diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 5d0e94533bf4..58ee78787176 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,27 +1,158 @@ -# Github owner file -# List of code reviewers for TVM modules +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. -# Global reviewers -* @dmlc/tvm-committers +# Github code owners file +# This file is used as a convenient tool to map +# committers' areas of expertise and faciliate the review process. +# +# This may not be the non-comprehensive list and is meant to be +# updated over time. -# LLVM backends -src/codegen/llvm/* @aatluri +# Per ASF policy, committer have global write permission. +# We normally recommend committers to shepherd code in their area of expertise. +* @apache/tvm-committers -# ROCM runtime -src/runtime/rocm/* @aatluri +# Order is important; the last matching pattern takes the most precedence. +# The sub modules should be ordered first by depth. +# Making sure we append new sub-module rules after exisiting modules rules. -# SGX support -src/runtime/sgx/* @nhynes -apps/sgx/* @nhynes +############################## +# Top-level Fallbacks +############################## +include/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +src/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +apps/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +python/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics + +# Thirdparty license audit +3rdparty/** @tqchen @jroesch +licenses/** @tqchen @jroesch # JVM language -jvm/* @yzhliu +jvm/** @yzhliu + +# Golang +golang/** @srkreddy1238 + +# WASM +web/** @tqchen @jroesch + +# Docker +docker/** @areusch @leandron @jroesch + +# Conda +conda/** @tqchen @junrushao1994 @comaniac + +# CMake +cmake/** @jroesch @tqchen @areusch @junrushao1994 @comaniac + +# rust bindings +rust/** @jroesch @nhynes @nhynes + +# vta +vta/** @tmoreau89 @vegaluisjose + +# docs +docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 +tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 + +# tests +tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 + +############################## +# Specific modules +############################## + +# automation related +src/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 +include/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 +python/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 + +python/tvm/autotvm/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 + +# node system and reflection +src/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +include/tvm/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac + +# ir: Common IR +src/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +include/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac +python/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac + +# tir +src/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +include/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +python/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were + +# te +src/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +include/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were +python/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were + +# target +src/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi +include/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi +python/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi + +# arith: Arithmetic module and simplifiers +src/arith/** @tqchen @junrushao1994 @vinx13 +include/tvm/arith/** @tqchen @junrushao1994 @vinx13 +python/tvm/arith/** @tqchen @junrushao1994 @vinx13 + +# parser +src/parser/** @jroesch @slyubomirsky + +# runtime +src/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 +include/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 +python/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 + +# runtime/micro +src/runtime/micro/** @areusch @liangfu @tmoreau89 +src/runtime/crt/** @areusch @liangfu @tmoreau89 +include/tvm/runtime/crt/** @areusch @liangfu @tmoreau89 +include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 +python/tvm/micro/** @areusch @liangfu @tmoreau89 + +# relay +src/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +include/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +python/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 + + +# relay/qnn +src/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang +inlcude/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang +python/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang + +# relay/backend/contrib: BYOC +src/relay/backend/contrib/** @zhiics @trevor-m @comaniac @mbaret + +# relay/frontends +python/tvm/relay/frontend/** @jwfromm @mbrookhart @srkreddy1238 @siju-samuel @Huyuwei @hlu1 @kazum @PariksheetPinjari909 -# WebGL backends -src/runtime/opengl/* @phisiart -src/codegen/*opengl* @phisiart +# topi: Operator definitions +src/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 +include/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 +python/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -# TOPI -topi/python/topi/* @Laurawly @Huyuwei +# tvm/driver/ +python/tvm/driver/** @leandron @jwfromm @tqchen @jroesch +# tvm/driver/tvmc +python/tvm/driver/tvmc/** @leandron @jwfromm diff --git a/CMakeLists.txt b/CMakeLists.txt index c56a929e276d..59786af38a9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" AUTO) tvm_option(BUILD_STATIC_RUNTIME "Build static version of libtvm_runtime" OFF) +tvm_option(USE_PAPI "Use Performance Application Programming Interface (PAPI) to read performance counters" OFF) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -407,6 +408,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) +include(cmake/modules/contrib/PAPI.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/RustExt.cmake) @@ -422,12 +424,15 @@ else() set(CMAKE_CUDA_STANDARD 14) endif() -add_lib_info(${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) +set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) +add_lib_info(${LIBINFO_FILE}) +list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE}) add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) +add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE}) -add_library(tvm SHARED $ $) +add_library(tvm SHARED $ $ $) set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") if(BUILD_STATIC_RUNTIME) @@ -443,8 +448,10 @@ else() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") endif() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") + target_compile_definitions(tvm_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm_runtime_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) +target_compile_definitions(tvm_libinfo_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm PUBLIC DMLC_USE_LOGGING_LIBRARY=) target_compile_definitions(tvm_runtime PUBLIC DMLC_USE_LOGGING_LIBRARY=) @@ -472,19 +479,24 @@ if(USE_RELAY_DEBUG) target_compile_definitions(tvm_objs PRIVATE "TVM_LOG_DEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_RELAY_DEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_RELAY_DEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "TVM_LOG_DEBUG") else() target_compile_definitions(tvm_objs PRIVATE "NDEBUG") target_compile_definitions(tvm_runtime_objs PRIVATE "NDEBUG") + target_compile_definitions(tvm_libinfo_objs PRIVATE "NDEBUG") endif(USE_RELAY_DEBUG) if(USE_FALLBACK_STL_MAP) message(STATUS "Building with STL Map...") target_compile_definitions(tvm_objs PRIVATE "USE_FALLBACK_STL_MAP=1") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_FALLBACK_STL_MAP=1") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_FALLBACK_STL_MAP=1") else() message(STATUS "Building with TVM Map...") target_compile_definitions(tvm_objs PRIVATE "USE_FALLBACK_STL_MAP=0") target_compile_definitions(tvm_runtime_objs PRIVATE "USE_FALLBACK_STL_MAP=0") + target_compile_definitions(tvm_libinfo_objs PRIVATE "USE_FALLBACK_STL_MAP=0") endif(USE_FALLBACK_STL_MAP) if(BUILD_FOR_HEXAGON) @@ -515,17 +527,23 @@ target_include_directories( target_include_directories( tvm_objs PUBLIC "topi/include") +target_include_directories( + tvm_libinfo_objs + PUBLIC "topi/include") set(CRC16_INCLUDE_PATH "3rdparty/libcrc/include") target_include_directorieS( tvm_objs PRIVATE "${CRC16_INCLUDE_PATH}") +target_include_directorieS( + tvm_libinfo_objs + PRIVATE "${CRC16_INCLUDE_PATH}") target_include_directorieS( tvm_runtime_objs PRIVATE "${CRC16_INCLUDE_PATH}") set(TVM_TEST_LIBRARY_NAME tvm) if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - add_library(tvm_allvisible SHARED $ $) + add_library(tvm_allvisible SHARED $ $ $) target_include_directories(tvm_allvisible PUBLIC "$") target_link_libraries(tvm_allvisible PRIVATE "$") set(TVM_TEST_LIBRARY_NAME tvm_allvisible) @@ -603,6 +621,7 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS) + target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS) endif() @@ -619,6 +638,7 @@ if(TVM_IS_DEBUG_BUILD) if(FILE_PREFIX_MAP_SUPPORTED) target_compile_options(tvm PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) + target_compile_options(tvm_libinfo_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_runtime PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) target_compile_options(tvm_runtime_objs PRIVATE $<$:${FILE_PREFIX_MAP_FLAG}>) endif() @@ -635,3 +655,32 @@ if(APPLE AND TVM_IS_DEBUG_BUILD) VERBATIM ) endif() + +#Caches the build. +#Note that ccache-3.x doesn't support nvcc well, so CUDA kernels may never hit the cache and still +#need to be re-compiled every time. Using ccache 4.0+ can resolve this issue. + +if(USE_CCACHE) # True for AUTO, ON, /path/to/ccache + if("${USE_CCACHE}" STREQUAL "AUTO") # Auto mode + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) + message(STATUS "Found the path to ccache, enabling ccache") + set(PATH_TO_CCACHE ccache) + else() + message(STATUS "Didn't find the path to CCACHE, disabling ccache") + endif(CCACHE_FOUND) + elseif("${USE_CCACHE}" MATCHES ${IS_TRUE_PATTERN}) + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) + message(STATUS "Found the path to ccache, enabling ccache") + set(PATH_TO_CCACHE ccache) + else() + message(FATAL_ERROR "Cannot find ccache. Set USE_CCACHE mode to AUTO or OFF to build without ccache. USE_CCACHE=" "${USE_CCACHE") + endif(CCACHE_FOUND) + else() # /path/to/ccache + set(PATH_TO_CCACHE USE_CCACHE) + message(STATUS "Setting ccache path to " "${PATH_TO_CCACHE}") + endif() + # Set the flag for ccache + set(CXX_COMPILER_LAUNCHER PATH_TO_CCACHE) +endif(USE_CCACHE) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b26d38574c6f..550be102b562 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -66,7 +66,7 @@ We do encourage everyone to work anything they are interested in. - [Andrew Reusch](https://github.com/areusch): @areusch - runtime, microTVM - [Jared Roesch](https://github.com/jroesch) (PMC): @jroesch - relay - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - frontends -- [Junru Shao](https://github.com/junrushao1994) @junrushao1994 - relay, compiler +- [Junru Shao](https://github.com/junrushao1994) (PMC): @junrushao1994 - relay, compiler - [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi - [Siva](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web @@ -77,7 +77,7 @@ We do encourage everyone to work anything they are interested in. - [Jian Weng](https://github.com/were): @were: - hybrid script - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene - runtime, topi, frontends - [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi -- [Hao Yu](https://github.com/comaniac): @comaniac - relay, byoc, auto_scheduler +- [Hao Yu](https://github.com/comaniac): @comaniac (PMC) - relay, byoc, auto_scheduler - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, auto_scheduler, topi, relay ## Reviewers diff --git a/Jenkinsfile b/Jenkinsfile index f26b148085fb..65ccbf27326f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,12 +45,12 @@ // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.66" -ci_gpu = "tlcpack/ci-gpu:v0.75" -ci_cpu = "tlcpack/ci-cpu:v0.74" +ci_gpu = "tlcpack/ci-gpu:v0.76" +ci_cpu = "tlcpack/ci-cpu:v0.75" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.73" -ci_qemu = "tlcpack/ci-qemu:v0.05" -ci_arm = "tlcpack/ci-arm:v0.05" +ci_qemu = "tlcpack/ci-qemu:v0.06" +ci_arm = "tlcpack/ci-arm:v0.06" // <--- End of regex-scanned config. // tvm libraries diff --git a/apps/README.md b/apps/README.md index 41c39248706b..01630a3ee8c1 100644 --- a/apps/README.md +++ b/apps/README.md @@ -25,4 +25,4 @@ they also serve as examples on how to use TVM in your own project. - [android_rpc](android_rpc) Android RPC server. - [benchmark](benchmark) Example end to end compilation benchmarks - [howto_deploy](howto_deploy) Tutorial on how to deploy TVM with minimum code dependency. -- [wasm_standalone](tvm-standalone) WebAssembly standalone for deep learning framework with TVM runtime. +- [wasm_standalone](wasm-standalone) WebAssembly standalone for deep learning framework with TVM runtime. diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 1331e1a65ca8..bbb43a4e9cb1 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -34,6 +34,7 @@ #define TVM_LOG_CUSTOMIZE 1 #include "../src/runtime/c_runtime_api.cc" +#include "../src/runtime/container.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index e897a975de28..be172cea4e53 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -47,7 +47,7 @@ namespace { std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { std::string untar_cmd; untar_cmd.reserve(512); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) untar_cmd += "tar -C "; untar_cmd += output_dir; untar_cmd += " -zxf "; @@ -224,7 +224,7 @@ std::vector ListDir(const std::string& dirname) { return vec; } -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) /*! * \brief LinuxShared Creates a linux shared library * \param output The output file name @@ -280,7 +280,7 @@ void WindowsShared(const std::string& output, const std::vector& fi * \param files The files for building */ void CreateShared(const std::string& output, const std::vector& files) { -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) LinuxShared(output, files); #elif defined(_WIN32) WindowsShared(output, files); @@ -290,7 +290,8 @@ void CreateShared(const std::string& output, const std::vector& fil } std::string BuildSharedLibrary(std::string file) { - if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { + if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll") || + support::EndsWith(file, ".dylib")) { return file; } diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 5dc84105388b..f5706d315f4a 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -140,7 +140,7 @@ class RPCServer { * \brief ListenLoopProc The listen process. */ void ListenLoopProc() { - TrackerClient tracker(tracker_addr_, key_, custom_addr_); + TrackerClient tracker(tracker_addr_, key_, custom_addr_, port_); while (true) { support::TCPSocket conn; support::SockAddr addr("0.0.0.0", 0); diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index 1497ab3251be..58824e6aa9af 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -49,12 +49,17 @@ class TrackerClient { * \brief Constructor. */ TrackerClient(const std::string& tracker_addr, const std::string& key, - const std::string& custom_addr) + const std::string& custom_addr, int port) : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), + port_(port), gen_(std::random_device{}()), - dis_(0.0, 1.0) {} + dis_(0.0, 1.0) { + if (custom_addr_.empty()) { + custom_addr_ = "null"; + } + } /*! * \brief Destructor. */ @@ -80,7 +85,7 @@ class TrackerClient { std::ostringstream ss; ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ - << "\"}]"; + << "\", \"addr\": [" << custom_addr_ << ", \"" << port_ << "\"]}]"; tracker_sock_.SendBytes(ss.str()); // Receive status and validate @@ -105,9 +110,6 @@ class TrackerClient { void ReportResourceAndGetKey(int port, std::string* matchkey) { if (!tracker_sock_.IsClosed()) { *matchkey = RandomKey(key_ + ":", old_keyset_); - if (custom_addr_.empty()) { - custom_addr_ = "null"; - } std::ostringstream ss; ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port @@ -230,6 +232,7 @@ class TrackerClient { std::string tracker_addr_; std::string key_; std::string custom_addr_; + int port_; support::TCPSocket tracker_sock_; std::set old_keyset_; std::mt19937 gen_; diff --git a/apps/microtvm/reference-vm/zephyr/Vagrantfile b/apps/microtvm/reference-vm/zephyr/Vagrantfile index be41c0b733e5..bd0094fcec66 100644 --- a/apps/microtvm/reference-vm/zephyr/Vagrantfile +++ b/apps/microtvm/reference-vm/zephyr/Vagrantfile @@ -57,6 +57,7 @@ Vagrant.configure("2") do |config| vb.customize ["modifyvm", :id, "--usb", "on"] vb.customize ["modifyvm", :id, "--usbehci", "on"] vb.customize ["modifyvm", :id, "--usbxhci", "on"] + vb.customize [ "guestproperty", "set", :id, "/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold", 10000] dirs_to_mount.each do |d| overrides.vm.synced_folder d.to_s, d.to_s end diff --git a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf index d298325eb4a4..6c588c86b0d5 100644 --- a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf +++ b/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf @@ -29,3 +29,6 @@ CONFIG_TEST_RANDOM_GENERATOR=y # For debugging. CONFIG_LED=y + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf b/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf new file mode 100644 index 000000000000..52a6753c733b --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file is specific to the STM32L4R5ZI Nucleo board. + +# For intrinsics used by generated optimized operators. +CONFIG_CMSIS_DSP=y + +# For AOT runtime which requires lots of function call. +CONFIG_MAIN_STACK_SIZE=3000 + +# For random number generation. +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y + +# For debugging. +CONFIG_LED=y diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf b/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf new file mode 100644 index 000000000000..267589ba8f0c --- /dev/null +++ b/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is specific to the QEMU-emulated Cortex R5 microTVM board. + +# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_TIMER_RANDOM_GENERATOR=y + +# Default stack size is 1k, this is required for debug mode. +CONFIG_MAIN_STACK_SIZE=2000 diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf b/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf index 5f3c4a4bed36..4b0e494068fa 100644 --- a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf +++ b/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf @@ -23,3 +23,6 @@ CONFIG_TIMER_RANDOM_GENERATOR=y # Default stack size is 1k, this is required for debug mode. CONFIG_MAIN_STACK_SIZE=2000 + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/aot_demo/prj.conf b/apps/microtvm/zephyr/aot_demo/prj.conf index 5f4d7a0689dc..c6ab10e9d86e 100644 --- a/apps/microtvm/zephyr/aot_demo/prj.conf +++ b/apps/microtvm/zephyr/aot_demo/prj.conf @@ -28,8 +28,5 @@ CONFIG_UART_INTERRUPT_DRIVEN=y CONFIG_CPLUSPLUS=y CONFIG_NEWLIB_LIBC=y -# For models with floating point. -CONFIG_FPU=y - # For TVMPlatformAbort(). CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/aot_demo/src/main.c index 43cc7b33987b..0c16572fc744 100644 --- a/apps/microtvm/zephyr/aot_demo/src/main.c +++ b/apps/microtvm/zephyr/aot_demo/src/main.c @@ -32,6 +32,7 @@ #include "input_data.h" #include "output_data.h" +#include "tvmgen_default.h" #include "zephyr_uart.h" #ifdef CONFIG_ARCH_POSIX @@ -194,18 +195,18 @@ void main(void) { } TVMLogf("Zephyr AOT Runtime\n"); - void* inputs[1] = { - input_data, + struct tvmgen_default_inputs inputs = { + .input_1 = input_data, }; - void* outputs[1] = { - output_data, + struct tvmgen_default_outputs outputs = { + .output = output_data, }; StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); double elapsed_time = 0; TVMPlatformTimerStart(); - int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs); + int ret_val = tvmgen_default_run(&inputs, &outputs); TVMPlatformTimerStop(&elapsed_time); if (ret_val != 0) { diff --git a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c index c9eec8751100..02401584f652 100644 --- a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c +++ b/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c @@ -77,5 +77,11 @@ uint32_t TVMPlatformWriteSerial(const char* data, uint32_t size) { void TVMPlatformUARTInit() { // Claim console device. g_microtvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console))); + const struct uart_config config = {.baudrate = 115200, + .parity = UART_CFG_PARITY_NONE, + .stop_bits = UART_CFG_STOP_BITS_1, + .data_bits = UART_CFG_DATA_BITS_8, + .flow_ctrl = UART_CFG_FLOW_CTRL_NONE}; + uart_configure(g_microtvm_uart, &config); uart_rx_init(&uart_rx_rbuf, g_microtvm_uart); } diff --git a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf index 149a69ea3b5b..511ff0121d32 100644 --- a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf +++ b/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf @@ -29,3 +29,6 @@ CONFIG_TEST_RANDOM_GENERATOR=y # For debugging. CONFIG_LED=y + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf index eba023294894..33b08032c32e 100644 --- a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf +++ b/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf @@ -28,3 +28,6 @@ CONFIG_ENTROPY_GENERATOR=y # For debugging. CONFIG_LED=y + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf new file mode 100644 index 000000000000..b87206019026 --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file is specific to the STM32L4R5ZI Nucleo board. + +# For intrinsics used by generated optimized operators. +CONFIG_CMSIS_DSP=y + +# For operations that stack allocates a large float array. +CONFIG_MAIN_STACK_SIZE=1536 + +# For random number generation. +CONFIG_ENTROPY_GENERATOR=y +CONFIG_TEST_RANDOM_GENERATOR=y + +# For debugging. +CONFIG_LED=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf new file mode 100644 index 000000000000..4097f7ec5487 --- /dev/null +++ b/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is specific to the QEMU-emulated Cortex R5 microTVM board. + +# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. +CONFIG_TEST_RANDOM_GENERATOR=y +CONFIG_TIMER_RANDOM_GENERATOR=y + +# Default stack size is 1k, this is required for debug mode. +CONFIG_MAIN_STACK_SIZE=1536 diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf index 3733568ed02f..b94d96b11fba 100644 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf +++ b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf @@ -24,6 +24,9 @@ CONFIG_TIMER_RANDOM_GENERATOR=y # Default is 512, raised here for operations with large floating point data. CONFIG_MAIN_STACK_SIZE=2048 +# For models with floating point. +CONFIG_FPU=y + # For floating point operations. It has exception on floating point operations # without this flag. CONFIG_FPU_SHARING=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf index a8a055bcc748..1da5f054da46 100644 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf +++ b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf @@ -23,3 +23,6 @@ CONFIG_TIMER_RANDOM_GENERATOR=y # Default 512, for operations with large floating point data. CONFIG_MAIN_STACK_SIZE=2048 + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf b/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf index 505f1babc3f4..542faf28cd67 100644 --- a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf +++ b/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf @@ -26,3 +26,6 @@ CONFIG_TEST_RANDOM_GENERATOR=y # For debugging. CONFIG_LED=n + +# For models with floating point. +CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/prj.conf b/apps/microtvm/zephyr/host_driven/prj.conf index 5f4d7a0689dc..c6ab10e9d86e 100644 --- a/apps/microtvm/zephyr/host_driven/prj.conf +++ b/apps/microtvm/zephyr/host_driven/prj.conf @@ -28,8 +28,5 @@ CONFIG_UART_INTERRUPT_DRIVEN=y CONFIG_CPLUSPLUS=y CONFIG_NEWLIB_LIBC=y -# For models with floating point. -CONFIG_FPU=y - # For TVMPlatformAbort(). CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/qemu-hack/qemu-system-arm index 58fc8296c31f..ebbc8ad5ad9d 120000 --- a/apps/microtvm/zephyr/qemu-hack/qemu-system-arm +++ b/apps/microtvm/zephyr/qemu-hack/qemu-system-arm @@ -1 +1 @@ -./qemu-system-i386 \ No newline at end of file +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/wasm-standalone/README.md b/apps/wasm-standalone/README.md index e40d218634aa..b8a977f6ae50 100644 --- a/apps/wasm-standalone/README.md +++ b/apps/wasm-standalone/README.md @@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all - Build DL library in the WebAssembly format. - - Download model + - Compile the model ``` - cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx - ``` - - - Compile - - ``` - LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx + cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ``` ### Build wasm-graph package @@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn $ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv original image dimensions: (256, 256) resized image dimensions: (224, 224) -input image belongs to the class `tabby, tabby cat` +input image belongs to the class `tiger cat` ``` +Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to +`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD +may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the +[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py). + ## Future Work ### More networks support diff --git a/apps/wasm-standalone/wasm-graph/src/lib.rs b/apps/wasm-standalone/wasm-graph/src/lib.rs index 2b4187849edc..92a3d5c2f3b0 100644 --- a/apps/wasm-standalone/wasm-graph/src/lib.rs +++ b/apps/wasm-standalone/wasm-graph/src/lib.rs @@ -48,6 +48,7 @@ lazy_static! { "/lib/graph.json" ))) .unwrap(); + let params_bytes = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params")); let params = tvm_graph_rt::load_param_dict(params_bytes) @@ -57,6 +58,7 @@ lazy_static! { .collect::>>(); let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap(); + exec.load_params(params); Mutex::new(exec) @@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 { let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) }; let input: TVMTensor = in_tensor.as_dltensor().into(); - GRAPH_EXECUTOR.lock().unwrap().set_input("data", input); - GRAPH_EXECUTOR.lock().unwrap().run(); - let output = GRAPH_EXECUTOR - .lock() - .unwrap() - .get_output(0) - .unwrap() - .as_dltensor(false); + // since this executor is not multi-threaded, we can acquire lock once + let mut executor = GRAPH_EXECUTOR.lock().unwrap(); + + executor.set_input("data", input); + + executor.run(); + + let output = executor.get_output(0).unwrap().as_dltensor(false); let out_tensor: Tensor = output.into(); let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) }; diff --git a/apps/wasm-standalone/wasm-graph/src/types.rs b/apps/wasm-standalone/wasm-graph/src/types.rs index a3761a758cff..f08b7be84990 100644 --- a/apps/wasm-standalone/wasm-graph/src/types.rs +++ b/apps/wasm-standalone/wasm-graph/src/types.rs @@ -24,7 +24,7 @@ use std::{ }; pub use tvm_sys::ffi::DLTensor; use tvm_sys::ffi::{ - DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU, + DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU, }; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] diff --git a/apps/wasm-standalone/wasm-graph/src/utils.rs b/apps/wasm-standalone/wasm-graph/src/utils.rs index fd4a71745f4f..92d386e3062a 100644 --- a/apps/wasm-standalone/wasm-graph/src/utils.rs +++ b/apps/wasm-standalone/wasm-graph/src/utils.rs @@ -24,13 +24,20 @@ use std::ptr; pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor { let in_addr = in_addr as *mut u8; - let mut data_vec = Vec::new(); - for i in 0..in_size { - data_vec.push(ptr::read(in_addr.offset(i as isize))); - } - let input: Tensor = serde_json::from_slice(&data_vec).unwrap(); + println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size); + + let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) }; - input + let input = serde_json::from_slice(&data_vec); + match input { + Ok(result) => { + println!("DEBUG: SER SUCCEED!!! and Ok"); + result + } + Err(e) => { + panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e); + } + } } pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize { diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py old mode 100644 new mode 100755 index 3d8a349b8744..b1cdb199a871 --- a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py +++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""Builds a simple graph for testing.""" +"""Builds a simple resnet50 graph for testing.""" import argparse import os import subprocess @@ -25,47 +25,78 @@ import onnx import tvm from tvm import relay, runtime +from tvm.contrib.download import download_testdata +from tvm.contrib import graph_executor +from PIL import Image +import numpy as np +import tvm.relay as relay -def _get_mod_and_params(model_file): - onnx_model = onnx.load(model_file) - shape_dict = {} - for input in onnx_model.graph.input: - shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim] +# This example uses resnet50-v2-7 model +model_url = "".join( + [ + "https://github.com/onnx/models/raw/", + "master/vision/classification/resnet/model/", + "resnet50-v2-7.onnx", + ] +) - return relay.frontend.from_onnx(onnx_model, shape_dict) - -def build_graph_lib(model_file, opt_level): +def build_graph_lib(opt_level): """Compiles the pre-trained model with TVM""" out_dir = os.path.join(sys.path[0], "../lib") if not os.path.exists(out_dir): os.makedirs(out_dir) - # Compile the relay mod - mod, params = _get_mod_and_params(model_file) + # Follow the tutorial to download and compile the model + model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx") + onnx_model = onnx.load(model_path) + + img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg" + img_path = download_testdata(img_url, "imagenet_cat.png", module="data") + + # Resize it to 224x224 + resized_image = Image.open(img_path).resize((224, 224)) + img_data = np.asarray(resized_image).astype("float32") + + # Our input image is in HWC layout while ONNX expects CHW input, so convert the array + img_data = np.transpose(img_data, (2, 0, 1)) + + # Normalize according to the ImageNet input specification + imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev + + # Add the batch dimension, as we are expecting 4-dimensional input: NCHW. + img_data = np.expand_dims(norm_img_data, axis=0) + + input_name = "data" + shape_dict = {input_name: img_data.shape} + + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib" + with tvm.transform.PassContext(opt_level=opt_level): - graph_json, lib, params = relay.build(mod, target=target, params=params) + factory = relay.build(mod, target=target, params=params) # Save the model artifacts to obj_file obj_file = os.path.join(out_dir, "graph.o") - lib.save(obj_file) + factory.get_lib().save(obj_file) + # Run llvm-ar to archive obj_file into lib_file lib_file = os.path.join(out_dir, "libgraph_wasm32.a") cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file] subprocess.run(cmds) + # Save the json and params with open(os.path.join(out_dir, "graph.json"), "w") as f_graph: - f_graph.write(graph_json) - + f_graph.write(factory.get_graph_json()) with open(os.path.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(runtime.save_param_dict(params)) + f_params.write(runtime.save_param_dict(factory.get_params())) if __name__ == "__main__": parser = argparse.ArgumentParser(description="ONNX model build example") - parser.add_argument("model_file", type=str, help="the path of onnx model file") parser.add_argument( "-O", "--opt-level", @@ -75,4 +106,4 @@ def build_graph_lib(model_file, opt_level): ) args = parser.parse_args() - build_graph_lib(args.model_file, args.opt_level) + build_graph_lib(args.opt_level) diff --git a/apps/wasm-standalone/wasm-runtime/Cargo.toml b/apps/wasm-standalone/wasm-runtime/Cargo.toml index 99f6db54431f..d3f860170d4e 100644 --- a/apps/wasm-standalone/wasm-runtime/Cargo.toml +++ b/apps/wasm-standalone/wasm-runtime/Cargo.toml @@ -26,8 +26,8 @@ license = "Apache-2.0" keywords = ["wasm", "machine learning", "wasmtime"] [dependencies] -wasmtime = "0.16.0" -wasmtime-wasi = "0.16.0" +wasmtime = "0.28.0" +wasmtime-wasi = "0.28.0" anyhow = "1.0.31" serde = "1.0.53" serde_json = "1.0.53" diff --git a/apps/wasm-standalone/wasm-runtime/src/graph.rs b/apps/wasm-standalone/wasm-runtime/src/graph.rs index e7c39cbb0687..bfa1c2f19c56 100644 --- a/apps/wasm-standalone/wasm-runtime/src/graph.rs +++ b/apps/wasm-standalone/wasm-runtime/src/graph.rs @@ -19,7 +19,7 @@ use anyhow::Result; use wasmtime::*; -use wasmtime_wasi::{Wasi, WasiCtx}; +use wasmtime_wasi::{WasiCtx, WasiCtxBuilder}; use super::Tensor; @@ -27,6 +27,9 @@ pub struct GraphExecutor { pub(crate) wasm_addr: i32, pub(crate) input_size: i32, pub(crate) output_size: i32, + pub(crate) store: Option>, + // None-WASI version: + // pub(crate) store: Option>, pub(crate) instance: Option, } @@ -37,25 +40,44 @@ impl GraphExecutor { wasm_addr: 0, input_size: 0, output_size: 0, + store: None, instance: None, } } pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> { - let engine = Engine::new(Config::new().wasm_simd(true)); - let store = Store::new(&engine); + // It seems WASI in this example is not necessary + // None WASI version: works with no SIMD + // let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); + // let mut store = Store::new(&engine, ()); + // let module = Module::from_file(store.engine(), &wasm_graph_file)?; + + // let instance = Instance::new(&mut store, &module, &[])?; + + // self.instance = Some(instance); + // self.store = Some(store); + + // Ok(()) + + // WASI version: + let engine = Engine::new(Config::new().wasm_simd(true)).unwrap(); // First set up our linker which is going to be linking modules together. We // want our linker to have wasi available, so we set that up here as well. - let mut linker = Linker::new(&store); + let mut linker = Linker::new(&engine); + wasmtime_wasi::add_to_linker(&mut linker, |s| s)?; // Create an instance of `Wasi` which contains a `WasiCtx`. Note that // `WasiCtx` provides a number of ways to configure what the target program // will have access to. - let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?); - wasi.add_to_linker(&mut linker)?; + let wasi = WasiCtxBuilder::new() + .inherit_stdio() + .inherit_args()? + .build(); + let mut store = Store::new(&engine, wasi); - let module = Module::from_file(&store, &wasm_graph_file)?; - self.instance = Some(linker.instantiate(&module)?); + let module = Module::from_file(&engine, &wasm_graph_file)?; + self.instance = Some(linker.instantiate(&mut store, &module)?); + self.store = Some(store); Ok(()) } @@ -65,26 +87,24 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; // Specify the wasm address to access the wasm memory. - let wasm_addr = memory.data_size(); + let wasm_addr = memory.data_size(self.store.as_mut().unwrap()); + // Serialize the data into a JSON string. let in_data = serde_json::to_vec(&input_data)?; let in_size = in_data.len(); + // Grow up memory size according to in_size to avoid memory leak. - memory.grow((in_size >> 16) as u32 + 1)?; + memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?; - // Insert the input data into wasm memory. - for i in 0..in_size { - unsafe { - memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap(); - } - } + memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?; self.wasm_addr = wasm_addr as i32; self.input_size = in_size as i32; + Ok(()) } @@ -94,11 +114,12 @@ impl GraphExecutor { .instance .as_ref() .unwrap() - .get_func("run") - .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))? - .get2::()?; + .get_func(self.store.as_mut().unwrap(), "run") + .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?; - let out_size = run(self.wasm_addr, self.input_size)?; + let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)]; + let out_size = run.call(self.store.as_mut().unwrap(), ¶ms[..])?; + let out_size = (*out_size)[0].unwrap_i32(); if out_size == 0 { panic!("graph run failed!"); } @@ -107,18 +128,22 @@ impl GraphExecutor { Ok(()) } - pub fn get_output(&self) -> Result { + pub fn get_output(&mut self) -> Result { let memory = self .instance .as_ref() .unwrap() - .get_memory("memory") + .get_memory(self.store.as_mut().unwrap(), "memory") .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?; - let out_data = unsafe { - &memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize] - }; - let out_vec: Tensor = serde_json::from_slice(out_data).unwrap(); + let mut out_data = vec![0 as u8; self.output_size as _]; + memory.read( + self.store.as_mut().unwrap(), + self.wasm_addr as _, + &mut out_data, + )?; + + let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap(); Ok(out_vec) } } diff --git a/cmake/config.cmake b/cmake/config.cmake index ae257d435155..daa0f1e84315 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -299,3 +299,23 @@ set(USE_LIBBACKTRACE AUTO) # not be included in the final executable. This would make the corresponding # runtime functions to be unavailable to the program. set(BUILD_STATIC_RUNTIME OFF) + + +# Caches the build so that building is faster when switching between branches. +# If you switch branches, build and then encounter a linking error, you may +# need to regenerate the build tree through "make .." (the cache will +# still provide significant speedups). +# Possible values: +# - AUTO: search for path to ccache, disable if not found. +# - ON: enable ccache by searching for the path to ccache, report an error if not found +# - OFF: disable ccache +# - /path/to/ccache: use specific path to ccache +set(USE_CCACHE AUTO) + +# Whether to enable PAPI support in profiling. PAPI provides access to hardware +# counters while profiling. +# Possible values: +# - ON: enable PAPI support. Will search PKG_CONFIG_PATH for a papi.pc +# - OFF: disable PAPI support. +# - /path/to/folder/containing/: Path to folder containing papi.pc. +set(USE_PAPI OFF) diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index b908df2f869b..8b064d2eb2eb 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -34,6 +34,9 @@ if(USE_ROCM) file(GLOB RUNTIME_ROCM_SRCS src/runtime/rocm/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_ROCM_SRCS}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPHCC_LIBRARY}) + if (ROCM_HSA_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HSA_LIBRARY}) + endif() if(USE_MIOPEN) message(STATUS "Build with MIOpen support") diff --git a/cmake/modules/contrib/PAPI.cmake b/cmake/modules/contrib/PAPI.cmake new file mode 100644 index 000000000000..257591451ca8 --- /dev/null +++ b/cmake/modules/contrib/PAPI.cmake @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_PAPI) + find_package(PkgConfig REQUIRED) + + set(ENV{PKG_CONFIG_PATH} "${USE_PAPI}:$ENV{PKG_CONFIG_PATH}") + pkg_check_modules(PAPI REQUIRED IMPORTED_TARGET papi>=6.0) + list(APPEND TVM_RUNTIME_LINKER_LIBS PkgConfig::PAPI) + list(APPEND RUNTIME_SRCS src/runtime/contrib/papi/papi.cc) +endif() diff --git a/cmake/utils/FindROCM.cmake b/cmake/utils/FindROCM.cmake index 7d4e282956d9..4d895ff89d13 100644 --- a/cmake/utils/FindROCM.cmake +++ b/cmake/utils/FindROCM.cmake @@ -55,6 +55,7 @@ macro(find_rocm use_rocm) endif() find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib) find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib) + find_library(ROCM_HSA_LIBRARY hsa-runtime64 ${__rocm_sdk}/lib) if(ROCM_HIPHCC_LIBRARY) set(ROCM_FOUND TRUE) diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 9479d7194d3b..974998b9d6fe 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -32,6 +32,9 @@ RUN bash /install/ubuntu_install_llvm.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 65afa6931d9c..bc6c0f116c1e 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -27,6 +27,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -109,3 +112,13 @@ RUN bash /install/ubuntu_install_androidsdk.sh ENV ANDROID_HOME=/opt/android-sdk-linux/ ENV ANDROID_NDK_HOME=/opt/android-sdk-linux/ndk/21.3.6528147/ +# Arm(R) Ethos(TM)-U NPU driver +COPY install/ubuntu_install_ethosu_driver_stack.sh /install/ubuntu_install_ethosu_driver_stack.sh +RUN bash /install/ubuntu_install_ethosu_driver_stack.sh + +# Install Vela compiler +COPY install/ubuntu_install_vela.sh /install/ubuntu_install_vela.sh +RUN bash /install/ubuntu_install_vela.sh + +# Update PATH +ENV PATH /opt/arm/gcc-arm-none-eabi/bin:$PATH diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 09c6425da6fb..a76cf4664f47 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -28,6 +28,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh RUN bash /install/ubuntu1804_install_llvm.sh diff --git a/docker/Dockerfile.ci_i386 b/docker/Dockerfile.ci_i386 index 2383f4675e37..564731c12d36 100644 --- a/docker/Dockerfile.ci_i386 +++ b/docker/Dockerfile.ci_i386 @@ -31,6 +31,9 @@ RUN bash /install/ubuntu_install_llvm.sh COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh RUN bash /install/ubuntu_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index 2adb793a3517..ae8c6b0c2c16 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -27,6 +27,9 @@ RUN apt-get update && apt-get install -y wget git sudo make COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + RUN apt-get update && apt-get install -y doxygen graphviz RUN pip3 install cpplint pylint==2.4.4 mypy==0.902 black==20.8b1 diff --git a/docker/Dockerfile.ci_qemu b/docker/Dockerfile.ci_qemu index 72189bd79afa..e309bb5f0e6c 100644 --- a/docker/Dockerfile.ci_qemu +++ b/docker/Dockerfile.ci_qemu @@ -26,7 +26,10 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python_venv.sh /install/ubuntu1804_install_python_venv.sh RUN bash /install/ubuntu1804_install_python_venv.sh -ENV PATH=/opt/tvm-venv/bin:$PATH +ENV PATH=/opt/tvm-venv/bin:/opt/zephyr-sdk/sysroots/x86_64-pokysdk-linux/usr/bin:$PATH + +# Globally disable pip cache +RUN pip config set global.cache-dir false COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -56,10 +59,6 @@ RUN bash /install/ubuntu_install_tensorflow.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh -# QEMU deps -COPY install/ubuntu_install_qemu.sh /install/ubuntu_install_qemu.sh -RUN bash /install/ubuntu_install_qemu.sh - # Zephyr SDK deps COPY install/ubuntu_install_zephyr.sh /install/ubuntu_install_zephyr.sh COPY install/ubuntu_init_zephyr_project.sh /install/ubuntu_init_zephyr_project.sh diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 85f942d57ca3..03a1ded5572f 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -24,6 +24,9 @@ RUN bash /install/ubuntu_install_core.sh COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh +# Globally disable pip cache +RUN pip config set global.cache-dir false + COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh diff --git a/docker/build.sh b/docker/build.sh index e654a6253317..3b58bcc52a75 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -22,7 +22,8 @@ # # Usage: build.sh [--tag ] # [--dockerfile ] [-it] -# [--net=host] [--cache-from ] [] +# [--net=host] [--cache-from ] +# [--context-path ] [] # # CONTAINER_TYPE: Type of the docker container used the run the build, # e.g. "ci_cpu", "ci_gpu" @@ -37,6 +38,9 @@ # IMAGE_NAME: An image to be as a source for cached layers when building the # Docker image requested. # +# CONTEXT_PATH: Path to be used for relative path resolution when building +# the Docker images. +# # COMMAND (optional): Command to be executed in the docker container # SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -47,7 +51,6 @@ shift 1 # Dockerfile to be used in docker build DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE}" -DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" if [[ "$1" == "--tag" ]]; then DOCKER_IMAGE_TAG="$2" @@ -57,9 +60,7 @@ fi if [[ "$1" == "--dockerfile" ]]; then DOCKERFILE_PATH="$2" - DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") echo "Using custom Dockerfile path: ${DOCKERFILE_PATH}" - echo "Using custom docker build context path: ${DOCKER_CONTEXT_PATH}" shift 2 fi @@ -85,6 +86,15 @@ if [[ "$1" == "--cache-from" ]]; then shift 1 fi +if [[ "$1" == "--context-path" ]]; then + DOCKER_CONTEXT_PATH="$2" + echo "Using custom context path: ${DOCKER_CONTEXT_PATH}" + shift 2 +else + DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") + echo "Using default context path: ${DOCKER_CONTEXT_PATH}" +fi + if [[ ! -f "${DOCKERFILE_PATH}" ]]; then echo "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" exit 1 diff --git a/docker/install/ubuntu_install_darknet.sh b/docker/install/ubuntu_install_darknet.sh index 37adf4a30270..8020899f8bf1 100755 --- a/docker/install/ubuntu_install_darknet.sh +++ b/docker/install/ubuntu_install_darknet.sh @@ -23,4 +23,7 @@ set -o pipefail #install the necessary dependancies, cffi, opencv wget -q 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so debian_version=`cat /etc/debian_version` -pip3 install opencv-python cffi + +pip3 install \ + cffi \ + opencv-python diff --git a/docker/install/ubuntu_install_ethosu_driver_stack.sh b/docker/install/ubuntu_install_ethosu_driver_stack.sh new file mode 100755 index 000000000000..35b2b4c74b7b --- /dev/null +++ b/docker/install/ubuntu_install_ethosu_driver_stack.sh @@ -0,0 +1,94 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +fvp_dir="/opt/arm/FVP_Corstone_SSE-300_Ethos-U55" +cmake_dir="/opt/arm/cmake" +ethosu_dir="/opt/arm/ethosu" +ethosu_driver_ver="21.05" +cmsis_ver="5.7.0" + +mkdir -p /opt/arm + +tmpdir=$(mktemp -d) + +cleanup() +{ + rm -rf "$tmpdir" +} + +trap cleanup 0 + +# Ubuntu 18.04 dependencies +apt-get update + +apt-get install -y \ + bsdmainutils \ + build-essential \ + cpp \ + git \ + linux-headers-generic \ + make \ + python-dev \ + python3 \ + ssh \ + wget \ + xxd + +# Download the FVP +mkdir -p "$fvp_dir" +cd "$tmpdir" +curl -sL https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/MPS3/FVP_Corstone_SSE-300_Ethos-U55_11.14_24.tgz | tar -xz +./FVP_Corstone_SSE-300_Ethos-U55.sh --i-agree-to-the-contained-eula --no-interactive -d "$fvp_dir" +rm -rf FVP_Corstone_SSE-300_Ethos-U55.sh license_terms + +# Setup cmake 3.19.5 +mkdir -p "${cmake_dir}" +cd "$tmpdir" +curl -sL -o cmake-3.19.5-Linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.19.5/cmake-3.19.5-Linux-x86_64.sh +chmod +x cmake-3.19.5-Linux-x86_64.sh +./cmake-3.19.5-Linux-x86_64.sh --prefix="${cmake_dir}" --skip-license +rm cmake-3.19.5-Linux-x86_64.sh +export PATH="${cmake_dir}/bin:${PATH}" + +# Install the GCC toolchain +mkdir -p /opt/arm/gcc-arm-none-eabi/ +gcc_arm_url='https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2?revision=ca0cbf9c-9de2-491c-ac48-898b5bbc0443&la=en&hash=68760A8AE66026BCF99F05AC017A6A50C6FD832A' +curl --retry 64 -sSL ${gcc_arm_url} | tar -C /opt/arm/gcc-arm-none-eabi --strip-components=1 -jx +export PATH="/opt/arm/gcc-arm-none-eabi/bin:${PATH}" + +# Clone Arm(R) Ethos(TM)-U NPU driver stack +mkdir "${ethosu_dir}" +cd "${ethosu_dir}" +git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-driver" core_driver +cd core_driver +git checkout tags/${ethosu_driver_ver} + +cd "${ethosu_dir}" +git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform" core_platform +cd core_platform +git checkout tags/${ethosu_driver_ver} + +# Clone CMSIS +cd "${ethosu_dir}" +git clone "https://github.com/ARM-software/CMSIS_5.git" cmsis +cd cmsis +git checkout -f tags/${cmsis_ver} diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index 8f462284c2ba..ef0bf1b012c6 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -22,11 +22,14 @@ set -o pipefail # We need to fix the onnx version because changing versions tends to break tests # TODO(mbrookhart): periodically update -pip3 install onnx==1.8.1 -pip3 install onnxruntime==1.7.0 +pip3 install \ + onnx==1.8.1 \ + onnxruntime==1.7.0 # torch depends on a number of other packages, but unhelpfully, does # not expose that in the wheel!!! pip3 install future -pip3 install torch==1.7.0 torchvision==0.8.1 +pip3 install \ + torch==1.7.0 \ + torchvision==0.8.1 diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 7989a49a4826..2ca298a43857 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -21,4 +21,21 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip3 install six numpy pytest cython decorator scipy tornado pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr +pip3 install \ + attrs \ + cloudpickle \ + cython \ + decorator \ + mypy \ + numpy \ + orderedset \ + packaging \ + Pillow \ + pytest \ + pytest-profiling \ + pytest-xdist \ + requests \ + scipy \ + six \ + synr \ + tornado diff --git a/docker/install/ubuntu_install_redis.sh b/docker/install/ubuntu_install_redis.sh index 0eb46eb8edec..d2600d828d49 100755 --- a/docker/install/ubuntu_install_redis.sh +++ b/docker/install/ubuntu_install_redis.sh @@ -21,4 +21,6 @@ set -u set -o pipefail apt-get update && apt-get install -y redis-server -pip3 install "xgboost>=1.1.0" psutil +pip3 install \ + psutil \ + "xgboost>=1.1.0" diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 8a7ce1d3f798..12208bbe6643 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -21,4 +21,13 @@ set -u set -o pipefail # NOTE: install docutils < 0.17 to work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 -pip3 install sphinx sphinx-gallery==0.4.0 autodocsumm sphinx_rtd_theme sphinx_autodoc_annotation matplotlib Image "commonmark>=0.7.3" "docutils>=0.11,<0.17" +pip3 install \ + autodocsumm \ + "commonmark>=0.7.3" \ + "docutils>=0.11,<0.17" \ + Image \ + matplotlib \ + sphinx \ + sphinx_autodoc_annotation \ + sphinx-gallery==0.4.0 \ + sphinx_rtd_theme diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 81802964ba0e..8a51fbbbb178 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -20,4 +20,7 @@ set -e set -u set -o pipefail -pip3 install tensorflow==2.4.2 +pip3 install \ + "h5py<3.0" \ + keras==2.4.3 \ + tensorflow==2.4.2 diff --git a/docker/install/ubuntu_install_vela.sh b/docker/install/ubuntu_install_vela.sh new file mode 100644 index 000000000000..e75a99d9d563 --- /dev/null +++ b/docker/install/ubuntu_install_vela.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +pip3 install -U setuptools +# In a refactor between v2.1.1 and v3.0.0, find_block_configs was removed from Vela. +# Since this is still required for the TVM port, it will be reinstated in Vela in a future release. +# Until then, it needs to be pinned to v2.1.1. +pip3 install ethos-u-vela==2.1.1 diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index c2a1a1e106d5..e7fb3c114689 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -23,6 +23,14 @@ tvm.ir :autosummary: +tvm.instrument +-------------- +.. automodule:: tvm.instrument + :members: + :imported-members: + :autosummary: + + tvm.transform ------------- .. automodule:: tvm.transform diff --git a/docs/conf.py b/docs/conf.py index 1f645645f25d..b008c305b1e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -190,7 +190,7 @@ def git_describe_version(original_version): intersphinx_mapping = { "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), "numpy": ("https://numpy.org/doc/stable", None), - "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), "matplotlib": ("https://matplotlib.org/", None), } @@ -273,7 +273,12 @@ def git_describe_version(original_version): "tune_network_x86.py", "tune_network_cuda.py", ], - "dev": ["low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py"], + "dev": [ + "low_level_custom_pass.py", + "use_pass_infra.py", + "use_pass_instrument.py", + "bring_your_own_datatypes.py", + ], } diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 60ab5e5ae9d2..efc8b32832c0 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -183,7 +183,7 @@ The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunct auto it = fmap_.find(name); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } @@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel( fcache_[device_id], wl.grid_dim(0), diff --git a/docs/dev/index.rst b/docs/dev/index.rst index b4fb37d790f4..76d50f496e75 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -423,3 +423,4 @@ microTVM :maxdepth: 1 microtvm_design + model_library_format diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index 010d0d42d37e..28e034dc44cb 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser :: - // attr [compute(D, 0x2c070b0)] realize_scope = "" realize D([0, 4], [0, 5], [0, 16]) { produce D { for (di, 0, 4) { for (dj, 0, 5) { for (dk, 0, 16) { - // attr [compute(C, 0x2c29990)] realize_scope = "" realize C([dj, 1], [dk, 1]) { produce C { for (i, 0, 1) { diff --git a/docs/dev/model_library_format.rst b/docs/dev/model_library_format.rst new file mode 100644 index 000000000000..fec90de4bcea --- /dev/null +++ b/docs/dev/model_library_format.rst @@ -0,0 +1,169 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Model Library Format +==================== + +About Model Library Format +-------------------------- + +TVM traditionally exports generated libraries as Dynamic Shared Objects (e.g. DLLs (Windows) or .so +(linux)). Inferences can be performed using those libraries by loading them into an executable using +``libtvm_runtime.so``. This process is very dependent on services provided by traditional OS. + +For deployment to unconventional platforms (e.g. those lacking traditional OS), TVM provides another +output format, Model Library Format. Initially, the microTVM project is the primary use case for this +format. Should it become useful in other use cases (and in particular, should it become possible to +export BYOC artifacts in Model Library Format), it could be used as a general-purpose TVM export +format. Model Library Format is a tarball containing a file for each piece of the TVM compiler +output. + +What can be Exported? +--------------------- + +At the time of writing, export is limited to full models built with ``tvm.relay.build``. + +Directory Layout +---------------- + +Model Library Format is contained within a tarball. All paths are relative to the root of the +tarball: + +- ``/`` - Root of the tarball + + - ``codegen`` - Root directory for all generated device code + + - (see `codegen`_ section) + + - ``executor-config/`` - Configuration for the executor which drives model inference + + - ``graph/`` - Root directory containing configuration for the GraphExecutor + + - ``graph.json`` - GraphExecutor JSON configuration + + - ``metadata.json`` - Machine-parseable metadata for this model + + - ``parameters/`` - Root directory where simplified parameters are placed + + - ``.params`` - Parameters for the model tvm.relay._save_params format + + - ``src/`` - Root directory for all source code consumed by TVM + + - ``relay.txt`` - Relay source code for the generated model + +Description of Sub-directories +------------------------------ + +.. _subdir_codegen: + +``codegen`` +^^^^^^^^^^^ + +All TVM-generated code is placed in this directory. At the time of writing, there is 1 file per +Module in the generated Module tree, though this restriction may change in the future. Files in +this directory should have filenames of the form ``/(lib|src)/.``. + +These components are described below: + + * ```` - Identifies the TVM target on which the code should run. Currently, only ``host`` + is supported. + * ```` - A unique slug identifying this file. Currently ``lib``, with ``>`` an + auto-incrementing integer. + * ```` - Suffix identifying the filename format. Currently ``c`` or ``o``. + +An example directory tree for a CPU-only model is shown below: + +- ``codegen/`` - Codegen directory + + - ``host/`` - Generated code for ``target_host`` + + - ``lib/`` - Generated binary object files + + - ``lib0.o`` - LLVM module (if ``llvm`` target is used) + - ``lib1.o`` - LLVM CRT Metadata Module (if ``llvm`` target is used) + + - ``src/`` - Generated C source + + - ``lib0.c`` - C module (if ``c`` target is used) + - ``lib1.c`` - C CRT Metadata module (if ``c`` target is used) + +``executor-config`` +^^^^^^^^^^^^^^^^^^^ + +Contains machine-parsable configuration for executors which can drive model inference. Currently, +only the GraphExecutor produces configuration for this directory, in ``graph/graph.json``. This +file should be read in and the resulting string supplied to the ``GraphExecutor()`` constructor for +parsing. + +``parameters`` +^^^^^^^^^^^^^^ + +Contains machine-parseable parameters. A variety of formats may be provided, but at present, only +the format produced by ``tvm.relay._save_params`` is supplied. When building with +``tvm.relay.build``, the ``name`` parameter is considered to be the model name. A single file is +created in this directory ``.json``. + +``src`` +^^^^^^^ + +Contains source code parsed by TVM. Currently, just the Relay source code is created in +``src/relay.txt``. + +Metadata +-------- + +Machine-parseable metadata is placed in a file ``metadata.json`` at the root of the tarball. +Metadata is a dictionary with these keys: + +- ``export_datetime``: Timestamp when this Model Library Format was generated, in + `strftime `_ + format ``"%Y-%M-%d %H:%M:%SZ",``. +- ``memory``: A summary of the memory usage of each generated function. Documented in + `Memory Usage Summary`_. +- ``model_name``: The name of this model (e.g. the ``name`` parameter supplied to + ``tvm.relay.build``). +- ``executors``: A list of executors supported by this model. Currently, this list is always + ``["graph"]``. +- ``target``: A dictionary mapping ``device_type`` (the underlying integer, as a string) to the + sub-target which describes that relay backend used for that ``device_type``. +- ``version``: A numeric version number that identifies the format used in this Model Library + Format. This number is incremented when the metadata structure or on-disk structure changes. + This document reflects version ``5``. + +Memory Usage Summary +^^^^^^^^^^^^^^^^^^^^ + +A dictionary with these sub-keys: + + - ``"main"``: ``list[MainFunctionWorkspaceUsage]``. A list summarizing memory usage for each + workspace used by the main function and all sub-functions invoked. + - ``"operator_functions"``: ``map[string, list[FunctionWorkspaceUsage]]``. Maps operator function + name to a list summarizing memory usage for each workpace used by the function. + +A ``MainFunctionWorkspaceUsage`` is a dict with these keys: + +- ``"device"``: ``int``. The ``device_type`` associated with this workspace. +- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function + and all sub-functions invoked. +- ``"constants_size_bytes"``: ``int``. Size of the constants used by the main function. +- ``"io_size_bytes"``: ``int``. Sum of the sizes of the buffers used from this workspace by this + function and sub-functions. + +A ``FunctionWorkspaceUsage`` is a dict with these keys: + +- ``"device"``: ``int``. The ``device_type`` associated with this workspace. +- ``"workspace_size_bytes"``: ``int``. Number of bytes needed in this workspace by this function. diff --git a/docs/dev/pass_infra.rst b/docs/dev/pass_infra.rst index 67ef30a29504..9fc24d87ef0d 100644 --- a/docs/dev/pass_infra.rst +++ b/docs/dev/pass_infra.rst @@ -109,7 +109,8 @@ configure the compilation options, including optimization level and required/disabled passes, etc. For instance, we may have a configuration which performs all passes at ``opt_level=3`` with some disabled passes using ``disabled_pass=xx`` provided by ``PassContext``. Now we could glob all passes -at ``opt_level=3`` and exclude those in the disabled pass list. +at ``opt_level=3`` and exclude those in the disabled pass list. ``PassContext`` +also provides a way to instrument all passes. See section :ref:`pass_instrument_cpp_backend`. This class is designed for users to conveniently write the Python ``with`` syntax to perform optimizations under a certain configuration. In addition, the @@ -123,16 +124,22 @@ Python APIs to create a compilation pipeline using pass context. class PassContextNode : public Object { public: - ErrorReporter err_reporter; int opt_level{2}; tvm::Array required_pass; tvm::Array disabled_pass; + mutable Optional diag_ctx; + Map config; + Array instruments; }; class PassContext : public NodeRef { public: TVM_DLL static PassContext Create(); TVM_DLL static PassContext Current(); + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; /* Other fields are omitted. */ private: @@ -338,7 +345,7 @@ favorably use Python APIs to create a specific pass object. Pass Sequential(tvm::Array passes, PassInfo pass_info); Pass Registration -~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^ We've covered the concept of different level of passes and the context used for compilation. It would be interesting to see how easily users can register @@ -389,6 +396,148 @@ To allow other C++ modules to apply this pass, we declare a free function in TVM_DLL Pass FoldConstant(); +.. _pass_instrument_cpp_backend: + +Pass Instrument +^^^^^^^^^^^^^^^ + +Pass Instrument is a mechanism to analyze the pass itself. For example, +we can use the infrastructure to know how much time and memory a pass requires +or how a pass can transform the IR module. + +We introduce four instrument points in the life-cycle of ``PassContext``. + +.. code:: c++ + + TVM_DLL void InstrumentEnterPassContext(); + TVM_DLL void InstrumentExitPassContext(); + TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const; + TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const; + +``InstrumentEnterPassContext`` is called immediately when entering the scope +of the ``PassContext`` instance. + +``InstrumentExitPassContext`` is called when leaving the scope of ``PassContext``, +or exceptions occur during the execution of passes. +This method is also called when instruments is being overriden by ``override_instruments`` in :py:class:`tvm.transform.PassContext`. +See :ref:`pass_instrument_overriden`. + +``InstrumentBeforePass`` is called before execution. +``InstrumentAfterPass`` is called after execution if the pass should be run. The behavior is like: + +.. code:: c++ + + if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) { + new_ir_module = run_pass(ir_module, pass_ctx); + pass_ctx.InstrumentAfterPass(new_ir_module, pass_info); + return new_ir_module; + } + +The ``PassInstrument`` interface allow you to run arbitrary code inside above four methods. +Multiple ``PassInstrument`` instances can be registed into a single +``PassContext``. ``PassInstrument`` instances are called sequentially in the order of +``instruments`` argument passed to ``PassContext``. + +``PassInstrument`` provides following interfaces: + +.. code:: c++ + + namespace instrument { + + class PassInstrumentNode : public Object { + public: + String name; + virtual void EnterPassContext() const = 0; + virtual void ExitPassContext() const = 0; + virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0; + virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; + /* Other fields are omitted. */ + }; + + class PassInstrument : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode); + }; + + } // namespace instrument + +Python frontend are provided to implement ``PassInstrument`` quickly. See :ref:`pass_instrument_py_frontend`. + +Within a ``PassContext``, the call sequence of a ``PassInstrument`` instance is like: + +:: + + with PassContext(instruments=[pi]) # pi = a PassInstrument implementation. + pi.EnterPassContext() + + if pi.ShouldRun(Pass1): + pi.RunBeforePass() + Pass1() + pi.RunAfterPass() + + if pi.ShouldRun(Pass2): + pi.RunBeforePass() + Pass2() + pi.RunAfterPass() + + pi.ExitPassContext() + +Here is a brief introduction of relations between ``PassInstrument`` interfaces +and ``PassContext`` methods. See (`src/ir/transform.cc`_) for more details. + +- ``InstrumentEnterPassContext`` + + * ``EnterPassContext()`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception raises, ``PassContext`` disable the pass instrumentation + by clearing all registered ``PassInstrument`` instances. + * Then ``PassContext`` execute ``ExitPassContext()`` method of each ``PassInstrument`` + instances which successfully finished ``EnterPassContext()`` + * For example, if ``PassInstrument`` A, B, and C are registered to a ``PassContext`` + and A finished ``EnterPassContext()`` while B throws an exception, then C + is never executed; ``ExitPassContext()`` of A is executed. + +- ``InstrumentExitPassContext`` + + * ``ExitPassContext()`` of each ``PassInstrument`` instances are executed in + the order of ``instruments`` passed to the ``PassContext``. + * While an exception occurs, ``instruments`` is cleared. + * ``PassInstrument`` Instances registered after the one throwing exceptions do not execute ``ExitPassContext``. + +- ``InstrumentBeforePass`` + + * ``ShouldRun`` is executed if the pass is not listed as a required pass. + * ``RunBeforePass`` is executed in the order of ``instruments`` if the pass is not blocked by ``ShouldRun``. + * Note that ``InstrumentBeforePass`` returns a boolean indicating whether or not the pass should be run. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager to exit ``PassContext`` safely + (meaning ``ExitPassContext`` of each instruments will be run. For C++, please refer to `include/tvm/support/with.h`_.) + +- ``InstrumentAfterPass`` + + * ``RunAfterPass`` is executed in the order of ``instruments`` passed to the ``PassContext``. + * When an exception occur, it is thrown immediately. + We rely on Python Context Manager or ``With`` class(`include/tvm/support/with.h`_) to exit ``PassContext`` safely + +Built-in Instrument +^^^^^^^^^^^^^^^^^^^ + +There are several built-in instruments. Those marked with *TODO* are not implemented yet. + +- PassTimingInstrument (see `src/ir/instrument.cc`_) + + * Profile the execution time of passes. + +- PrintIRBefore(TODO) + + * Print the IR module before the pass transforms it. :py:func:`tvm.transform.PrintIR` + can also serve this purpose if we insert it around passes. However, + with the ``PassInstrument``, we don't need to modify the sequence of passes. + +- PrintAfter(TODO) + + * Print the IR module after the pass transforms it. + Python Frontend ~~~~~~~~~~~~~~~ @@ -478,7 +627,7 @@ Users can build a pass through decoration like the following: x = relay.var("x", tp) gv = relay.GlobalVar("abs") func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) + new_mod = tvm.IRModule({gv: func}) new_mod.update(mod) return new_mod @@ -494,7 +643,7 @@ function. .. code:: python - mod = relay.Module() + mod = tvm.IRModule() mod = module_pass(mod) Correspondingly, we also offer such functionality for ``function_pass``. For @@ -526,16 +675,78 @@ decorators and then invoke it. For more examples about how to customize your own optimization pipeline and debug Relay and tir passes, please refer to the `use pass infra`_ tutorial. + +.. _pass_instrument_py_frontend: + +Pass Instrument +^^^^^^^^^^^^^^^ + +One can implement a ``PassInstrument`` by using the ``pass_instrument`` +decorator(`python/tvm/ir/instrument.py`_) on a class implementing following methods. +Note that it is recommended to use the ``pass_instrument`` decorator to implement +``PassInstrument``, instead of overriding or subclassing. + +- ``enter_pass_ctx`` + + * This method is run when entering ``PassContext``. + +- ``exit_pass_ctx`` + + * This method is run when exiting ``PassContext``. + +- ``should_run`` + + * This method is run before a pass is executed, returning a boolean + indicating whether or not the pass should be run. + +- ``run_before_pass`` + + * If a pass should be run, this method is run just before pass execution. + +- ``run_after_pass`` + + * This method is run right after a pass has been executed. + +``PassInstrument`` instances can be registered through ``instruments`` argument in +:py:class:`tvm.transform.PassContext`. + +`use pass instrument`_ tutorial provides examples for how to implement ``PassInstrument`` with Python APIs. + +.. _pass_instrument_overriden: + +Override Instruments in Current PassContext +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``override_instruments`` method is provided to override the ``instruments`` of current ``PassContext``. +For example, if passes are run without explicitly creating a new ``PassContext``, +one can still register ``PassInstrument`` into the global ``PassContext`` by: + +.. code:: python + + cur_pass_ctx = tvm.transform.PassContext.current() + # override PassInstrument instances + cur_pass_ctx.override_instruments([pass_inst]) + mod = pass_seq(mod) + result = pass_inst.get_result() + +Note that when ``override_instruments`` is called, the ``exit_pass_ctx`` method of +old ``PassInstrument`` instances are called. Then the ``enter_pass_ctx`` method of +new ``PassInstrument`` are called. + .. _Sequential: https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential .. _Block: https://mxnet.apache.org/api/python/docs/api/gluon/block.html#gluon-block .. _include/tvm/ir/transform.h: https://github.com/apache/tvm/blob/main/include/tvm/ir/transform.h +.. _include/tvm/support/with.h: https://github.com/apache/tvm/blob/main/include/tvm/support/with.h + .. _src/relay/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/relay/ir/transform.cc .. _src/ir/transform.cc: https://github.com/apache/tvm/blob/main/src/ir/transform.cc +.. _src/ir/instrument.cc: https://github.com/apache/tvm/blob/main/src/ir/instrument.cc + .. _src/relay/transforms/fold_constant.cc: https://github.com/apache/tvm/blob/main/src/relay/transforms/fold_constant.cc .. _python/tvm/relay/transform/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/relay/transform/transform.py @@ -544,6 +755,10 @@ optimization pipeline and debug Relay and tir passes, please refer to the .. _python/tvm/ir/transform.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/transform.py +.. _python/tvm/ir/instrument.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/instrument.py + .. _src/tir/transforms/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tir/transforms/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py + +.. _use pass instrument: https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_instrument.py diff --git a/docs/index.rst b/docs/index.rst index a7ae68c87b01..491c42712e9a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -78,6 +78,7 @@ For Developers :caption: MISC vta/index + profiling/index Index diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index c8e1ab7fc045..5c3a2544c578 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -63,7 +63,8 @@ The minimal building requirements for the ``TVM`` libraries are: - CMake 3.5 or higher - We highly recommend to build with LLVM to enable all the features. - If you want to use CUDA, CUDA toolkit version >= 8.0 is required. If you are upgrading from an older version, make sure you purge the older version and reboot after installation. - - On macOS, you may want to install `Homebrew ` to easily install and manage dependencies. + - On macOS, you may want to install `Homebrew `_ to easily install and manage dependencies. + - Python is also required. Avoid using Python 3.9.X+ which is not `supported `_. 3.7.X+ and 3.8.X+ should be well supported however. To install the these minimal pre-requisites on Ubuntu/Debian like linux operating systems, execute (in a terminal): @@ -73,6 +74,15 @@ linux operating systems, execute (in a terminal): sudo apt-get update sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +Use Homebrew to install the required dependencies for macOS running either the Intel or M1 processors. You must follow the post-installation steps specified by +Homebrew to ensure the dependencies are correctly installed and configured: + +.. code:: bash + + brew install gcc git cmake + brew install llvm + brew install python@3.8 + We use cmake to build the library. The configuration of TVM can be modified by editing `config.cmake` and/or by passing cmake flags to the command line: @@ -293,6 +303,21 @@ like ``virtualenv``. pip3 install --user tornado psutil xgboost cloudpickle +Note on M1 macs, you may have trouble installing xgboost / scipy. scipy and xgboost requires some additional dependencies to be installed, +including openblas and its dependencies. Use the following commands to install scipy and xgboost with the required dependencies and +configuration. A workaround for this is to do the following commands: + + .. code:: bash + + brew install openblas gfortran + + pip install pybind11 cython pythran   + + export OPENBLAS=/opt/homebrew/opt/openblas/lib/ + + pip install scipy --no-use-pep517 + + pip install xgboost Install Contrib Libraries ------------------------- @@ -316,7 +341,7 @@ tests in TVM. The easiest way to install GTest is from source. cd googletest mkdir build cd build - cmake -DMAKE_SHARED_LIBS=ON .. + cmake -DBUILD_SHARED_LIBS=ON .. make sudo make install diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst index 658ec7d56f1b..72af31663e8a 100644 --- a/docs/langref/relay_expr.rst +++ b/docs/langref/relay_expr.rst @@ -432,7 +432,7 @@ This definition would result in a module entry mapping the identifier :code:`@ac with the parameters, return type, and body above. Any reference to the identifier :code:`@ackermann` elsewhere in the code could then look up the identifier in the module and replace the function definition as needed. -See :py:class:`~tvm.relay.Module` for the definition and documentation of a module. +See :py:class:`~tvm.IRModule` for the definition and documentation of a module. Constant ======== diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index febe542b83b1..3e797fc93b31 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -181,7 +181,9 @@ This level enables additional math and transform operators. .. autosummary:: :nosignatures: - tvm.relay.image.resize + tvm.relay.image.resize1d + tvm.relay.image.resize2d + tvm.relay.image.resize3d tvm.relay.image.crop_and_resize tvm.relay.image.dilation2d tvm.relay.vision.multibox_prior diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 49d3a42d3e98..b74c58921d3f 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -434,13 +434,15 @@ Pattern Rewriting If you would like to replace the matched pattern with another subgraph, you can leverage the ``rewrite`` transformation. Here is an example of rewriting a series of arithmetic operators -with a single batch_norm op: +with a single batch_norm op. The constructor parameter ``require_type`` indicates whether InferType +is required to be run before the callback. .. code-block:: python class BatchnormCallback(DFPatternCallback): # A callback class to rewrite the matched pattern to a batch_norm op. - def __init__(self): + def __init__(self, require_type=False): + super().__init__(require_type) self.x = wildcard() self.var = wildcard() self.mean = wildcard() diff --git a/docs/profiling/index.rst b/docs/profiling/index.rst new file mode 100644 index 000000000000..9443fef25ea6 --- /dev/null +++ b/docs/profiling/index.rst @@ -0,0 +1,24 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Profiling Deep Learning Models +==================================== + +.. toctree:: + :maxdepth: 1 + + papi diff --git a/docs/profiling/papi.rst b/docs/profiling/papi.rst new file mode 100644 index 000000000000..b7c23b2c0c73 --- /dev/null +++ b/docs/profiling/papi.rst @@ -0,0 +1,114 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Getting Started With PAPI +========================= + +The Performance Application Programming Interface (PAPI) is a library that +provides performance counters on a variety of platforms. Performance counters +provide accurate low-level information about processors behavior during a given +execution run. This information can contain simple metrics like total cycle +count, cache misses, and instructions executed as well as more high level +information like total FLOPS and warp occupancy. PAPI makes these metrics +available while profiling. + +Installing PAPI +--------------- + +PAPI can either be installed using your package manager (``apt-get install libpapi-dev`` +on Ubuntu), or from source here: +https://bitbucket.org/icl/papi/src/master/. + + +Building TVM With PAPI +---------------------- + +To include PAPI in your build of TVM, set the following line in you ``config.cmake``: + +.. code:: + + set(USE_PAPI ON) + +If PAPI is installed in a non-standard place, you can specify where it is like so: + +.. code:: + + set(USE_PAPI path/to/papi.pc) + + +Using PAPI While Profiling +-------------------------- + +If TVM has been built with PAPI (see above), then you can pass a +:py:class:`tvm.runtime.profiling.PAPIMetricCollector` to +:py:meth:`tvm.runtime.GraphModule.profile` to collect performance metrics. Here +is an example: + +.. code:: python + + target = "llvm" + dev = tvm.cpu() + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector()], + ) + print(report) + +.. code:: + + Name perf::CACHE-MISSES perf::CYCLES perf::STALLED-CYCLES-BACKEND perf::INSTRUCTIONS perf::STALLED-CYCLES-FRONTEND + fused_nn_dense_nn_bias_add_nn_relu 2,494 1,570,698 85,608 675,564 39,583 + fused_nn_dense_nn_bias_add_nn_relu_1 1,149 655,101 13,278 202,297 21,380 + fused_nn_dense_nn_bias_add 288 600,184 8,321 163,446 19,513 + fused_nn_batch_flatten 301 587,049 4,636 158,636 18,565 + fused_nn_softmax 154 575,143 8,018 160,738 18,995 + ---------- + Sum 4,386 3,988,175 119,861 1,360,681 118,036 + Total 10,644 8,327,360 179,310 2,660,569 270,044 + +You can also change which metrics are collected: + +.. code:: python + + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: ["PAPI_FP_OPS"])], + ) + +.. code:: + + Name PAPI_FP_OPS + fused_nn_dense_nn_bias_add_nn_relu 200,832 + fused_nn_dense_nn_bias_add_nn_relu_1 16,448 + fused_nn_dense_nn_bias_add 1,548 + fused_nn_softmax 160 + fused_nn_batch_flatten 0 + ---------- + Sum 218,988 + Total 218,988 + +You can find a list of available metrics by running the ``papi_avail`` and +``papi_native_avail`` commands. diff --git a/golang/sample/complex.go b/golang/sample/complex.go index 911d0a7a28c1..91821c978e96 100644 --- a/golang/sample/complex.go +++ b/golang/sample/complex.go @@ -88,7 +88,7 @@ func main() { // Array allocation attributes tshapeIn := []int64{1, 224, 224, 3} - tshapeOut := []int64{1, 1000} + tshapeOut := []int64{1, 1001} // Allocate input Array inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 641d0e0f5321..6c72cbeafdd4 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -282,6 +282,39 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +/*! + * \brief Use IterVarMap detector to rewrite and simplify the indices + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param input_pred The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * + * \return The indices after rewrite + */ +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective); + +/*! + * \brief Apply the inverse of the affine transformation to the outputs. + * + * Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions + * in reverse topology order and applies the inverse of the affine transformation until it reaches + * the input. The affine iter map is required to be bijective. + * + * For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], + * the affine transformation specified by `iter_map` will be applied to `outputs` and the result + * will be {l0: ((output_0*16) + output_1)}. + * + * \sa DetectIterMap + * + * \param iter_map The bijective affine iter map. + * \param outputs The outputs of the affine transformation. + * + * \return The map from the input to the transformed result. + */ +Map InverseAffineIterMap(const Array& iter_map, + const Array outputs); /*! * \brief Detect if bindings can be written as diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h new file mode 100644 index 000000000000..afbe1f343bb8 --- /dev/null +++ b/include/tvm/ir/affine_type.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/affine_type.h + * \brief Quantized Tensor Types. + */ +#ifndef TVM_IR_AFFINE_TYPE_H_ +#define TVM_IR_AFFINE_TYPE_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "AffineType"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypeNode. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +/*! + * \brief TensorAffineType representation + * \sa TensorAffineType + * + * This Type represents a quantized integer tensor that can be converted + * back to real space via the x_real = scale * (x_quant - zero_point) + */ +class TensorAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The scale of this type */ + RelayExpr scale; + /*! \brief The zero point of this type */ + RelayExpr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const char* _type_key = "TensorAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TensorAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class TensorAffineType : public AffineType { + public: + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); +}; + +/*! + * \brief TupleAffineType representation + * \sa TupleAffineType + */ +class TupleAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The types of this tuple*/ + Array types; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("types", &types); } + + bool SEqualReduce(const TupleAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(types, other->types); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(types); + } + + static constexpr const char* _type_key = "TupleAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to TupleAffineTypes. + * \sa TupleAffineType + */ +class TupleAffineType : public AffineType { + public: + TVM_DLL TupleAffineType(Array types); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleAffineType, AffineType, TupleAffineTypeNode); +}; + +} // namespace tvm +#endif // TVM_IR_AFFINE_TYPE_H_ diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index c1a012f05318..09c074cb71bd 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -43,7 +43,7 @@ namespace tvm { */ enum class CallingConv : int { /*! - * \brief Default calling convetion. + * \brief Default calling convention. * * - Uses the native calling convention of the target. * - Implementation: specified by the native target. diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 4a2eb63c7e6a..8379e6471561 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode { } }; +/*! + * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + */ +struct TIRCallAttrs : public tvm::AttrsNode { + /*! \brief The metadata attached to the call node. */ + Map metadata; + + TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") { + TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ANNOTATION_H_ diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index baceb04958f0..b851add61e4a 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -32,31 +32,74 @@ namespace tvm { namespace relay { -/*! \brief Attributes used in image resize operator */ -struct ResizeAttrs : public tvm::AttrsNode { +/*! \brief Attributes used in image resize1d operator */ +struct Resize1DAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; std::string coordinate_transformation_mode; std::string rounding_method; - double bicubic_alpha; - int bicubic_exclude; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; - TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { + TVM_DECLARE_ATTRS(Resize1DAttrs, "relay.attrs.Resize1DAttrs") { + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel and width" + "dimensions respectively. Resize is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Linear Interpolation" + "cubic - Cubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for cubic interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during cubic interpolation"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + } +}; + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array size; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCHW").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Resize is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("bilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( @@ -69,10 +112,10 @@ struct ResizeAttrs : public tvm::AttrsNode { .describe( "indicates how to find the \"nearest\" pixel in nearest_neighbor method" "Available options are round, floor, and ceil."); - TVM_ATTR_FIELD(bicubic_alpha) + TVM_ATTR_FIELD(cubic_alpha) .set_default(-0.5) .describe("Spline Coefficient for Bicubic Interpolation"); - TVM_ATTR_FIELD(bicubic_exclude) + TVM_ATTR_FIELD(cubic_exclude) .set_default(0) .describe("Flag to exclude exterior of the image during bicubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); @@ -80,32 +123,46 @@ struct ResizeAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in image resize3d operator */ -struct Resize3dAttrs : public tvm::AttrsNode { +struct Resize3DAttrs : public tvm::AttrsNode { Array size; - String layout; - String method; - String coordinate_transformation_mode; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; - TVM_DECLARE_ATTRS(Resize3dAttrs, "relay.attrs.Resize3dAttrs") { + TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Resize3d is applied on the 'D', 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("trilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Trilinear Interpolation" + "cubic - Tricubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( "Describes how to transform the coordinate in the resized tensor" "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Tricubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during tricubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 3c7574562676..694001f612e7 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1003,16 +1003,26 @@ struct DenseAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for batch matmul operator */ +/*! \brief Attributes for batch matmul operator. */ struct BatchMatmulAttrs : public tvm::AttrsNode { - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; + bool transpose_a; + bool transpose_b; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") { // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(transpose_a) + .set_default(false) + .describe("Whether the first input tensor is in transposed format."); + + TVM_ATTR_FIELD(transpose_b) + .set_default(false) + .describe("Whether the second input tensor is in transposed format."); } }; diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 95eaad0b2797..fccd1f937a06 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -126,7 +126,7 @@ namespace attr { /*! \brief Mark the function as a primitive function. */ constexpr const char* kPrimitive = "Primitive"; /*! - * \brief Indicate the compiler that should be used for builing this function. + * \brief Indicate the compiler that should be used for building this function. * When this is unset or set to "default", the default compilation pipeline will be used. */ constexpr const char* kCompiler = "Compiler"; diff --git a/include/tvm/runtime/contrib/papi.h b/include/tvm/runtime/contrib/papi.h new file mode 100644 index 000000000000..ff2d75c483eb --- /dev/null +++ b/include/tvm/runtime/contrib/papi.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \brief Performance counters for profiling via the PAPI library. + */ +#ifndef TVM_RUNTIME_CONTRIB_PAPI_H_ +#define TVM_RUNTIME_CONTRIB_PAPI_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace profiling { + +/*! \brief Construct a metric collector that collects data from hardware + * performance counters using the Performance Application Programming Interface + * (PAPI). + * + * \param metrics A mapping from a device type to the metrics that should be + * collected on that device. You can find the names of available metrics by + * running `papi_native_avail`. + */ +TVM_DLL MetricCollector CreatePAPIMetricCollector(Map> metrics); +} // namespace profiling +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_PAPI_H_ diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 58b9ff1932cc..71188574ac2a 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -291,7 +291,14 @@ inline Device RemoveRPCSessionMask(Device dev) { return dev; } -inline std::ostream& operator<<(std::ostream& os, DLDevice dev); +inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) + if (tvm::runtime::IsRPCSessionDevice(dev)) { + os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-"; + dev = tvm::runtime::RemoveRPCSessionMask(dev); + } + os << tvm::runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; + return os; +} /*! * \brief Add a RPC session mask to a Device. @@ -308,14 +315,6 @@ inline Device AddRPCSessionMask(Device dev, int session_table_index) { return dev; } -inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) - if (IsRPCSessionDevice(dev)) { - os << "remote[" << GetRPCSessionIndex(dev) << "]-"; - dev = RemoveRPCSessionMask(dev); - } - os << runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; - return os; -} } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 9dd7423c6679..71be8d218d2d 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -230,8 +230,10 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; -/*! \brief The main AOT executor function */ +/*! \brief The main AOT executor function generated from TIR */ constexpr const char* tvm_run_func_suffix = "run_model"; +/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ +constexpr const char* tvm_entrypoint_suffix = "run"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 5b44e020f4e4..eea195f64a6d 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -37,6 +37,7 @@ #include namespace tvm { + namespace runtime { /*! \brief Base class for all implementations. @@ -150,6 +151,26 @@ class Timer : public ObjectRef { Timer DefaultTimer(Device dev); namespace profiling { +/*! \brief Wrapper for `Device` because `Device` is not passable across the + * PackedFunc interface. + */ +struct DeviceWrapperNode : public Object { + /*! The device */ + Device device; + + /*! Constructor */ + explicit DeviceWrapperNode(Device device) : device(device) {} + + static constexpr const char* _type_key = "runtime.profiling.DeviceWrapper"; + TVM_DECLARE_BASE_OBJECT_INFO(DeviceWrapperNode, Object); +}; + +/*! \brief Wrapper for `Device`. */ +class DeviceWrapper : public ObjectRef { + public: + explicit DeviceWrapper(Device dev) { data_ = make_object(dev); } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DeviceWrapper, ObjectRef, DeviceWrapperNode); +}; /*! \brief Data collected from a profiling run. Includes per-call metrics and per-device metrics. */ @@ -184,6 +205,39 @@ class ReportNode : public Object { * `aggregate` is true. */ String AsTable(bool sort = true, bool aggregate = true) const; + /*! \brief Convert this report to JSON. + * + * Output JSON will be of this format: + * \code + * { + * "calls": [ + * { + * "Duration (us)": { + * "microseconds": 12.3 + * }, + * "Name": "fused_dense", + * "Count": { + * "count": 1 + * }, + * "Percent": { + * "percent": 10.3 + * } + * } + * ], + * "device_metrics": { + * "cpu": { + * "Duration (us)": { + * "microseconds": 334.2 + * }, + * "Percent": { + * "percent": 100 + * } + * } + * } + * } + * \endcode + */ + String AsJSON() const; static constexpr const char* _type_key = "runtime.profiling.Report"; TVM_DECLARE_FINAL_OBJECT_INFO(ReportNode, Object); @@ -200,6 +254,57 @@ class Report : public ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode); }; +/*! \brief Interface for user defined profiling metric collection. + * + * Users can register their own collector by registering a packed function with + * the name "runtime.profiling.metrics.my_collector_name" where + * "my_collector_name" is the name of their collector. This function should + * take an Array of Device as input which contains the devices the collector + * will be run on. + * + * `MetricCollectorNode`s will be called in the following fashion. + * \code + * MetricCollector mc; + * for (auto op : model) { + * auto o = mc.Start(); + * op(); + * auto metrics = mc.Stop(o); // metrics are added the profiling report + * } + * \endcode + */ +class MetricCollectorNode : public Object { + public: + /*! \brief Initialization call. Called before profiling has started. Any + * expensive precomputation should happen here. + * \param devs The list of devices this collector will be run on. + */ + virtual void Init(Array devs) = 0; + /*! \brief Start colling metrics for a function call. + * \param dev The device the call will be run on. + * \returns An object used to maintain state of the metric collection. This + * object will be passed to the corresponding `Stop` call. If the device is + * not supported, this function will return a nullptr ObjectRef. + */ + virtual ObjectRef Start(Device dev) = 0; + /*! \brief Stop collecting metrics. + * \param obj The object created by the corresponding `Start` call. + * \returns A set of metric names and the associated values. Values must be + * one of DurationNode, PercentNode, CountNode, or StringObj. + */ + virtual Map Stop(ObjectRef obj) = 0; + + virtual ~MetricCollectorNode() {} + + static constexpr const char* _type_key = "runtime.profiling.MetricCollector"; + TVM_DECLARE_BASE_OBJECT_INFO(MetricCollectorNode, Object); +}; + +/*! \brief Wrapper for `MetricCollectorNode`. */ +class MetricCollector : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetricCollector, ObjectRef, MetricCollectorNode); +}; + /*! Information about a single function or operator call. */ struct CallFrame { /*! Device on which the call was made */ @@ -210,6 +315,10 @@ struct CallFrame { Timer timer; /*! Extra performance metrics */ std::unordered_map extra_metrics; + /*! User defined metric collectors. Each pair is the MetricCollector and its + * associated data (returned from MetricCollector.Start). + */ + std::vector> extra_collectors; }; /*! Runtime profiler for function and/or operator calls. Used in the graph @@ -217,9 +326,10 @@ struct CallFrame { * * Example usage: * \code{.cpp} - * Profiler prof; * Device cpu, gpu; - * prof.Start({cpu, gpu}); + * Profiler prof({cpu, gpu}); + * my_gpu_kernel(); // do a warmup iteration + * prof.Start(); * prof.StartCall("my_gpu_kernel", gpu); * my_gpu_kernel(); * prof.StopCall(); @@ -232,13 +342,24 @@ struct CallFrame { */ class Profiler { public: - /*! \brief Start the profiler. + /*! Constructor. + * + * The profiler should be constructed before you do any warmup iterations. + * + * \note + * Calling this constructor will reset the TVM threadpool. It is necessary in + * order to install thread handlers required by certain collectors. + * * \param devs The list of devices the profiler will be running on. Should * include all devices used by profiled operators. + * \param metric_collectors Additional `MetricCollector`s to use with this profiler. + */ + explicit Profiler(std::vector devs, std::vector metric_collectors); + /*! \brief Start the profiler. * * This function should only be called once per object. */ - void Start(const std::vector& devs); + void Start(); /*! \brief Stop the profiler. * * This function should only be called once per object after start has been called. @@ -270,12 +391,14 @@ class Profiler { /*! \brief Check if the profiler is currently running. * \returns Whether or not the profiler is running. */ - bool IsRunning() const { return !global_timers_.empty(); } + bool IsRunning() const { return is_running_; } private: - std::vector> global_timers_; + std::vector devs_; + bool is_running_{false}; std::vector calls_; std::stack in_flight_; + std::vector collectors_; }; /* \brief A duration in time. */ diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index 95a64049fd45..43636ddbdb1f 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -94,6 +94,14 @@ void Yield(); */ int MaxConcurrency(); +/*! + * \brief Reset the threads in the pool. All current threads are destroyed and + * new ones are created. + * + * Note that this does nothing when openmp is used. + */ +void ResetThreadPool(); + } // namespace threading } // namespace runtime } // namespace tvm diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..13f39317dbe4 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object { * \param stage the op's stage. * \param realize_map The realization domain map of the operators. * \param body The body that is going to get + * \param storage_scope The storage scope associated with this realization * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 262ac688f2e0..dce9736adec7 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -96,22 +96,20 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set.. - * \param expr The expression to be checked. - * \param vset_contains The check function to see if var is in the vset. - * \return Whether e uses vset. + * \brief Whether the given Stmt uses any var in the given variable set. + * \param stmt The Stmt to be checked. + * \param vset_contains The check function to see if a var is in the variable set. + * \return Whether `stmt` uses any var in the given variable set. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); +TVM_DLL bool UsesVar(const Stmt& stmt, std::function vset_contains); /*! - * \brief Whether e expression used var. - * \param expr The expression to be checked. - * \param var The variable. - * \return Whether e uses v. + * \brief Whether the given PrimExpr uses any var in the given variable set. + * \param expr The PrimExpr to be checked. + * \param vset_contains The check function to see if var is in the variable set. + * \return Whether `expr` uses any var in the given variable set. */ -inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); -} +TVM_DLL bool UsesVar(const PrimExpr& expr, std::function vset_contains); /*! * \brief Verifies whether the IR stmt or Expr is in SSA form. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index a01d69b372d2..28d202cb50a9 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -67,8 +67,6 @@ class BufferNode : public Object { // Meta data /*! \brief optional name of the buffer */ String name; - /*! \brief storage scope of the buffer, if other than global */ - String scope; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -93,7 +91,6 @@ class BufferNode : public Object { v->Visit("strides", &strides); v->Visit("elem_offset", &elem_offset); v->Visit("name", &name); - v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); v->Visit("buffer_type", &buffer_type); @@ -105,7 +102,7 @@ class BufferNode : public Object { // in its semantics, skip name as name is not important. return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } @@ -115,7 +112,6 @@ class BufferNode : public Object { hash_reduce.DefHash(shape); hash_reduce.DefHash(strides); hash_reduce.DefHash(elem_offset); - hash_reduce(scope); hash_reduce(data_alignment); hash_reduce(buffer_type); } @@ -141,8 +137,8 @@ class Buffer : public ObjectRef { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span = Span()); + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -182,7 +178,13 @@ class Buffer : public ObjectRef { */ TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + /*! + * \brief Return the storage scope associated with this buffer. + */ + TVM_DLL String scope() const; + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; /*! @@ -190,12 +192,13 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param storage_scope The storage scope associated with this buffer * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 40d66a2d8357..8ea48dd592d5 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1126,6 +1126,9 @@ class AnyNode : public PrimExprNode { /*! \brief Convert to var. */ Var ToVar() const { return Var("any_dim", DataType::Int(32)); } + /*! \brief Convert to SizeVar. */ + SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); } + static constexpr const char* _type_key = "tir.Any"; TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97ee7f7211d4..55f4fc62649c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Specialize parameters of PrimFunc. + * \param func The PrimFunc to be specialized. + * \param param_map The mapping from function params to the instance. + * \return The new function with parameter specialized. + * \note We can define a Meta TIR function with symbolic shape: + * + * \code + * @tvm.script.tir + * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + * A = tir.match_buffer(a, (m, n), "float32") + * B = tir.match_buffer(b, (m, n), "float32") + * + * with tir.block([m, n], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + * + * Then we can make it specialized with given shapes or buffers. + * + * \code + * a, _, m, n = mem_copy.params + * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + * # or + * func = mem_copy.specialize({n: 16, m: 16}) + * \endcode + * + * \code {.language-id} + * @tvm.script.tir + * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + * A = tir.match_buffer(a, (16, 16), "float32") + * B = tir.match_buffer(b, (16, 16), "float32") + * + * with tir.block([16, 16], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + */ +PrimFunc Specialize(PrimFunc func, const Map& param_map); + /*! * \brief PrimFunc specific attribute names. * @@ -202,10 +240,12 @@ namespace attr { * * Call(f, * [arg1, arg2, ..., arg_n, - * work_size_1, work_size_2, ... work_size_m]) + * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size]) * * Here n = len(arg), m = len(work_size) = len(device_thread_axis). * + * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is omitted. + * * The list of device_thread_axis indicates how can be bind the * work_size arguments to the corresponding threads. * @@ -213,6 +253,13 @@ namespace attr { */ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; +/*! + * \brief Whether or not use dynamic shared memory. + * + * Type: Integer + */ +constexpr const char* kDeviceUseDynSharedMemory = "tir.device_use_dyn_shared_memory"; + /*! * \brief Whether to set noalias rule on the function arguments. * diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index fb08583b7771..be3e79a18331 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -262,7 +262,7 @@ class BlockScope : public ObjectRef { * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL BlockScope(const Array& child_block_srefs); + TVM_DLL explicit BlockScope(const Array& child_block_srefs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); }; diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h new file mode 100644 index 000000000000..5a9e687dc8c7 --- /dev/null +++ b/include/tvm/tir/schedule/instruction.h @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_H_ +#define TVM_TIR_SCHEDULE_INSTRUCTION_H_ + +#include + +#include + +namespace tvm { + +// Forward declaration +template +class AttrRegistry; + +namespace tir { + +// Forward declaration +class Schedule; + +/*! + * \brief Type of the functor that applies the instruction to a TensorIR schedule + * \param sch The schedule to be applied on + * \param inputs The input random variables + * \param attrs Instruction attributes + * \param decision Decisions made on the instruction + * \return The functor returns an array of output random variables + */ +using FInstructionApply = runtime::TypedPackedFunc( + Schedule sch, const Array& inputs, const Array& attrs, + const Optional& decision)>; + +/*! + * \brief Type of the functor that converts the instruction to a statement in python syntax + * \param inputs Names of the input random variables + * \param attrs Instruction attributes + * \param decisions Decisions made on the instruction + * \param outputs Names of the output random variables + * \return A string representing the python api call + */ +using FInstructionAsPython = runtime::TypedPackedFunc& inputs, const Array& attrs, + const Optional& decision, const Array& outputs)>; + +/*! + * \brief Type of the functor that serialize its attributes to JSON + * \param attrs The attributes to be serialized + * \return An array, serialized attributes + * \note This functor is nullable + */ +using FInstructionAttrsAsJSON = runtime::TypedPackedFunc attrs)>; + +/*! + * \brief Type of the functor that deserialize its attributes from JSON + * \param json_attrs The attributes to be serialized + * \return An array, deserialized attributes + * \note This functor is nullable + */ +using FInstructionAttrsFromJSON = runtime::TypedPackedFunc(ObjectRef json_attrs)>; + +/*! + * \brief Kind of an instruction, e.g. Split, Reorder, etc. + * Besides the name, every kind of instruction has its own properties, including: + * 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + * 2) A functor that applies the instruction to a TensorIR schedule + * 3) A functor that converts the instruction to a statement in python syntax + * 4) A functor that serialize its attributes to JSON + * 5) A functor that deserialize its attributes from JSON + * + * Unlike `tvm::OpNode`, `InstructionKindNode` doesn't support unstructured properties, + * mainly because there is no such usecase yet to add any other property. + */ +class InstructionKindNode : public runtime::Object { + public: + /*! \brief The name of a kind of instructions */ + String name; + /*! + * \brief Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule + * state. For example, the instruction `GetBlock` is pure because it changes + * nothing, while `ComputeInline` is not because removing it leads to a different resulting + * schedule. + */ + bool is_pure{false}; + /*! \brief A functor that applies the instruction to a TensorIR schedule */ + FInstructionApply f_apply_to_schedule{nullptr}; + /*! \brief A functor that converts the instruction to a statement in python syntax */ + FInstructionAsPython f_as_python{nullptr}; + /*! + * \brief A functor that serialize its attributes to JSON + * \note If the functor is null, it means no conversion is needed + */ + FInstructionAttrsAsJSON f_attrs_as_json{nullptr}; + /*! + * \brief A functor that deserialize its attributes from JSON + * \note If the functor is null, it means no conversion is needed + */ + FInstructionAttrsFromJSON f_attrs_from_json{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("_is_pure", &is_pure); + // not visited: f_apply_to_schedule + // not visited: f_as_python + // not visited: f_attrs_as_json + // not visited: f_attrs_from_json + } + + static constexpr const char* _type_key = "tir.InstructionKind"; + TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); +}; + +/*! + * \brief Managed reference to InstructionKindNode + * \sa InstructionKindNode + */ +class InstructionKind : public runtime::ObjectRef { + public: + /*! + * \brief Retrieve an InstructionKind using its name + * \param name The registered name of the InstructionKind + * \return The InstructionKind retrieved + */ + static InstructionKind Get(const String& name); + TVM_DEFINE_OBJECT_REF_METHODS(InstructionKind, runtime::ObjectRef, InstructionKindNode); +}; + +/*! \brief Schedule instructions each corresponds to a schedule primitive */ +class InstructionNode : public runtime::Object { + public: + /*! \brief The kind of the instruction */ + InstructionKind kind; + /*! + * \brief The input random variables of the instruction, and the type of each element can be one + * of the following: + * - BlockRV + * - LoopRV + * - ExprRV + * - FloatImm + * - IntImm + * - String + * - null pointer + */ + Array inputs; + /*! + * \brief The attributes of the instruction. Similar to attributes of an operator, + * attributes of an instruction are arbitrary constant metadata required by the instructions. + * For example, the name of the block to be retrieved in `GetBlock`. + */ + Array attrs; + /*! \brief The output random variables of the instruction, and the type of each element can be one + * of the following: + * - BlockRV + * - LoopRV + * - ExprRV, atomic variables only, won't be constants or composite PrimExpr + */ + Array outputs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("kind", &kind); + v->Visit("inputs", &inputs); + v->Visit("attrs", &attrs); + v->Visit("outputs", &outputs); + } + + static constexpr const char* _type_key = "tir.Instruction"; + TVM_DECLARE_FINAL_OBJECT_INFO(InstructionNode, runtime::Object); +}; + +/*! + * \brief Managed reference to InstructionNode + * \sa InstructionNode + */ +class Instruction : public runtime::ObjectRef { + public: + /*! + * \brief Constructor + * \param kind The kind of the instruction + * \param inputs The input random variables of the instruction + * \param attrs The attributes of the instruction + * \param outputs The output random variables of the instruction + */ + explicit Instruction(InstructionKind kind, Array inputs, Array attrs, + Array outputs); + + TVM_DEFINE_OBJECT_REF_METHODS(Instruction, runtime::ObjectRef, InstructionNode); +}; + +/*! + * \brief A helper macro to register InstructionKind, only used in `TVM_REGISTER_INST_KIND` + * \note This macro is not user-facing. + * \sa TVM_REGISTER_INST_KIND + */ +#define TVM_INST_KIND_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::tir::InstructionKindRegEntry& __make_##InstructionKind + +/*! + * \brief Register an InstructionKind + * \param InstructionKindName The name of the InstructionKind + * + * Example: + * + * \code + * + * TVM_REGISTER_INST_KIND("ComputeInline") + * .set_is_pure(false) + * .set_apply_to_schedule(ApplyToSchedule) + * .set_attrs_as_json(AttrsAsJSON) + * .set_attrs_from_json(AttrsFromJSON) + * .set_as_python(AsPython); + * + * \endcode + */ +#define TVM_REGISTER_INST_KIND(InstructionKindName) \ + TVM_STR_CONCAT(TVM_INST_KIND_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::tir::InstructionKindRegEntry::RegisterOrGet(InstructionKindName).set_name() + +/*! \brief An entry in the registry of InstructionKind */ +class InstructionKindRegEntry { + public: + static InstructionKindRegEntry& RegisterOrGet(const String& name); + + InstructionKindRegEntry& set_name() { + get_mutable()->name = this->name; + return *this; + } + + InstructionKindRegEntry& set_is_pure(bool is_pure) { + get_mutable()->is_pure = is_pure; + return *this; + } + + InstructionKindRegEntry& set_apply_to_schedule(FInstructionApply f_apply_to_schedule) { + get_mutable()->f_apply_to_schedule = std::move(f_apply_to_schedule); + return *this; + } + + InstructionKindRegEntry& set_as_python(FInstructionAsPython f_as_python) { + get_mutable()->f_as_python = std::move(f_as_python); + return *this; + } + + InstructionKindRegEntry& set_attrs_as_json(FInstructionAttrsAsJSON f_attrs_as_json) { + get_mutable()->f_attrs_as_json = std::move(f_attrs_as_json); + return *this; + } + + InstructionKindRegEntry& set_attrs_from_json(FInstructionAttrsFromJSON f_attrs_from_json) { + get_mutable()->f_attrs_from_json = std::move(f_attrs_from_json); + return *this; + } + + private: + /*! \brief Private constructor, used only by AttrRegistry */ + explicit InstructionKindRegEntry(uint32_t reg_index); + /*! \brief Get the mutable reference to the internal InstructionKind */ + InstructionKindNode* get_mutable() const { + return const_cast(inst_kind_.get()); + } + + /*! \brief The name of the registry entry */ + String name; + /*! \brief The instruction kind */ + InstructionKind inst_kind_; + template + friend class ::tvm::AttrRegistry; + friend class InstructionKind; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_INSTRUCTION_H_ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9a09d0ad211f..bd2377397626 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -180,7 +180,8 @@ class ScheduleNode : public runtime::Object { virtual void RemoveRV(const ExprRV& expr_rv) = 0; public: - /******** Block/Loop relation ********/ + /******** Schedule: Sampling ********/ + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name * \param name The name of the block to be retrieved @@ -195,8 +196,29 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; - /******** Schedule: loops manipulation ********/ - /******** Schedule: compute location ********/ + /******** Schedule: Transform loops ********/ + /*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The (i+1)-th loop must be the only child of the i-th loop. + * 3) All loops must start with 0. + * \param loop_rvs The loops to be fused + * \return The new loop after fusion + */ + virtual LoopRV Fuse(const Array& loop_rvs) = 0; + /*! + * \brief Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param loop_rv The loop to be split + * \param factors The tiling factors, and at most one of which is -1, which means that + * factor is inferred. + * \return The new loops after split + */ + virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; + /******** Schedule: Manipulate ForKind ********/ + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer @@ -220,10 +242,30 @@ class ScheduleNode : public runtime::Object { * \param block The block to be inlined to its producer */ virtual void ReverseComputeInline(const BlockRV& block) = 0; - /******** Schedule: loop binding/annotation ********/ - /******** Schedule: cache read/write ********/ - /******** Schedule: reduction ********/ - /******** Schedule: blockize & tensorize ********/ + /******** Schedule: Reduction ********/ + /*! + * \brief Factorize an associative reduction block by the specified loop. + * \details An associative reduction cannot be parallelized directly, + * because it leads to potential race condition during accumulation. + * Alternatively, the reduction could be factorized on a loop with the following steps: + * - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent + * - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; + * - Step 3: accumulate the `n` separate buffer into the result buffer. + * Note that the Step 2 above introduces opportunities for parallelization. + * RFactor is a schedule primitive that implements the transformation described above. + * \param loop_rv The loop outside block we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The rfactor block + */ + virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ + virtual void EnterPostproc() = 0; }; /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 83ac7150543f..077bf938f48a 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -190,14 +190,6 @@ class ScheduleState : public ObjectRef { * and each time after calling the Replace method. */ TVM_DLL explicit ScheduleState(IRModule mod, int debug_mode = 0); - /*! - * \brief Construct a schedule state from a PrimFunc - * \param func The PrimFunc to be scheduled. A new IRModule will be created with - * this specific PrimFunc as "main" function in the module to be scheduled - * \param debug_mode Do extra correctness checking after the class creation - * and each time after calling the Replace method. - */ - TVM_DLL explicit ScheduleState(PrimFunc func, int debug_mode = 0); /*! \return The mutable pointer to the ScheduleStateNode */ ScheduleStateNode* get() const { return static_cast(data_.get()); } diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h new file mode 100644 index 000000000000..b6b3b57226c8 --- /dev/null +++ b/include/tvm/tir/schedule/trace.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRACE_H_ +#define TVM_TIR_SCHEDULE_TRACE_H_ + +#include + +namespace tvm { +namespace tir { + +// Forward declaration +class Trace; + +/*! + * \brief A callback that allows users to mutate decisions on the fly + * when applying instructions. The signature of the callback is: + * \param inst The instruction + * \param inputs The input random variables + * \param attrs The attributes + * \param decision The original decision + * \return A new decision + */ +using FTraceDecisionProvider = runtime::TypedPackedFunc& inputs, const Array& attrs, + const Optional& decision)>; + +/*! + * \brief An execution trace of a scheduling program + * + * A trace has two parts: + * 1) The instructions invoked so far in the program execution + * 2) The random decisions made upon those instructions, if any + * + * A trace can be serialized to: + * 1) Roundtrippable JSON format: can be saved to file and loaded back + * 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process + * + * A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with + * their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its + * corresponding decision; Otherwise the existing decision will be reused accordingly. + */ +class TraceNode : public runtime::Object { + public: + /*! \brief The instructions invoked so far in the program execution */ + Array insts; + /*! \brief The random decisions made upon those instructions */ + Map decisions; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("insts", &insts); + v->Visit("decisions", &decisions); + } + + static constexpr const char* _type_key = "tir.Trace"; + TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object); + + public: + /*! + * \brief Retrieve the decision made on a specific instruction + * \param inst The instruction whose decision is to be retrieved + * \return The corresponding decision; NullOpt if there is no decision made on the instruction + */ + Optional GetDecision(const Instruction& inst) const; + /*! + * \brief Append a new instruction to the trace + * \param inst The new instruction to be appended + */ + void Append(Instruction inst); + /*! + * \brief Append a new instruction with a random decision to the trace + * \param inst The new instruction to be appended + * \param decision The random decision made on this instruction + * The type of `decision` depends on the instruction, e.g. + * the decision of `SamplePerfectTile` has type `Array` + */ + void Append(Instruction inst, ObjectRef decision); + /*! + * \brief Remove the last instruction, along with the decision made on that instruction, if any + * \return The instruction removed; NullOpt if the trace is empty + */ + Optional Pop(); + /*! + * \brief Apply the trace to a TensorIR schedule + * \param sch The schedule to be applied onto + * \param remove_postproc If postprocessing instructions are removed + * \param decision_provider A callback that allows users to mutate decisions on the fly + * when applying instructions. + * \sa FTraceDecisionProvider + */ + void ApplyToSchedule(Schedule sch, bool remove_postproc, + FTraceDecisionProvider decision_provider = nullptr) const; + /*! + * \brief Serialize the trace as a JSON-style object + * \param remove_postproc If postprocessing instructions are removed + * \return The JSON-style object + */ + ObjectRef AsJSON(bool remove_postproc) const; + /*! + * \brief Serialize the trace as a sequence of python statements + * \param remove_postproc If postprocessing instructions are removed + * \return A sequence of python statements + */ + Array AsPython(bool remove_postproc) const; + /*! + * \brief Create a new trace with an instruction whose decision is changed, + * assuming this instruction exists in the resulting trace + * \param inst The instruction whose decision is to be changed + * \param decision The decision to be changed to + * \param remove_postproc If postprocessing instructions are removed + * \return The new trace with the decision changed + */ + Trace WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const; + /*! + * \brief Simplify the trace with dead-code elimination + * \param remove_postproc If postprocessing instructions are removed + * \return A simplified trace + */ + Trace Simplified(bool remove_postproc) const; +}; + +/*! + * \brief Managed reference to TraceNode + * \sa TraceNode + */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing instructions and their decisions + * \param insts The instructions used + * \param decisions The decisions made in sampling + */ + explicit Trace(Array insts, Map decisions); + /*! + * \brief Apply a JSON-serialized trace to a TensorIR schedule + * \param json The JSON-serialized trace + * \param sch The TensorIR schedule + */ + static void ApplyJSONToSchedule(ObjectRef json, Schedule sch); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRACE_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..0da8e55be023 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -464,18 +464,22 @@ class ProducerRealizeNode : public StmtNode { PrimExpr condition; /*! \brief The body of realization. */ Stmt body; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; @@ -860,6 +865,7 @@ class For : public Stmt { Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); }; /*! @@ -1235,8 +1241,6 @@ constexpr const char* extern_scope = "extern_scope"; * This can hint some code generator to create a new function for compute. */ constexpr const char* compute_scope = "compute_scope"; -/*! \brief Mark storage scope of buffers */ -constexpr const char* storage_scope = "storage_scope"; /*! \brief Mark storage alignement requirement of buffers */ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark storage scope of realization */ @@ -1356,6 +1360,24 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); +// inline implementations +inline const char* ForKind2String(ForKind t) { + switch (t) { + case ForKind::kSerial: + return "serial"; + case ForKind::kParallel: + return "parallel"; + case ForKind::kVectorized: + return "vectorized"; + case ForKind::kUnrolled: + return "unroll"; + case ForKind::kThreadBinding: + return "thread_binding"; + } + LOG(FATAL) << "Unknown ForKind" << t; + return "Unknown"; +} + } // namespace tir } // namespace tvm #endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5ee847e2f010..d1308fe0059e 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -423,6 +423,12 @@ TVM_DLL Pass CompactBufferAllocation(); */ TVM_DLL Pass LegalizePackedCalls(); +/*! + * \brief Remove match buffers inside the block. Also, it will validate the binding. + * \return The pass. + */ +TVM_DLL Pass LowerMatchBuffer(); + /*! * \brief Flatten the multi-dimensional BufferLoad and BufferStore * to single dimensional Load/Store. Also remove Block to @@ -431,6 +437,11 @@ TVM_DLL Pass LegalizePackedCalls(); */ TVM_DLL Pass FlattenBuffer(); +/*! + * A pass to merge multiple TIR-level dynamic shared memory allocations into one + */ +TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index caca1e85e520..2561f8d1ca27 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -48,7 +48,7 @@ using namespace tvm::te; inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, -1, 0, kDefault); } /*! diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java index 9811ae19afd8..10df50237628 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/ConnectTrackerServerProcessor.java @@ -243,8 +243,8 @@ private boolean needRefreshKey() throws IOException { // handcrafted JSON private String generateCinfo(String key) { - String cinfo = "{\"key\" : " + "\"server:" + key + "\", \"addr\": [\"" - + trackerHost + "\", \"" + trackerPort + "\"]}"; + String cinfo = "{\"key\" : " + "\"server:" + key + "\", \"addr\": [null, \"" + + serverPort + "\"]}"; return "[" + RPC.TrackerCode.UPDATE_INFO + ", " + cinfo + "]"; } diff --git a/python/gen_requirements.py b/python/gen_requirements.py index dc338a3fcd3b..dac435146469 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -"""TVM Python requriements.txt generator. +"""TVM Python requirements.txt generator. This script generates a set of requirements.txt files (stored in `./requirements`) that describe TVM's Python dependencies. @@ -75,6 +75,16 @@ ], ), ), + # Provide support for Arm(R) Ethos(TM)-U NPU. + ( + "ethosu", + ( + "Requirements for using Arm(R) Ethos(TM)-U NPU", + [ + "ethos-u-vela", + ], + ), + ), # Relay frontends. ( "importer-caffe2", @@ -205,6 +215,7 @@ "docutils", "<0.17", ), # Work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 + ("ethos-u-vela", "==2.1.1"), ("future", None), ("image", None), ("matplotlib", None), diff --git a/python/setup.py b/python/setup.py index b47e5b14f6a7..dd13a12d8903 100644 --- a/python/setup.py +++ b/python/setup.py @@ -41,7 +41,7 @@ def get_lib_path(): """Get library path, name and version""" # We can not import `libinfo.py` in setup.py directly since __init__.py - # Will be invoked which introduces dependences + # Will be invoked which introduces dependencies libinfo_py = os.path.join(CURRENT_DIR, "./tvm/_ffi/libinfo.py") libinfo = {"__file__": libinfo_py} exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index d1e4431a2e0e..f5a0478dc008 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,9 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide +from .iter_affine_map import ( + detect_iter_map, + normalize_iter_map_to_expr, + subspace_divide, + inverse_affine_iter_map, +) diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index bfd5dfadc800..85513ecae5c4 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi Empty array if no match can be found. """ return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + + +def inverse_affine_iter_map(iter_map, outputs): + """Apply the inverse of the affine transformation to the outputs. + Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions + in reverse topology order and applies the inverse of the affine transformation until it reaches + the input. The affine iter map is required to be bijective. + + For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], + the affine transformation specified by `iter_map` will be applied to `outputs` and the result + will be {l0: ((output_0*16) + output_1)}. + + See also :any:`detect_iter_map`. + + Parameters + ---------- + iter_map : List[IterSumExpr] + The bijective affine iter map. + outputs : List[PrimExpr] + The outputs of the affine transformation. + + Returns + ------- + results : Map[Var, PrimExpr] + The map from the input to the transformed result. + """ + return _ffi_api.InverseAffineIterMap(iter_map, outputs) diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index c843dcfccdf0..cc1e76b9faa8 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -53,7 +53,7 @@ def __init__(self): def query(self, target, workload_key, has_complex_op, dag, func_name): """ Query the context to get the specific config for a workload. - If cannot find the result inside this context, this function will query it + If this function cannot find the result inside this context, it will query the result from the upper contexts. Parameters diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0d18bc08e5ed..850e50004337 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -150,7 +150,7 @@ def extract_tasks( # create search tasks tasks = [] weights = [] - for (func_name, wkl_key), weight in env.wkl_key_to_weight.items(): + for wkl_key, (weight, func_names) in env.wkl_key_to_weight.items(): tasks.append( SearchTask( workload_key=wkl_key, @@ -165,7 +165,7 @@ def extract_tasks( else None ), task_inputs_save_to_file=True, - desc=func_name, + desc=",".join(func_names), ) ) weights.append(weight) @@ -189,6 +189,7 @@ class TracingEnvironment: def __init__(self, tracing_mode): self.tracing_mode = tracing_mode self.relay_disable_build_cache = "false" + self.func_name_to_wkl_key = {} self.wkl_key_to_weight = {} self.wkl_key_to_input_names = {} @@ -210,10 +211,12 @@ def add_workload_key(self, func_name, workload_key): workload_key: str The workload key of a task. """ - key = (func_name, workload_key) - if key not in self.wkl_key_to_weight: - self.wkl_key_to_weight[key] = 0 - self.wkl_key_to_weight[key] += 1 + self.func_name_to_wkl_key[func_name] = workload_key + if workload_key not in self.wkl_key_to_weight: + self.wkl_key_to_weight[workload_key] = (0, set()) + weight, func_names = self.wkl_key_to_weight[workload_key] + func_names.add(func_name) + self.wkl_key_to_weight[workload_key] = (weight + 1, func_names) def add_workload_input_names(self, workload_key, input_names): """Add special task inputs to this workload. @@ -318,6 +321,7 @@ def auto_schedule_topi(func_name, outs): A tuned schedule or none (if not tuned) in the final build mode; None in the tracing mode so that the fallback topi schedule will be used. """ + # pylint: disable=import-outside-toplevel from tvm.auto_scheduler.measure import ( prepare_input_map, @@ -376,6 +380,41 @@ def auto_schedule_topi(func_name, outs): return schedule +@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights") +def te_compiler_update_weights(function_weights): + """A callback for updating the weights of extracted tasks. When using the TE compiler + that avoids compiling the same function multiple times by caching, all extracted tasks + have weight 1, so the TE compiler invokes this callback at the end. In this case, + we override existing weights with the use_count in TE compiler cache. + + Parameters + ---------- + function_weights: Dict[str, int] + Mapping from function names to their weights. + """ + env = TracingEnvironment.current + if env is not None: + # Override this map with the weights in the TE compiler. + env.wkl_key_to_weight = {} + + for func_name, weight in function_weights.items(): + # If the function name is not in the map, then it means we are not interested in + # this function during task extraction (e.g., a function without reduction). + if func_name not in env.func_name_to_wkl_key: + continue + + workload_key = env.func_name_to_wkl_key[func_name] + if workload_key not in env.wkl_key_to_weight: + env.wkl_key_to_weight[workload_key] = (0, set()) + + # Note that the function appears multiple times in a model will be renamed + # to make sure function names are unique, so we use the workload key generated + # from the function's TE compute to determine their weights. + old_weight, func_names = env.wkl_key_to_weight[workload_key] + func_names.add(func_name) + env.wkl_key_to_weight[workload_key] = (old_weight + weight, func_names) + + def tensor_no_check_call(self, *indices): """An indexing function without any check. This is the same as `tvm.te.Tensor::__call__` except that the safety diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index dd5073331083..9b975063105f 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id): # overall info if all(cost < 1e9 for cost in task_scheduler.best_costs): - total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3) else: total_latency_str = "-" print( @@ -629,7 +629,7 @@ def __init__(self, log_file): def post_tune(self, task_scheduler, task_id): if all(cost < 1e9 for cost in task_scheduler.best_costs): - total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3) + total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3) else: total_latency_str = "N/A" diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 3de25cb6100b..db4ff26857bd 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -32,6 +32,7 @@ import typing from collections import namedtuple from random import getrandbits +import warnings import tvm._ffi import tvm.ir.transform @@ -235,6 +236,7 @@ def __init__( self.number = number self.repeat = repeat self.min_repeat_ms = min_repeat_ms + self._ref_input = None self.enable_cpu_cache_flush = enable_cpu_cache_flush self.cooldown_interval = cooldown_interval @@ -242,6 +244,25 @@ def __init__( self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1)) + @property + def ref_input(self): + """ + Fixed input for tuning special operators, e.g., sparse operators + requiring indices as input. + """ + return self._ref_input + + @ref_input.setter + def ref_input(self, val): + warnings.warn( + "You are specifying fixed input for tuning the operator. " + "Be sure your input always fits the operator. Some " + "operators may conduct layout transformation during tuning, " + "thus can lead to unexpected behaviors. ", + RuntimeWarning, + ) + self._ref_input = val + def set_task(self, task): self.task = task @@ -308,6 +329,7 @@ def run(self, measure_inputs, build_results): self.min_repeat_ms, self.cooldown_interval, remote_kwargs, + self.ref_input, self.enable_cpu_cache_flush, module_loader, ) @@ -508,6 +530,7 @@ def run_through_rpc( min_repeat_ms, cooldown_interval, remote_kwargs, + ref_input, enable_cpu_cache_flush=False, module_loader=None, ): @@ -539,6 +562,8 @@ def run_through_rpc( The cool down interval between two measurements remote_kwargs: dict Passed to module_loader(). Ultimately, keyword args to request_remote(). + ref_input: List of np.ndarray + The reference input used for tuning. Empty for randomly filled input. enable_cpu_cache_flush: bool Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to @@ -573,18 +598,22 @@ def run_through_rpc( f_preproc=f_prepare, ) - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" - ) - args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info] - if "scatter" not in measure_input.task.name: - # the index tensor of scatter op cannot be randomly initialized - for arg in args: - random_fill(arg) - dev.sync() + if ref_input: + args = [nd.array(x, device=dev) for x in ref_input] + else: + try: + random_fill = remote.get_function("tvm.contrib.random.random_fill") + except AttributeError: + raise AttributeError( + "Please make sure USE_RANDOM is ON in the config.cmake " + "on the remote devices" + ) + args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info] + if "scatter" not in measure_input.task.name: + # the index tensor of scatter op cannot be randomly initialized + for arg in args: + random_fill(arg) + dev.sync() costs = time_f(*args).results diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 1f5827d7e9d0..ee1750896fca 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -40,11 +40,11 @@ def _lookup_task(name): task = TASK_TABLE.get(name) if task is None: - raise RuntimeError( - f"Could not find a registered function for the task {name}. It is " - "possible that the function is registered in a python file which was " - "not imported in this run." - ) + # Unable to find the given task. This might be because we are + # creating a task based on a name that has not been imported. + # Rather than raising an exception here, we return a dummy + # task which cannot be invoked. + task = MissingTask(name) return task @@ -61,7 +61,7 @@ def _encode(x): return ("TENSOR", get_const_tuple(x.shape), x.dtype) if isinstance(x, (tuple, list, container.Array)): return tuple([_encode(a) for a in x]) - if isinstance(x, (str, int, float, np.int, np.float, expr.Var, expr.Any)): + if isinstance(x, (str, int, float, expr.Var, expr.Any)): return x if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): return x.value @@ -264,6 +264,25 @@ def _get_inputs(out): return inputs +class MissingTask(TaskTemplate): + """ + Dummy task template for a task lookup which cannot be resolved. + This can occur if the task being requested from _lookup_task() + has not been imported in this run. + """ + + def __init__(self, taskname: str): + super().__init__() + self._taskname = taskname + + def __call__(self, *args, **kwargs): + raise RuntimeError( + f"Attempting to invoke a missing task {self._taskname}." + "It is possible that the function is registered in a " + "Python module that is not imported in this run, or the log is out-of-date." + ) + + def _register_task_compute(name, func=None): """Register compute function to autotvm task diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 5c0a46336532..f438bc197afe 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -178,7 +178,7 @@ def download_package(tophub_location, package_name): download_url = "{0}/{1}".format(tophub_location, package_name) logger.info("Download pre-tuned parameters package from %s", download_url) - download(download_url, Path(rootpath, package_name), True, verbose=0) + download(download_url, Path(rootpath, package_name), overwrite=True) # global cache for load_reference_log diff --git a/python/tvm/autotvm/tuner/xgboost_tuner.py b/python/tvm/autotvm/tuner/xgboost_tuner.py index 2f4d0ee88ce9..9dec54c2d5f7 100644 --- a/python/tvm/autotvm/tuner/xgboost_tuner.py +++ b/python/tvm/autotvm/tuner/xgboost_tuner.py @@ -55,7 +55,9 @@ class XGBTuner(ModelBasedTuner): The cost model predicts relative rank score. num_threads: int, optional - The number of threads. optimizer: str or ModelOptimizer, optional + The number of threads. + + optimizer: str or ModelOptimizer, optional If is 'sa', use a default simulated annealing optimizer. Otherwise it should be a ModelOptimizer object. diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index dc043353c475..622f27c358b6 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -268,23 +268,28 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0): ret = self._run_individual(number, repeat, min_repeat_ms) return ret.strip(",").split(",") if ret else [] - def profile(self, **input_dict): + def profile(self, collectors=None, **input_dict): """Run forward execution of the graph and collect overall and per-op performance metrics. Parameters ---------- + collectors : Optional[Sequence[MetricCollector]] + Extra metrics to collect. + input_dict : dict of str to NDArray List of input values to be feed to + Return ------ timing_results : str Per-operator and whole graph timing results in a table format. """ + collectors = [] if collectors is None else collectors if input_dict: self.set_input(**input_dict) - return self._profile() + return self._profile(collectors) def exit(self): """Exits the dump folder and all its contents""" diff --git a/python/tvm/contrib/download.py b/python/tvm/contrib/download.py index f7c68a99229e..e0c13acc8de9 100644 --- a/python/tvm/contrib/download.py +++ b/python/tvm/contrib/download.py @@ -15,15 +15,18 @@ # specific language governing permissions and limitations # under the License. """Helper utility for downloading""" + +import logging +import os from pathlib import Path -from os import environ -import sys -import time -import uuid import shutil +import tempfile +import time + +LOG = logging.getLogger("download") -def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3): +def download(url, path, overwrite=False, size_compare=False, retries=3): """Downloads the file from the internet. Set the input options correctly to overwrite or do the size comparison @@ -33,19 +36,18 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= Download url. path : str - Local file path to save downloaded file + Local file path to save downloaded file. overwrite : bool, optional - Whether to overwrite existing file + Whether to overwrite existing file, defaults to False. size_compare : bool, optional - Whether to do size compare to check downloaded file. - - verbose: int, optional - Verbose level + Whether to do size compare to check downloaded file, defaults + to False retries: int, optional - Number of time to retry download, default at 3. + Number of time to retry download, defaults to 3. + """ # pylint: disable=import-outside-toplevel import urllib.request as urllib2 @@ -62,21 +64,19 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries= res_get = urllib2.urlopen(url) url_file_size = int(res_get.headers["Content-Length"]) if url_file_size != file_size: - print("exist file got corrupted, downloading %s file freshly..." % path) - download(url, path, True, False) + LOG.warning("Existing file %s has incorrect size, downloading fresh copy", path) + download(url, path, overwrite=True, size_compare=False, retries=retries) return - print("File {} exists, skip.".format(path)) + + LOG.info("File %s exists, skipping.", path) return - if verbose >= 1: - print("Downloading from url {} to {}".format(url, path)) + LOG.info("Downloading from url %s to %s", url, path) # Stateful start time start_time = time.time() dirpath = path.parent dirpath.mkdir(parents=True, exist_ok=True) - random_uuid = str(uuid.uuid4()) - tempfile = Path(dirpath, random_uuid) def _download_progress(count, block_size, total_size): # pylint: disable=unused-argument @@ -84,44 +84,53 @@ def _download_progress(count, block_size, total_size): if count == 0: return duration = time.time() - start_time - progress_size = int(count * block_size) - speed = int(progress_size / (1024 * duration)) + progress_bytes = int(count * block_size) + progress_megabytes = progress_bytes / (1024.0 * 1024) + speed_kbps = int(progress_bytes / (1024 * duration)) percent = min(int(count * block_size * 100 / total_size), 100) - sys.stdout.write( - "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" - % (percent, progress_size / (1024.0 * 1024), speed, duration) + + # Temporarily suppress newlines on the output stream. + prev_terminator = logging.StreamHandler.terminator + logging.StreamHandler.terminator = "" + LOG.debug( + "\r...%d%%, %.2f MB, %d KB/s, %d seconds passed", + percent, + progress_megabytes, + speed_kbps, + duration, ) - sys.stdout.flush() - - while retries >= 0: - # Disable pyling too broad Exception - # pylint: disable=W0703 - try: - if sys.version_info >= (3,): - urllib2.urlretrieve(url, tempfile, reporthook=_download_progress) - print("") - else: - f = urllib2.urlopen(url) - data = f.read() - with open(tempfile, "wb") as code: - code.write(data) - shutil.move(tempfile, path) - break - except Exception as err: - retries -= 1 - if retries == 0: - if tempfile.exists(): - tempfile.unlink() - raise err - print( - "download failed due to {}, retrying, {} attempt{} left".format( - repr(err), retries, "s" if retries > 1 else "" + logging.StreamHandler.terminator = prev_terminator + + with tempfile.TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) + download_loc = tempdir.joinpath(path.name) + + for i_retry in range(retries): + # pylint: disable=broad-except + try: + + urllib2.urlretrieve(url, download_loc, reporthook=_download_progress) + LOG.debug("") + try: + download_loc.rename(path) + except OSError: + # Prefer a move, but if the tempdir and final + # location are in different drives, fall back to a + # copy. + shutil.copy2(download_loc, path) + return + + except Exception as err: + if i_retry == retries - 1: + raise err + + LOG.warning( + "%s\nDownload attempt %d/%d failed, retrying.", repr(err), i_retry, retries ) - ) -if "TEST_DATA_ROOT_PATH" in environ: - TEST_DATA_ROOT_PATH = Path(environ.get("TEST_DATA_ROOT_PATH")) +if "TEST_DATA_ROOT_PATH" in os.environ: + TEST_DATA_ROOT_PATH = Path(os.environ.get("TEST_DATA_ROOT_PATH")) else: TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data") TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True) @@ -141,10 +150,16 @@ def download_testdata(url, relpath, module=None, overwrite=False): module : Union[str, list, tuple], optional Subdirectory paths under test data folder. + overwrite : bool, defaults to False + If True, will download a fresh copy of the file regardless of + the cache. If False, will only download the file if a cached + version is missing. + Returns ------- abspath : str Absolute file path of downloaded file + """ global TEST_DATA_ROOT_PATH if module is None: diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index 34b37537776f..6364ef749dd9 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -176,23 +176,26 @@ def buf_align(var): def visit(stmt): """Collect information about VTCM buffers and their alignments.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.append(stmt.node) - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": if not stmt.node in alignments: alignments[stmt.node] = [] alignments[stmt.node].append(stmt.value) + elif isinstance(stmt, tvm.tir.Allocate): + scope = stmt.buffer_var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.append(stmt.buffer_var) def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.pop() - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": alignments[stmt.node].pop() return stmt if isinstance(stmt, tvm.tir.Allocate): var = stmt.buffer_var + scope = var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.pop() if var in vtcm_buffers: is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var) throw_error = tvm.tir.call_intrin( diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 112fc320973b..0e336c1c82b9 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -136,3 +136,55 @@ def conv2d_forward( ), name="y", ) + + +def softmax(x, axis=-1): + """Compute softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) + + +def log_softmax(x, axis=-1): + """Compute log softmax with MIOpen + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute log softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.miopen.log_softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index b05265fa976a..b839af669fe6 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -24,6 +24,7 @@ import onnx import onnx.utils from onnx import numpy_helper, OperatorSetIdProto, defs +from onnx import TensorProto import tvm from tvm import relay import tvm._ffi @@ -138,6 +139,21 @@ def convert_attributes(cls, attrs): } +class ConvTranspose(OpConverter): + """Operator converter for ConvTranspose.""" + + @classmethod + def convert_attributes(cls, attrs): + return { + "group": attrs.get_int("groups"), + "pads": attrs.get_int_tuple("padding"), + "strides": attrs.get_int_tuple("strides"), + "dilations": attrs.get_int_tuple("dilation"), + "kernel_shape": attrs.get_int_tuple("kernel_size"), + "output_padding": attrs.get_int_tuple("output_padding"), + } + + class MaxPool(OpConverter): """Operator converter for MaxPool.""" @@ -147,6 +163,7 @@ def convert_attributes(cls, attrs): "pads": attrs.get_int_tuple("padding"), "strides": attrs.get_int_tuple("strides"), "kernel_shape": attrs.get_int_tuple("pool_size"), + "ceil_mode": 1 if attrs.ceil_mode else 0, } @@ -330,7 +347,10 @@ def convert_attributes(cls, attrs): after.append(axis_pads[1]) pads = before + after pads = numpy.asarray(pads, dtype=pads[0].dtype) - return {"pads": pads, "mode": attrs.get_str("pad_mode"), "constant_value": attrs.pad_value} + return { + "pads": pads, + "mode": attrs.get_str("pad_mode"), + } @classmethod def convert(cls, node_entry, model_container, node_dict): @@ -341,16 +361,17 @@ def convert(cls, node_entry, model_container, node_dict): attrs = cls.convert_attributes(node_entry["relay_node"].attrs) name = node_entry["name"] - data = numpy.asarray(attrs["pads"], dtype=attrs["pads"][0].dtype).astype(numpy.int64) - value = numpy.dtype(node_entry["types"][0].dtype).type(attrs["constant_value"]) + pad_data = numpy.asarray(attrs["pads"], dtype=attrs["pads"][0].dtype).astype(numpy.int64) input_names = [ node_entry["input_names"][0], - add_input(data, name, "pads", model_container), - add_input(value, name, "value", model_container), + add_input(pad_data, name, "pads", model_container), + node_entry["input_names"][1], ] - node = onnx.helper.make_node(cls.__name__, input_names, node_entry["output_names"]) + node = onnx.helper.make_node( + cls.__name__, input_names, node_entry["output_names"], mode=attrs["mode"] + ) model_container.add_nodes([node]) @@ -633,9 +654,108 @@ def convert_attributes(cls, attrs): return {"alpha": attrs.alpha, "beta": attrs.beta, "bias": attrs.bias, "size": attrs.size} +class Cast(OpConverter): + """ Operator converter for Cast.""" + + @classmethod + def convert_attributes(cls, attrs): + return {"to": getattr(TensorProto, attrs.dtype.upper())} + + +class Resize(OpConverter): + """Operator converter for Resize.""" + + @classmethod + def convert_attributes(cls, attrs): + method = attrs.get_str("method") + if method == "nearest_neighbor": + mode = b"nearest" + elif "linear" in method: # linear / bilinear + mode = b"linear" + elif "cubic" in method: # cubic / bicubic + mode = b"cubic" + else: + raise RuntimeError("Unsupported method %s in operator Resize" % method) + + coord_trans = attrs.get_str("coordinate_transformation_mode") + if coord_trans == "half_pixel": + coord_trans = b"half_pixel" + elif coord_trans == "align_corners": + coord_trans = b"align_corners" + elif coord_trans == "asymmetric": + coord_trans = b"asymmetric" + else: + raise RuntimeError( + "Unsupported coordinate transform mode %s in operator Resize" % coord_trans + ) + + rounding_method = attrs.get_str("rounding_method") + if rounding_method == "round": + rounding_method = b"round_prefer_ceil" + elif rounding_method == "floor": + rounding_method = b"floor" + elif rounding_method == "ceil": + rounding_method = b"ceil" + else: + raise RuntimeError( + "Unsupported rounding method %s in operator Resize" % rounding_method + ) + + size = attrs.get_int_tuple("size") + + return { + "mode": mode, + "coord_trans": coord_trans, + "size": size, + "nearest_mode": rounding_method, + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry["relay_node"].attrs) + + name = node_entry["name"] + input_node = node_dict[node_entry["inputs"][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + input_shape = input_node["types"][0].shape + + # (TBD) needed in opset 11 + roi = [0] * len(input_shape) + [1] * len(input_shape) + roi_array = numpy.asarray(roi).astype(numpy.float64) + roi_node = add_input(roi_array, name, "roi", model_container) + + out_size = attrs["size"] + + # (onnx) rank of scale / size must match rank of X + # relay size node contains only spatial dimensions + # pad with 1s to match rank + match_rank_pad = len(input_shape) - len(out_size) + out_size_full_rank = input_shape[:match_rank_pad] + list(out_size) + out_size_array = numpy.asarray(out_size_full_rank).astype(numpy.int64) + + input_size_array = numpy.asarray(list(input_shape)).astype(numpy.int64) + + scale_array = numpy.divide(out_size_array, input_size_array).astype(numpy.float32) + scale_node = add_input(scale_array, name, "scales", model_container) + + input_names = [node_entry["input_names"][0], roi_node, scale_node] + + resize_node = onnx.helper.make_node( + cls.__name__, + input_names, + node_entry["output_names"], + mode=attrs["mode"], + coordinate_transformation_mode=attrs["coord_trans"], + nearest_mode=attrs["nearest_mode"], + ) + model_container.add_nodes([resize_node]) + + relay_to_onnx_op_mapping = { "reshape": Reshape, "nn.conv2d": Conv, + "nn.conv2d_transpose": ConvTranspose, "add": rename("Add"), "nn.relu": rename("Relu"), "transpose": Transpose, @@ -667,6 +787,11 @@ def convert_attributes(cls, attrs): "clip": Clip, "expand_dims": Expand, "nn.lrn": LRN, + "sigmoid": rename("Sigmoid"), + "copy": rename("Identity"), + "round": rename("Round"), + "cast": Cast, + "image.resize2d": Resize, } diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 6451896c6bd1..68c6b3d5bf6b 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -19,6 +19,7 @@ import contextlib import datetime import os +import pathlib import tempfile import threading import shutil @@ -119,6 +120,18 @@ def remove(self): self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None + @property + def path(self): + return pathlib.Path(self.temp_dir) + + def __div__(self, other): + if not isinstance(other, (str, pathlib.Path)): + raise TypeError( + "TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,) + ) + + return self.path / other + def __del__(self): temp_dirs = getattr(self, "TEMPDIRS", None) if temp_dirs is None: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a9e07299f6dd..a7ebc00c315f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -24,6 +24,7 @@ import tvm.tir +from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container from tvm.ir import CallingConv @@ -106,7 +107,7 @@ def lower( It should be None if we want to lower TensorIR. name : str - The name of result function. + The name of the result function. binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] Dictionary that maps the Tensor to Buffer which specified the data layout @@ -160,7 +161,10 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - opt_mixed = [tvm.tir.transform.VerifyMemory()] + opt_mixed = [ + tvm.tir.transform.VerifyMemory(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] @@ -372,12 +376,32 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - return create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) - if target_host.kind.name == "llvm": + elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - return create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + else: + to_return = rt_mod_host + + return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) + + +class OperatorModule(Module): + """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" + + @classmethod + def from_module(cls, mod, **kwargs): + # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. + # If an exception occurs in cls.__init__, handle will be deleted. For this reason, + # set mod.handle to None. + handle = mod.handle + mod.handle = None + return cls(handle, **kwargs) - return rt_mod_host + def __init__(self, handle, ir_module_by_target=None, name=None): + super(OperatorModule, self).__init__(handle) + self.ir_module_by_target = ir_module_by_target + self.name = name diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 033522d0e81a..15c09753d46f 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -44,14 +44,14 @@ def convert_graph_layout(mod, desired_layout): Parameters ---------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module to convert. desired_layout : str The layout to convert to. Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The converted module. """ @@ -396,7 +396,7 @@ def parse_shape_string(inputs_string): """ # Create a regex pattern that extracts each separate input mapping. - pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + pattern = r"(?:\w+\/)?\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index ceee5ccd7266..928259e30f0c 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -68,7 +68,7 @@ def load(self, path, shape_dict=None, **kwargs): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The produced relay module. params : dict The parameters (weights) for the relay module. diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 8c8828ddd49b..7dc3fd4cdd36 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -336,8 +336,8 @@ def import_package(self, package_path: str): with open(temp.relpath("metadata.json")) as metadata_json: metadata = json.load(metadata_json) - is_graph_runtime = "graph" in metadata["runtimes"] - graph = temp.relpath("runtime-config/graph/graph.json") if is_graph_runtime else None + has_graph_executor = "graph" in metadata["executors"] + graph = temp.relpath("executor-config/graph/graph.json") if has_graph_executor else None params = temp.relpath("parameters/default.params") self.type = "mlf" diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b4cc4421b169..83557a3eae19 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -21,6 +21,7 @@ from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType +from .affine_type import TensorAffineType, TupleAffineType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range from .op import Op, register_op_attr, register_intrin_lowering diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py new file mode 100644 index 000000000000..a1ce08017b1b --- /dev/null +++ b/python/tvm/ir/affine_type.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Types for quantized Tensors.""" +import tvm._ffi + +from .base import Node +from . import _ffi_api + + +class AffineType(Node): + """The base class of Affine Types.""" + + def __eq__(self, other): + """Compare two types for structural equivalence.""" + return bool(tvm.ir.structural_equal(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + +@tvm._ffi.register_object("TensorAffineType") +class TensorAffineType(AffineType): + """The quantized type of a tensor, with scale, zero point, and datatype + + The real space value is calculated as x = x_q * scale + zero_point + + Parameters + ---------- + scale: Expr + The scale + + zero_point: Expr + The zero_point + + dtype : str + The content data type. + """ + + def __init__(self, scale, zero_point, dtype): + self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + + +@tvm._ffi.register_object("TupleAffineType") +class TupleAffineType(AffineType): + """Affine types of a node with multiple outputs + + Parameters + ---------- + types : List[TensorAffineType] + The shape of the Tensor + + """ + + def __init__(self, types): + self.__init_handle_by_constructor__(_ffi_api.TupleAffineType, types) diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index c322f2bef3fc..1948a6787eac 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -30,11 +30,8 @@ class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. Users don't need to interact with this class directly. - Instead, a `PassInstrument` instance should be created through `pass_instrument`. - - See Also - -------- - `pass_instrument` + Instead, a `PassInstrument` instance should be created through + :py:func:`pass_instrument` """ @@ -91,13 +88,14 @@ def pass_instrument(pi_cls=None): Parameters ---------- - pi_class : + pi_class : class + Instrument class. See example below. Examples -------- - The following code block decorates a pass instrument class. .. code-block:: python + @tvm.instrument.pass_instrument class SkipPass: def __init__(self, skip_pass_name): @@ -155,5 +153,17 @@ def render(): ------- string : string The rendered string result of time profiles + + Examples + -------- + + .. code-block:: python + + timing_inst = PassTimingInstrument() + with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() """ return _ffi_instrument_api.RenderTimePassProfiles() diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 9296244f6cfe..93aae45930e3 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -107,8 +107,8 @@ def __exit__(self, ptype, value, trace): def override_instruments(self, instruments): """Override instruments within this PassContext. - If there are existing instruments, their exit_pass_ctx callbacks are called. - Then switching to new instruments and calling new enter_pass_ctx callbacks. + If there are existing instruments, their ``exit_pass_ctx`` callbacks are called. + Then switching to new instruments and calling new ``enter_pass_ctx`` callbacks. instruments : Sequence[PassInstrument] The list of pass instrument implementations. diff --git a/python/tvm/micro/contrib/zephyr.py b/python/tvm/micro/contrib/zephyr.py index 3c79c200d155..77cfb8d09bf2 100644 --- a/python/tvm/micro/contrib/zephyr.py +++ b/python/tvm/micro/contrib/zephyr.py @@ -406,6 +406,7 @@ def _get_nrf_device_args(self): # kwargs passed to usb.core.find to find attached boards for the openocd flash runner. BOARD_USB_FIND_KW = { + "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, } diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py new file mode 100644 index 000000000000..915bee08175c --- /dev/null +++ b/python/tvm/micro/interface_api.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines functions for generating a C interface header""" + +import os + +from tvm.relay.backend.utils import mangle_module_name + + +def _emit_brief(header_file, module_name, description): + header_file.write("/*!\n") + header_file.write(f' * \\brief {description} for TVM module "{module_name}" \n') + header_file.write(" */\n") + + +def generate_c_interface_header(module_name, inputs, outputs, output_path): + """Generates a C interface header for a given modules inputs and outputs + + Parameters + ---------- + module_name : str + Name of the module to be used in defining structs and naming the header + inputs : list[str] + List of module input names to be placed in generated structs + outputs : list[str] + List of module output names to be placed in generated structs + output_path : str + Path to the output folder to generate the header into + """ + + mangled_name = mangle_module_name(module_name) + metadata_header = os.path.join(output_path, f"{mangled_name}.h") + with open(metadata_header, "w") as header_file: + header_file.write( + "#include \n" + f"#ifndef {mangled_name.upper()}_H_\n" + f"#define {mangled_name.upper()}_H_\n" + ) + + _emit_brief(header_file, module_name, "Input tensor pointers") + header_file.write(f"struct {mangled_name}_inputs {{\n") + for input_name in inputs: + header_file.write(f" void* {input_name};\n") + header_file.write("};\n\n") + + _emit_brief(header_file, module_name, "Output tensor pointers") + header_file.write(f"struct {mangled_name}_outputs {{\n") + for output_name in outputs: + header_file.write(f" void* {output_name};\n") + header_file.write("};\n\n") + + header_file.write( + "/*!\n" + f' * \\brief entrypoint function for TVM module "{module_name}"\n' + " * \\param inputs Input tensors for the module \n" + " * \\param outputs Output tensors for the module \n" + " */\n" + f"int32_t {mangled_name}_run(\n" + f" struct {mangled_name}_inputs* inputs,\n" + f" struct {mangled_name}_outputs* outputs\n" + ");\n" + ) + + header_file.write(f"#endif // {mangled_name.upper()}_H_\n") diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 7062b20e0d54..5e682c72ed73 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -20,12 +20,20 @@ import datetime import json import os +import pathlib import re import tarfile +import typing +from tvm.ir.type import TupleType +from .._ffi import get_global_func +from .interface_api import generate_c_interface_header from ..contrib import utils +from ..driver import build_module +from ..runtime import ndarray as _nd from ..relay.backend import executor_factory from ..relay import param_dict +from ..tir import expr # This should be kept identical to runtime::symbol::tvm_module_main MAIN_FUNC_NAME_STR = "__tvm_main__" @@ -49,7 +57,6 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """ dso_modules = mod._collect_dso_modules() - dso_module_handles = [m.handle.value for m in dso_modules] non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules) if non_dso_modules: raise UnsupportedInModelLibraryFormatError( @@ -207,67 +214,246 @@ def _build_function_memory_map(function_metadata): return ret -def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name): - """Export the build artifact in Model Library Format. +def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule): + main_func = mod.function_metadata[MAIN_FUNC_NAME_STR] + target = list(main_func.relay_primfuncs.keys())[0] + return main_func.relay_primfuncs[target] - This function creates a .tar archive containing the build artifacts in a standardized - layout. It's intended to allow downstream automation to build TVM artifacts against the C - runtime. + +def _convert_tuple_to_outputs(ret_type, offset=0): + outputs = [] + added_fields = len(ret_type.fields) + for output_index in range(added_fields): + next_output = offset + len(outputs) + if isinstance(ret_type.fields[output_index], TupleType): + outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output)) + else: + outputs.append(f"output{next_output}") + return outputs + + +def _get_inputs_and_outputs_from_module(mod): + main_func = _get_main_relay_func(mod) + inputs = [argument.name_hint for argument in main_func.params] + + outputs = ["output"] + if isinstance(main_func.ret_type, TupleType): + outputs = _convert_tuple_to_outputs(main_func.ret_type) + + return inputs, outputs + + +def _should_generate_interface_header(mod): + return any(target.attrs.get("interface-api") == "c" for target in mod.target.values()) + + +def _make_tar(source_dir, tar_file_path): + """Build a tar file from source_dir.""" + with tarfile.open(tar_file_path, "w") as tar_f: + + def reset(tarinfo): + tarinfo.uid = tarinfo.gid = 0 + tarinfo.uname = tarinfo.gname = "root" + return tarinfo + + tar_f.add(str(source_dir), arcname=".", filter=reset) + + +_GENERATED_VERSION = 5 + + +def _export_graph_model_library_format( + mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path +): + """Export a tvm.relay.build artifact in Model Library Format. Parameters ---------- mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule The return value of tvm.relay.build, which will be exported into Model Library Format. - file_name : str - Path to the .tar archive to generate. - - Returns - ------- - file_name : str - The path to the generated .tar archive. + tempdir : pathlib.Path + Temporary directory to populate with Model Library Format contents. """ - tempdir = utils.tempdir() is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) - runtime = ["aot"] if is_aot else ["graph"] + executor = ["aot"] if is_aot else ["graph"] metadata = { - "version": 3, + "version": _GENERATED_VERSION, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), "target": {int(k): str(v) for k, v in mod.target.items()}, - "runtimes": runtime, + "executors": executor, + "style": "full-model", } - with open(tempdir.relpath("metadata.json"), "w") as json_f: + with open(tempdir / "metadata.json", "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) - codegen_dir_path = tempdir.relpath("codegen") - os.mkdir(codegen_dir_path) - _populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name) + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name) - parameters_dir_path = tempdir.relpath("parameters") - os.mkdir(parameters_dir_path) - param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params") + if _should_generate_interface_header(mod): + include_path = codegen_dir / "host" / "include" + include_path.mkdir() + inputs, outputs = _get_inputs_and_outputs_from_module(mod) + generate_c_interface_header(mod.libmod_name, inputs, outputs, include_path) + + parameters_dir = tempdir / "parameters" + parameters_dir.mkdir() + param_filename = parameters_dir / f"{mod.libmod_name}.params" with open(param_filename, "wb") as f: f.write(param_dict.save_param_dict(mod.params)) - with open(tempdir.relpath("relay.txt"), "w") as f: + src_dir = tempdir / "src" + src_dir.mkdir() + with open(src_dir / "relay.txt", "w") as f: f.write(str(mod.ir_mod)) if not is_aot: - graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) - os.makedirs(graph_config_dir_path) - with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + graph_config_dir = tempdir / "executor-config" / "graph" + graph_config_dir.mkdir(parents=True) + with open(graph_config_dir / "graph.json", "w") as f: f.write(mod.get_executor_config()) - with tarfile.open(file_name, "w") as tar_f: - def reset(tarinfo): - tarinfo.uid = tarinfo.gid = 0 - tarinfo.uname = tarinfo.gname = "root" - return tarinfo +class NonStaticShapeError(Exception): + """Raised when a shape has elements other than IntImm.""" + + +def _shape_to_size(shape, dtype): + bits_per_item = int( + re.match(r"((float)|(int))(?P[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + total_bits = bits_per_item + for s in shape: + total_bits *= s + + return (total_bits + 7) // 8 + + +def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target): + def _eval_shape(param_name, buffer_shape): + shape = [] + for x in buffer_shape: + if not isinstance(x, expr.IntImm): + raise NonStaticShapeError( + f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}" + ) + shape.append(x.value) + return shape + + memory_map = {} + for target_device_type, target in targets.items(): + ir_mod = ir_module_by_target[target] + printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) + with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: + f.write(printer["print"](ir_mod)) + + for v in ir_mod.get_global_vars(): + map_entry = [] + for p, b in ir_mod[v.name_hint].buffer_map.items(): + shape = _eval_shape(p.name, b.shape) + buffer_size_bytes = _shape_to_size(shape, str(b.dtype)) + # NOTE: cannot tell what is an input or output at this point. + map_entry.append( + { + "size_bytes": buffer_size_bytes, + "shape": [int(x) for x in b.shape], + "dtype": b.dtype, + "input_binding": printer["get_var_name"](p), + } + ) + memory_map[v.name_hint] = map_entry + + return memory_map + + +def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): + """Export the result of tvm.build() in Model Library Format. + + Parameters + ---------- + mod : runtime.Module + The Module returned from tvm.build(). + args : list of Buffer or Tensor or Var, optional + The args supplied to tvm.build(). + file_name : str + Path to the .tar archive to generate. + """ + targets = {} + for target in mod.ir_module_by_target.keys(): + if str(target.kind) not in ("llvm", "c"): + raise UnsupportedInModelLibraryFormatError( + f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in " + "Model Library Format" + ) + + targets[int(_nd.device(str(target)).device_type)] = target + + src_dir = tempdir / "src" + src_dir.mkdir() + memory_map = _write_tir_and_build_operator_memory_map(src_dir, targets, mod.ir_module_by_target) + + metadata = { + "version": _GENERATED_VERSION, + "model_name": mod.name, + "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), + "memory": memory_map, + "target": {k: str(v) for k, v in targets.items()}, + "executors": [], + "style": "operator", + } + with open(tempdir / "metadata.json", "w") as metadata_f: + json.dump(metadata, metadata_f) + + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod, codegen_dir) + + +ExportableModule = typing.Union[ + build_module.OperatorModule, + executor_factory.AOTExecutorFactoryModule, + executor_factory.GraphExecutorFactoryModule, +] + + +def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): + """Export the build artifact in Model Library Format. + + This function creates a .tar archive containing the build artifacts in a standardized + layout. It's intended to allow downstream automation to build TVM artifacts against the C + runtime. + + Parameters + ---------- + mod : ExportableModule + The return value of tvm.build or tvm.relay.build. + file_name : str + Path to the .tar archive to generate. + + Returns + ------- + file_name : str + The path to the generated .tar archive. + """ + file_name = pathlib.Path(file_name) + + tempdir = utils.tempdir() - tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) + if isinstance(mod, build_module.OperatorModule): + _export_operator_model_library_format(mod, tempdir.path) + elif isinstance( + mod, + (executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule), + ): + _export_graph_model_library_format(mod, tempdir.path) + else: + raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") + + _make_tar(tempdir.path, file_name) return file_name diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 661d7523ad77..a8f1a993552e 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -370,7 +370,7 @@ def extract_fused_functions(mod): Parameters ---------- - mod : tvm.relay.IRModule + mod : tvm.IRModule Returns ------- diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 2db8c5a669f0..e9129db7b200 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -429,7 +429,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" @@ -444,7 +444,7 @@ def dump(self): res += "------------------------------------\n" res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.func_name) + res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) res += "----relay function----\n" res += k.source_func.astext() + "\n" res += "----tir function----- \n" diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index 11274b97197f..58717a0ab482 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -20,14 +20,14 @@ The compiler is built from a few pieces. First we define a compiler from a single Relay expression to the -graph langauge. We require the expression to be a function. +graph language. We require the expression to be a function. The function's parameters correspond to the placeholder/inputs and model parameters found in the computation graph representation. The body of the function represents the computation graph. The compiler's output is a program in the graph language, which is composed of -graph langauge is composed of Node, NodeRef, InputNode, OpNode. -This "little language" represents programs in TVM's graph format. +Node, NodeRef, InputNode, OpNode. This "little language" represents programs in +TVM's graph format. To connect to the graph executor, we use a printer that converts our graph format into TVM's JSON format. The resulting string can be loaded by diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index b62fca86668d..81edf74a0a03 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -227,6 +227,8 @@ def _make_executor(self, expr=None): if expr is None or isinstance(expr, GlobalVar): assert self.mod is not None + _intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target) + def _interp_wrapper(*args, **kwargs): if expr is None: args = self._convert_args(self.mod["main"], args, kwargs) @@ -253,7 +255,6 @@ def _interp_wrapper(*args, **kwargs): mod = self.optimize() opt_expr = Call(mod["main"], relay_args) - _intrp = _backend.CreateInterpreter(mod, self.device, self.target) return _intrp(opt_expr) return _interp_wrapper diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index ed722643ff70..d1cf1c9bea2f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -40,7 +40,23 @@ from .backend.vm import VMExecutor -def _update_target(target): +def build_target_by_device_type_map(target): + """Build a map from DLDevice device_type to a Target used with that device. + + At runtime, TVM assigns target code to DLDevices by determining a device_type for each Target. + This function handles this process at compile time and, as a side effect, validates that exactly + one target maps to one device_type. + + Parameters + ---------- + target : Target or str or dict + If a Target or str: assumes that exactly one device type is present in the model. + If a dict: keys are tvm.ndarray.device, values are the targets used for each device. + + Returns + ------- + + """ target = target if target else Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") @@ -132,7 +148,7 @@ def build( params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) target, target_host = Target.check_and_update_host_consist( target, target_host, target_is_dict_key=False ) @@ -187,7 +203,7 @@ def optimize(self, mod, target=None, params=None): params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) # Setup the params. if params: @@ -316,7 +332,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: @@ -395,7 +411,7 @@ def optimize(mod, target=None, params=None): DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub @@ -495,7 +511,7 @@ def _graph_wrapper(*args, **kwargs): return _graph_wrapper -def create_executor(kind="debug", mod=None, device=None, target="llvm"): +def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None): """Factory function to create an executor. Example @@ -528,6 +544,10 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): target : :py:class:`tvm.Target` The corresponding context + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. + Returns ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` @@ -539,6 +559,9 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): else: device = _nd.device(str(target), 0) + if params is not None: + mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) + if isinstance(target, str): target = Target(target) if kind == "debug": diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 8d73a090ed6f..8461885b38ce 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar +from tvm.ir import RelayExpr, GlobalVar, Node from .base import RelayNode from . import _ffi_api @@ -538,3 +538,25 @@ def bind(expr, binds): The expression or function after binding. """ return _ffi_api.Bind(expr, binds) + + +@tvm._ffi.register_object("relay.StorageInfo") +class StorageInfo(Node): + """StorageInfo + + The static storage information produced by memory planning. + Contains the storage ids where expressions are stored, the + type of the "virtual devices" the expressions are stored on, + and the sizes of each storage element.""" + + @property + def storage_ids(self): + return _ffi_api.StorageInfoStorageIds(self) + + @property + def device_types(self): + return _ffi_api.StorageInfoDeviceTypes(self) + + @property + def storage_sizes(self): + return _ffi_api.StorageInfoStorageSizes(self) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 40a116ab0b43..b9ca7d0e11f2 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -152,8 +152,10 @@ def visit_let(self, let): self.visit(let.value) self.visit(let.body) - def visit_function(self, f): - self.visit(f.body) + def visit_function(self, fn): + for x in fn.params: + self.visit(x) + self.visit(fn.body) def visit_if(self, i): self.visit(i.cond) diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index d48e5634d986..b8273b0324c0 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -771,7 +771,7 @@ def from_caffe(init_net, predict_net, shape_dict, dtype_dict): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation. params : dict of str to tvm.NDArray diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py old mode 100644 new mode 100755 index d0e8c79c6392..9c53b59f9998 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -263,7 +263,7 @@ def get_relay_op(op_name): The Relay operator name. """ if "." in op_name: - # explicit hierachical modules + # explicit hierarchical modules op = _op try: for opn in op_name.split("."): @@ -527,6 +527,7 @@ def infer_value(input_val, params, mod=None): assert all( var.name_hint in params.keys() for var in analysis.free_vars(input_val) ), "All inputs to infer must be available in params." + assert tvm.runtime.enabled("llvm"), "LLVM must be enabled to infer value." try: # TODO(kevinthesun): Use VM for all cases. # pylint: disable=import-outside-toplevel @@ -553,7 +554,7 @@ def infer_value(input_val, params, mod=None): def infer_value_simulated(input_val, params): - """Extention to infer_value that can be used when some input + """Extension to infer_value that can be used when some input values are missing. This function creates dummy inputs with the same shape and random values then calls infer_value. This is helpful when implementing certain onnx operators where we need to evaluate the graph diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index f850750fad51..e515843e5fe2 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -562,7 +562,7 @@ def from_coreml(model, shape=None): etab = ExprTable() for i in spec.description.input: - input_shape = shape[i.name] if shape is not None and i.name in shape else None + input_shape = list(shape[i.name]) if shape is not None and i.name in shape else None etab.set_expr(i.name, _expr.var(i.name, shape=input_shape)) for pp in cc.preprocessing: diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 63521a67b065..aa185923d02e 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -725,6 +725,7 @@ def _convert_upsample3d(inexpr, keras_layer, etab): params["scale_h"] = h params["scale_w"] = w params["layout"] = etab.data_layout + params["coordinate_transformation_mode"] = "asymmetric" out = _op.nn.upsampling3d(inexpr, **params) return out diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3b940bd15f5b..59b4e99de999 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -963,7 +963,7 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") + return _op.image.resize2d(inputs[0], size, coordinate_transformation_mode="align_corners") def _mx_amp_multicast(inputs, attrs): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7135fccdf43b..c12e096e9051 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,6 +34,7 @@ from .. import qnn as _qnn from .. import ty as _ty from .. import vision as _vision +from .. import random as _random from .common import ( AttrCvt, Renamer, @@ -441,7 +442,10 @@ def autopad( # pad N and C with zeros pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type) + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) class Conv(OnnxOpConverter): @@ -457,6 +461,7 @@ def _impl_v1(cls, inputs, attr, params): kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -580,6 +585,70 @@ def _impl_v1(cls, inputs, attr, params): out = _op.nn.bias_add(out, inputs[2]) return out + @classmethod + def _impl_v11(cls, inputs, attr, params): + # get number of channels + out_type = infer_type(inputs[1]) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][1] + attr["channels"] = channels + groups = attr.get("group", 1) + + if "kernel_shape" not in attr: + attr["kernel_shape"] = out_shapes[0][2:] + + attr["groups"] = groups + # infer pads for auto_pad + data = inputs[0] + input_shape = infer_shape(data) + ndim = len(input_shape) + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + kernel_shape = attr["kernel_shape"] + kndim = len(kernel_shape) + dilations = attr.get("dilations", [1] * kndim) + output_padding = attr.get("output_padding", [0] * kndim) + strides = attr["strides"] + total_pad = [0] * kndim + for i in range(kndim): + total_pad[i] = ( + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - strides[i] + ) + left = [p // 2 for p in total_pad] + right = [total_pad[i] - left[i] for i in range(kndim)] + if "LOWER" in attr["auto_pad"]: + pad = left + right + else: + pad = right + left + attr["pads"] = pad + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out = AttrCvt( + op_name=dimension_picker("conv", "_transpose"), + transforms={ + "kernel_shape": "kernel_size", + "dilations": ("dilation", 1), + "pads": ("padding", 0), + "group": ("groups", 1), + }, + disables=["output_shape"], + custom_check=dimension_constraint(), + )([data, inputs[1]], attr, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + class GlobalAveragePool(OnnxOpConverter): """Operator converter for GlobalAveragePool""" @@ -677,25 +746,34 @@ def _impl_v1(cls, inputs, attr, params): # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: - def flatten_to_3d(x, x_shape): + def flatten_to_nd(x, x_shape, nd=3): ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x newshape = _op.concatenate( [ _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), - _op.strided_slice(x_shape, [ndims - 2], [ndims]), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), ], 0, ) out = _op.reshape(x, fold_constant(newshape)) return out - # Convert a and b into 3 dimensional tensors. - a = flatten_to_3d(inputs[0], a_shape) - b = flatten_to_3d(inputs[1], b_shape) - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) - # Perform a batch matmul. - output = _op.nn.batch_matmul(a, b) + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(inputs[0], a_shape, 3) + b = flatten_to_nd(inputs[1], b_shape, 3) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -930,7 +1008,7 @@ def _impl_v11(cls, inputs, attr, params): if len(inputs) == 3: value = fold_constant(_op.take(inputs[2], _op.const(0))) else: - value = 0 + value = 0.0 pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1)))) pad_mode = attr.get("mode", b"constant").decode("utf-8") @@ -1199,7 +1277,13 @@ def _impl_v9(cls, inputs, attr, params): layout = "NCDHW" out = _op.nn.upsampling3d( - inputs[0], scale_d, scale_h, scale_w, layout=layout, method=method + inputs[0], + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", ) # in 2d case, use dynamic op else: @@ -2388,9 +2472,9 @@ def _impl_v10(cls, inputs, attr, params): if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2398,21 +2482,31 @@ def _impl_v10(cls, inputs, attr, params): scale = inputs[1] size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - layout = "NCHW" # ONNX assumes NCHW layout - out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric") + ndims = len(infer_shape(inputs[0])) + out = None + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + out = _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric") + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric") + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + out = _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric") + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") + return out @classmethod def _impl_v11(cls, inputs, attr, params): - layout = "NCHW" # ONNX assumes NCHW layout - + ndims = len(infer_shape(inputs[0])) mode = attr.get("mode").decode("ascii") if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2434,10 +2528,26 @@ def _impl_v11(cls, inputs, attr, params): assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = None + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + out = _op.image.resize1d( + inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = _op.image.resize2d( + inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + out = _op.image.resize3d( + inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude + ) + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") - return _op.image.resize( - inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude - ) + return out class NonZero(OnnxOpConverter): @@ -2696,24 +2806,24 @@ def get_var(name, val, scan=False): loop_var_names = [v.name_hint for v in loop_vars] num_scan_outputs = len(body.output) - (1 + num_deps) - # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed. - if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]: - warnings.warn( - """ - Using scan outputs in a loop with strided slice - currently may cause errors during compilation. - """ - ) - # Construct variables and intial empty tensors for any scan outputs. + # Construct variables and initial empty tensors for any scan outputs. + # To do this, we'll figure out the output shapes of the body subgraph by importing + # it and doing type inference. scan_output_vars = [] scan_output_init = [] + if num_scan_outputs > 0: + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) + for i in range(num_scan_outputs): - name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) - if dtype is None: - dtype = infer_type(loop_deps[i]).checked_type.dtype - if dtype == "float": - dtype = "float32" + name, _, _, _ = get_info(body.output[i + 1 + num_deps]) + output_node = infer_type(loop_outputs[i + 1 + num_deps]) + shape = get_const_tuple(output_node.checked_type.shape) + dtype = output_node.checked_type.dtype scan_output_vars.append( _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) ) @@ -2846,7 +2956,10 @@ def _impl_v1(cls, inputs, attr, params): graph_scope._nodes.update({var.name_hint: var}) # Now we can construct the relay if statement and return. - return _expr.If(cond, then_expr, else_expr) + ret = _expr.If(cond, then_expr, else_expr) + if len(then_branch.output) > 1: + ret = _expr.TupleWrapper(ret, len(then_branch.output)) + return ret class NonMaxSuppression(OnnxOpConverter): @@ -3157,6 +3270,79 @@ def get_scalar(x, dtype="float32"): return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype) +class ConvInteger(OnnxOpConverter): + """Operator converter for ConvInteger.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + data = inputs[0] + weight = inputs[1] + data_zp = inputs[2] + weight_zp = inputs[3] + if data_zp is None: + data_zp = _expr.const(0, "int32") + if weight_zp is None: + weight_zp = _expr.const(0, "int32") + + input_type = infer_type(data) + input_shape = get_const_tuple(input_type.checked_type.shape) + + ndim = len(input_shape) + kernel_type = infer_type(weight) + kernel_shape = get_const_tuple(kernel_type.checked_type.shape) + if "kernel_shape" not in attr: + attr["kernel_shape"] = kernel_shape[2:] + + if "auto_pad" in attr: + attr["auto_pad"] = attr["auto_pad"].decode("utf-8") + if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): + # Warning: Convolution does not yet support dynamic shapes, + # one will need to run dynamic_to_static on this model after import + data = autopad( + data, + attr.get("strides", [1] * (ndim - 2)), + attr["kernel_shape"], + attr.get("dilations", [1] * (ndim - 2)), + ndim, + pad_value=data_zp, + mode=attr["auto_pad"], + ) + elif attr["auto_pad"] == "VALID": + attr["pads"] = tuple([0 for i in range(ndim - 2)]) + elif attr["auto_pad"] == "NOTSET": + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) + attr.pop("auto_pad") + + out_channels = kernel_shape[0] + dilation = attr.get("dilations", [1] * (ndim - 2)) + strides = attr.get("strides", [1] * (ndim - 2)) + padding = attr["pads"] if "pads" in attr else 0 + groups = attr["group"] if "group" in attr else 1 + + if ndim != 4: + raise tvm.error.OpAttributeInvalid( + "Only 2D kernels are supported for operator ConvInteger." + ) + + return _qnn.op.conv2d( + data, + weight, + _op.cast(data_zp, "int32"), + _op.cast(weight_zp, "int32"), + _expr.const(1.0, "float32"), + _expr.const(1.0, "float32"), + kernel_size=attr["kernel_shape"], + channels=out_channels, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + ) + + class BitShift(OnnxOpConverter): """Operator converter for NonZero""" @@ -3208,6 +3394,30 @@ def _impl_v11(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) +class RandomUniform(OnnxOpConverter): + """Operator converter for random_uniform""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = get_type(attr.get("dtype", 1)) + high = attr.get("high", 1.0) + low = attr.get("low", 0.0) + seed = attr.get("seed", None) + shape = attr["shape"] + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + key = _random.threefry_key(seed) + output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3385,6 +3595,9 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + "ConvInteger": ConvInteger.get_converter(opset), + # Random number generation. + "RandomUniform": RandomUniform.get_converter(opset), } diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..33cb83b883bc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -41,6 +41,7 @@ from . import qnn_torch from .common import AttrCvt, get_relay_op from .common import infer_value as _infer_value +from .common import infer_shape as _infer_shape from .common import infer_value_simulated as _infer_value_simulated from .common import try_infer_value from .pytorch_utils import is_version_greater_than @@ -1798,7 +1799,7 @@ def get_upsample_out_size(self, inputs, method): else: out_size.append(size) else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scale_index = 3 if method == "linear" else 2 scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) @@ -1813,7 +1814,7 @@ def upsample(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "bilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -1826,7 +1827,7 @@ def upsample(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans) if self.is_quantized_tensor(data): # input qparams are manually appended by us @@ -1845,7 +1846,7 @@ def upsample3d(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "trilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -1877,9 +1878,6 @@ def Float(self, inputs, input_types): assert len(inputs) == 1 return _op.cast(inputs[0], "float32") - def mm(self, inputs, input_types): - return _op.nn.dense(inputs[0], inputs[1]) - def bitwise_not(self, inputs, input_types): data = inputs[0] # The input tensor must be of integral or Boolean types. @@ -2195,6 +2193,8 @@ def interpolate(self, inputs, input_types): method = inputs[3] if method.startswith("nearest"): method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] if method == "nearest_neighbor": coord_trans = "asymmetric" @@ -2203,7 +2203,7 @@ def interpolate(self, inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans) def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) @@ -2325,6 +2325,303 @@ def nll_loss(self, inputs, input_types): weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + def flip(self, inputs, input_types): + data = inputs[0] + axis = inputs[1] + return _op.transform.reverse(data, axis=axis[0]) + + def lstm_cell(self, input_seqs, hidden, weights, has_proj=False): + if has_proj: + assert len(weights) == 5 + else: + assert len(weights) == 4 + outputs_list = [] + # Default activations types + f_act = _op.sigmoid + g_act = _op.tanh + h_act = _op.tanh + + # Input hiddens + H_t = hidden[0] # (batch, hidden_size) + C_t = hidden[1] # (batch, hidden_size) + for x_t in input_seqs: + # x_t shape = (batch, feature size) + # gates shape = (batch, 4 * hidden_size) + gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t, weights[1]) + # Add biases + if weights[2] is not None: + gates += weights[2] + if weights[3] is not None: + gates += weights[3] + i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size) + + i = f_act(i) + f = f_act(f) + c = g_act(c) + o = f_act(o) + + C = f * C_t + i * c + H = o * h_act(C) + + if has_proj: + H = _op.nn.dense(H, weights[4]) + + H_t = H + C_t = C + outputs_list.append(H) # [seq_num, (batch, hidden_size)] + hidden_outputs = (H_t, C_t) + + return (outputs_list, hidden_outputs) + + def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair, has_proj=False): + fw_outputs = self.lstm_cell(input_seq, hidden_pair[0], weights_pair[0], has_proj) + + rev_input_seq = [] + seq_len = len(input_seq) + for i in range(seq_len): + rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num, (batch, hidden_size)] + rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1], weights_pair[1], has_proj) + + final_outputs = [] # [seq_num, (batch, 2 * hidden_size)] + for j in range(seq_len): + final_outputs.append( + _op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1 - j]], -1) + ) + + return final_outputs, (fw_outputs[1], rev_outputs[1]) + + def lstm_layers( + self, input_data, hiddens, weights, bidirectional, dtype, dropout_p=0.0, has_proj=False + ): + hidden_layers_num = len(hiddens) + assert len(weights) == hidden_layers_num + + # split input sequence to samples set + input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch, feature_size)] + output_hiddens = [] + for k in range(hidden_layers_num): + hiddens_input = hiddens[k] + weights_input = weights[k] + + outputs = ( + self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) + if bidirectional + else self.lstm_cell(input_seqs, hiddens_input, weights_input, has_proj) + ) + + output_hiddens.append(outputs[1]) + # input_seqs shape = [seq_num, (batch, feature_size)] or + # [seq_num, (batch, 2*feature_size)] for bidirectional + input_seqs = outputs[0] + + # TODO (vvchernov): in pytorch implementation train is also checked + # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339 + # /aten/src/ATen/native/RNN.cpp#L1054 + if dropout_p != 0 and k < hidden_layers_num - 1: + # for input in input_seqs: + # input = _op.dropout(input, dropout_p) + raise NotImplementedError("Dropout for LSTM has not been supported yet!") + final_hiddens = [] + if bidirectional: + for i in range(hidden_layers_num): + final_hiddens.append(output_hiddens[i][0]) + final_hiddens.append(output_hiddens[i][1]) + else: + final_hiddens = output_hiddens + + return _op.stack(input_seqs, 0), final_hiddens + + def lstm(self, inputs, input_types): + """ + Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html + Native implementation for torch version less than 1.8.0 (projection is unsupported): + https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \ + src/ATen/native/RNN.cpp#L1396 + Native implementation for torch version from 1.8.0 and higher (projection is supported): + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 + """ + # TODO (vvchernov): support dropout + assert len(inputs) == 9, "Input of size 9 is expected" + # Unpack inputs, note that if optional and not provided then value will be None. + _X = inputs[0] + # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + + hidden_states = inputs[1] + assert len(hidden_states) == 2, "lstm expects two hidden states" + h_0 = hidden_states[0] + c_0 = hidden_states[1] + # H0 shape (hidden_layers_num, batch, proj_size) if projection + # else (hidden_layers_num, batch, hidden_size) + # C0 shape (hidden_layers_num, batch, hidden_size) + + _weights = inputs[2] + # If no projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, hidden_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, hidden_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + + # If projection + # Wi layer[0] shape (4 * hidden_size, feature_size) + # Wh layer[0] shape (4 * hidden_size, proj_size) + # Bi layer[0] shape (4 * hidden_size) + # Bh layer[0] shape (4 * hidden_size) + # P layer[0] shape (proj_size, hidden_size) + + # Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions) + # Wh layer[>0] shape (4 * hidden_size, proj_size) + # Bi layer[>0] shape (4 * hidden_size) + # Bh layer[>0] shape (4 * hidden_size) + # P layer[>0] shape (proj_size, hidden_size) + + # Scalar inputs + has_biases = inputs[3] + num_layers = inputs[4] + dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout + # train = inputs[6] + bidirectional = inputs[7] + batch_first = inputs[8] + + num_directions = 1 + if bidirectional: + num_directions = 2 + + rsd = len(_weights) % num_layers + assert rsd == 0, "The number of weights must be a multiple of the number of layers!" + rsd = (len(_weights) / num_layers) % num_directions + assert ( + rsd == 0 + ), "The number of weights in layer must be a multiple of the number of directions!" + has_proj = False + proj_size = 0 + weights_num = int(len(_weights) / num_layers / num_directions) + if has_biases: + if weights_num == 5: + has_proj = True + proj_size = _infer_shape(_weights[4])[0] + else: + assert weights_num == 4, "The weights number in layer is expected equal to 4" + else: + if weights_num == 3: + has_proj = True + proj_size = _infer_shape(_weights[2])[0] + else: + assert weights_num == 2, "The weights number in layer is expected equal to 2" + + weights = [] + if has_biases: + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights = [] + rev_weights = [] + for j in range(weights_num): + fw_weights.append(_weights[i + j]) + rev_weights.append(_weights[i + j + weights_num]) + weights.append((fw_weights, rev_weights)) + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + fw_weights = [] + for j in range(weights_num): + fw_weights.append(_weights[i + j]) + weights.append(fw_weights) + else: + if bidirectional: + rsd = len(_weights) % (2 * weights_num) + assert rsd == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), 2 * weights_num): + fw_weights = [] + rev_weights = [] + k = i + weights_num + if has_proj: + fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] + rev_weights = [_weights[k], _weights[k + 1], None, None, _weights[k + 2]] + else: + fw_weights = [_weights[i], _weights[i + 1], None, None] + rev_weights = [_weights[k], _weights[k + 1], None, None] + weights.append((fw_weights, rev_weights)) + else: + assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" + for i in range(0, len(_weights), weights_num): + if has_proj: + fw_weights = [_weights[i], _weights[i + 1], None, None, _weights[i + 2]] + else: + fw_weights = [_weights[i], _weights[i + 1], None, None] + weights.append(fw_weights) + assert ( + len(weights) == num_layers + ), "For stacked LSTM number of weights tuples should be the same as number of layers!" + + X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X + # TODO (vvchernov): Which data type should be used? from input or weights? + # Instead of it _infer_type(X).checked_type.dtype can be used + X_dtype = input_types[0] + X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + + hidden_size = _infer_shape(_weights[0])[0] / 4 + batch_size = X_shape[1] + + # Initialize hidden states if not provided. + layers_h = [] + layers_c = [] + hidden_layers_num = num_directions * num_layers + if h_0 is None: + if has_proj: + h_0 = _op.zeros((batch_size, proj_size), X_dtype) + else: + h_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_h.append(h_0) + else: + layers_h = self.unbind((h_0, 0), X_dtype) + if c_0 is None: + c_0 = _op.zeros((batch_size, hidden_size), X_dtype) + for i in range(hidden_layers_num): + layers_c.append(c_0) + else: + layers_c = self.unbind((c_0, 0), X_dtype) + + hiddens = [] + for i in range(num_layers): + if bidirectional: + hiddens.append( + ((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1], layers_c[2 * i + 1])) + ) + else: + hiddens.append((layers_h[i], layers_c[i])) + + outputs = self.lstm_layers( + X, + hiddens, + weights, + bidirectional, + dtype=X_dtype, + dropout_p=dropout_p, + has_proj=has_proj, + ) + + # output shape = (seq_num, batch, hidden_size) or + # (seq_num, batch, 2*feature_size) for bidirectional + output = outputs[0] + + hy = [] + cy = [] + for hidden in outputs[1]: + hy.append(hidden[0]) + cy.append(hidden[1]) + + if batch_first: + output = _op.transpose(output, (1, 0, 2)) + + return (output, _op.stack(hy, 0), _op.stack(cy, 0)) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2473,9 +2770,9 @@ def create_convert_map(self): "aten::clamp": self.clamp, "aten::clamp_": self.clamp, "aten::detach": self.identity, - "aten::upsample_bilinear2d": self.make_upsample("bilinear"), + "aten::upsample_bilinear2d": self.make_upsample("linear"), "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), - "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"), + "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), "aten::expand_as": self.expand_as, "aten::lt": self.make_elemwise("less"), @@ -2533,12 +2830,15 @@ def create_convert_map(self): "aten::hardsigmoid": self.hard_sigmoid, "aten::cumsum": self.cumsum, "aten::masked_fill": self.masked_fill, + "aten::masked_fill_": self.masked_fill, "aten::masked_select": self.masked_select, "aten::argsort": self.argsort, "aten::sort": self.sort, "aten::_unique2": self.unique, "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, + "aten::flip": self.flip, + "aten::lstm": self.lstm, } def update_convert_map(self, custom_map): @@ -3284,7 +3584,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The module that optimizations will be performed on. params : dict of str to tvm.runtime.NDArray diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 02b2484d4fb7..753b1f253a89 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -90,7 +90,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold): """ one = is_constant() - # Equivelent PyTorch code from above snippet + # Equivalent PyTorch code from above snippet # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) cast = is_op("cast")(idxs) mx = is_op("max")(boxes) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index f614982aac6c..9eafae905baf 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -163,7 +163,7 @@ def _get_quant_param_for_input(input_value): """ We want to know the input scale and zp of this input_value, since input quant params are not explicitly passed around in torch (they - are embeded in a QTensor data structure, not visible statically). + are embedded in a QTensor data structure, not visible statically). We know that it is quantized using output scale and zp of some previous quantized op. The purpose of this function is to find that pair of parameters. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index e297398ffe5b..d35e0e1c203d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -52,6 +52,11 @@ # However, please note that `nn.matmul` is in experimental so it may have some performance # issues. "use_dense": True, + # By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`. + # Change this flag to False to directly convert to `nn.batch_matmul`. + # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some + # performance issues. + "use_nt_batch_matmul": True, } # compatible operators that do NOT require any conversion. @@ -117,7 +122,7 @@ def _in_while_loop(control_flow_node_map, op_name): Parameters ---------- control_flow_node_map : Dict[str, Set[str]] - A dictionay contains the unique control flow execution frame name to + A dictionary contains the unique control flow execution frame name to a set of primitive operators mapping. op_name : str @@ -139,7 +144,7 @@ class RewriteSubgraph(ExprMutator): Parameters ---------- rewrite_map : Dict[expr, expr] - A dictionay contains a set of expr to var mapping. + A dictionary contains a set of expr to var mapping. """ def __init__(self, rewrite_map): @@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): return func, self._params -def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True): +def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. @@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. - use_dense_op : bool (Optional) = True - Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. - The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be - transposed, may insert extra `transpose` to the original graph. + convert_config : Optional[Dict[str, Any]] + Default config: + use_dense : bool = True + Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. + The `nn.dense` op requires the data tensor to be non-transposed and weight tensor + to be transposed, may insert extra `transpose` to the original graph. + use_nt_batch_matmul : bool = True + True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format + (transpose_a=False, transpose_b=True). Returns ------- @@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op Dict of converted parameters stored in tvm.nd.NDArray format """ global TF_DEFAULT_CONFIGS - TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op + if convert_config is not None: + TF_DEFAULT_CONFIGS.update(convert_config) g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index e5339b33c4e9..465f530624b9 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks """Tensorflow2.x graph to relay converter. If model is constructed using tf2.x API, then use this converter: @@ -38,12 +38,20 @@ from .common import infer_type as _infer_type from .tensorflow_ops import _convert_map as _convert_map_common -from .tensorflow_ops import _need_prelude_for_shape_inference +from .tensorflow_ops import _get_more_static_shape_rank +from .tensorflow2_ops import _convert_map as _convert_map_tf2 +from .tensorflow2_ops import _need_prelude_for_shape_inference from ..ty import Any __all__ = ["from_tensorflow"] +# A map to record tensor list write ops and input tl/tensor indices +# Value is (index of tensor list, index of written node) +_tensor_list_write_ops = { + "TensorListSetItem": (0, 2), +} + def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) @@ -66,6 +74,11 @@ def set_span(sym, node_name): return sym +def is_tensor_list_constuctor(tf_node): + """Check whether is tensor list constructor node.""" + return tf_node.op == "TensorListReserve" + + def convert_const_node(node, shape): """convert tf const node into relay const or var""" @@ -196,6 +209,10 @@ def __init__(self, module): self._output_shapes = {} self._tf_node_map = {} self._gdef_lib = {} + self._tensor_list_shapes = {} + self._tensor_list_shape_nodes = {} + self._sub_map = {} + self._sub_input_idx_map = {} def from_tensorflow( self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None @@ -215,10 +232,134 @@ def from_tensorflow( ) return func, self._params + def _analysis_tensor_list_op( + self, + graph, + node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name="", + root_node="", + ): + if sub_func_name and sub_func_name not in self._sub_input_idx_map: + self._sub_input_idx_map[sub_func_name] = {} + + if node.op == "Placeholder": + # record placeholder node in sub functions + self._sub_map[sub_func_name] = node + self._sub_input_idx_map[sub_func_name][node.name] = len( + self._sub_input_idx_map[sub_func_name] + ) + + if node.op.startswith("TensorList"): + if is_tensor_list_constuctor(node): + tl_construct_nodes.append(node) + else: + for tl_write_name, idx in _tensor_list_write_ops.items(): + if node.op.startswith(tl_write_name): + tl_write_nodes.append((node, idx, sub_func_name, root_node)) + if node.op.startswith("TensorListStack"): + tl_stack_nodes.append(node) + elif node.op.startswith("StatelessWhile"): + root_node = node.name + cond_fn_name, body_fn_name = [ + parse_attr(node.attr).get(x).name for x in ["cond", "body"] + ] + for fn_name in [cond_fn_name, body_fn_name]: + subfunction = self._gdef_lib[fn_name] + sub_func_name = fn_name + for sub_node in subfunction.node: + # bypass const node + if sub_node.op == "Const": + continue + self._tf_node_map[sub_node.name] = sub_node + self._analysis_tensor_list_op( + subfunction, + sub_node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name=sub_func_name, + root_node=root_node, + ) + + def _infer_static_shape_stack_node(self, tl_stack_nodes): + for stack_node in tl_stack_nodes: + if len(stack_node.input) < 2: + # Stack node does not have shape + continue + input_shape_name = stack_node.input[1].split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]] + in_idx = -1 + while stack: + cnode = stack.pop(0) + if not cnode.op.startswith("TensorList"): + if in_idx and cnode.op.startswith("StatelessWhile"): + stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]]) + else: + for iname in cnode.input: + if self._tf_node_map[iname.split(":")[0]].op.startswith( + "StatelessWhile" + ): + # identify input index based on output index + if iname.split(":")[1]: + in_idx = int(iname.split(":")[1]) + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add shape to _tensor_list_shapes + elif cnode.name != stack_node.name: + if is_tensor_list_constuctor(cnode): + shape_attr = parse_attr(input_shape_node.attr) + if "value" not in shape_attr: + continue + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(int(dim)) + self._tensor_list_shapes[cnode.name] = elem_shape + break + + def _infer_static_shape_write_node(self, tl_write_nodes): + for item in tl_write_nodes: + wnode = item[0] + ta_idx, inode_idx = item[1] + sub_func_name = item[2] + root_name = item[3] + stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]] + while stack: + cnode = stack.pop(0) + + if not cnode.op.startswith("TensorList"): + if cnode.op == "Placeholder" and sub_func_name: + # need to map subfunction + input_idx = self._sub_input_idx_map[sub_func_name][cnode.name] + stack.append( + self._tf_node_map[ + self._tf_node_map[root_name].input[input_idx].split(":")[0] + ] + ) + else: + for iname in cnode.input: + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add it to _tensor_list_shape_nodes + elif cnode.name != wnode.name: + if is_tensor_list_constuctor(cnode): + inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] + tn = wnode.input[inode_idx].split(":") + output_index = int(tn[1]) if len(tn) > 1 else 0 + self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index) + break + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None): if input_types is None: input_types = {} - + tl_write_nodes = [] + tl_stack_nodes = [] + tl_construct_nodes = [] self._layout = layout for node in graph.node: name = node.name @@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ self._nodes[node.name] = sym if param: self._params[node.name] = param + # recursivly iterate tensorlist op if seen while loop + else: + self._analysis_tensor_list_op( + graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes + ) + + # Use tensor list stack to infer static tensor list shape + self._infer_static_shape_stack_node(tl_stack_nodes) + + # Fetch node contains static tensor list shape + self._infer_static_shape_write_node(tl_write_nodes) + for node in graph.node: self._backtrack_construct(graph, node.name) @@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): gdef_lib=self._gdef_lib, ) elif op_name in _convert_map_common: + # assert op are exclusive + assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys()) if _need_prelude_for_shape_inference(op_name): sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude) else: sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod) + elif op_name in _convert_map_tf2: + if _need_prelude_for_shape_inference(op_name): + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) sym = set_span(sym, node_name) return sym + def _parse_element_shape(self, elem_shape, shape_attr): + if "value" in shape_attr: + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + + if raw_elem_shape.size == 1 and raw_elem_shape == -1: + elem_shape.append(Any()) + else: + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + def _backtrack_construct(self, graph, node_name): """Convert a specific tensorflow node to relay expression. @@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name): CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) """ - input_op_name = node_name.split(":")[0].split("^")[-1] + if input_op_name not in self._nodes: node = self._tf_node_map[input_op_name] attr = parse_attr(node.attr) @@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name): attr["_node_name"] = node.name attr["_target_layout"] = self._layout inputs = [self._backtrack_construct(graph, iname) for iname in node.input] - op = self._convert_operator(graph, node.op, node.name, inputs, attr) + # infer shape for TensorList op + if is_tensor_list_constuctor(node): + input_shape_name = ( + node.input[1] if "TensorListFromTensor" in node.op else node.input[0] + ) + input_shape_name = input_shape_name.split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + shape_attr = parse_attr(input_shape_node.attr) + elem_shape = [] + + self._parse_element_shape(elem_shape, shape_attr) + + if elem_shape: + attr["shape"] = elem_shape + if ( + "identical_element_shapes" in attr and attr["identical_element_shapes"] + ) or elem_shape: + shape = elem_shape + if node.name in self._tensor_list_shapes: + preset_shape = self._tensor_list_shapes[node.name] + shape = _get_more_static_shape_rank(shape, preset_shape) + attr["shape"] = shape + + op = self._convert_operator(graph, node.op, node.name, inputs, attr) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [ @@ -512,7 +708,7 @@ def _convert_function( Examples -------- - a tf function "x+1", is implemented as a subgraph in the libary section of the graph. + a tf function "x+1", is implemented as a subgraph in the library section of the graph. this subgraph is converted to a relay function such as fn (%x: float32) { add(%x, 1f) /* Identity */ diff --git a/python/tvm/relay/frontend/tensorflow2_ops.py b/python/tvm/relay/frontend/tensorflow2_ops.py new file mode 100644 index 000000000000..945554816984 --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow2_ops.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x to relay converter ops and helper""" +import tvm +from tvm.relay.prelude import StaticTensorArrayOps, get_tensor_array_shape + +from .. import op as _op +from ..ty import Any +from .common import infer_value as _infer_value +from .common import infer_type as _infer_type +from .tensorflow_ops import _get_more_static_shape_rank + + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def _need_prelude_for_shape_inference(op): + return "TensorList" in op or "TensorArray" in op + + +def _tensorlist_reserve(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + elem_shape = _infer_value(inputs[0], params, prelude.mod) + elem_shape = tuple(elem_shape.asnumpy().astype("int32").flatten()) + + if elem_shape or "shape" in attr: + shape = attr["shape"] if "shape" in attr else elem_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) + static_tensor_array_ops.register() + tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") + tensor_array = tensor_array_constructor(inputs[1]) + else: + tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) + tensor_array = tensor_array_constructor(inputs[1]) + return tensor_array + + return _impl + + +def _tensorlist_set_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + input_ta = inputs[0] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_type_with_prelude(inputs[2], prelude).shape + input_rank = len(input_t_shape) + + if input_ta_shape is None: + tensor_name = "tensor{}".format(input_rank) + tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_func(inputs[2]) + write_func = prelude.get_global_var("tensor_array_write", dtype_str) + out = write_func(input_ta, inputs[1], v) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_func(inputs[2]) + # Write tensor with more static shape + # convert shape with -1 to any() + input_ta_shape_a = [] + for dim in input_ta_shape: + if isinstance(dim, (int, tvm.tir.expr.IntImm)): + if dim < 0: + input_ta_shape_a.append(Any()) + else: + input_ta_shape_a.append(dim) + else: + input_ta_shape_a.append(dim) + actual_shape = _get_more_static_shape_rank(input_t_shape, input_ta_shape_a) + if actual_shape != input_ta_shape_a: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + write_func = prelude.get_global_var_static( + "tensor_array_write", dtype_str, input_ta_shape_a + ) + out = write_func(input_ta, inputs[1], v) + return out + + return _impl + + +def _tensorlist_get_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_shape is None: + read_func = prelude.get_global_var("tensor_array_read", dtype_str) + out = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + read_func = static_tensor_array_ops.get_global_var("tensor_array_read") + out_tensor = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") + out = get_data_func(out_tensor) + return out + + return _impl + + +def _tensorlist_stack(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_ta_shape is None: + stack_func = prelude.get_global_var("tensor_array_stack", dtype_str) + out = stack_func(inputs[0]) + else: + if "num_elements" in attr: + num_elements = attr["num_elements"] + static_tensor_array_ops = StaticTensorArrayOps( + prelude, dtype_str, input_ta_shape, num_elements + ) + static_tensor_array_ops.register() + stack_func = prelude.get_global_var_static( + "tensor_array_stack", dtype_str, input_ta_shape, num_elements + ) + out_tensor = stack_func(inputs[0]) + out_shape = ( + (num_elements,) + input_ta_shape + if num_elements and num_elements == 1 + else (Any(),) + input_ta_shape + ) + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) + static_tensor_array_ops.register() + get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) + out = get_data_func(out_tensor) + + return out + + return _impl + + +def _tensorlist_from_tensor(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = _infer_type_with_prelude(inputs[0], prelude).shape + + if input_ta_shape is None: + unstack_func = prelude.get_global_var("tensor_array_unstack", dtype_str) + out = unstack_func(inputs[0]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + unstack_func = prelude.get_global_var_static( + "tensor_array_unstack", dtype_str, input_ta_shape + ) + out = unstack_func(inputs[0]) + return out + + return _impl + + +_convert_map = { + "TensorListFromTensor": _tensorlist_from_tensor(), + "TensorListGetItem": _tensorlist_get_item(), + "TensorListReserve": _tensorlist_reserve(), + "TensorListSetItem": _tensorlist_set_item(), + "TensorListStack": _tensorlist_stack(), +} diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 004174f076fd..a8213d4b1c49 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -138,6 +138,18 @@ def _get_more_static_shape(shape0, shape1): return shape1 +def _get_more_static_shape_rank(shape0, shape1): + """Compare two shapes with different rank, + and return the one with fewer symbolic dimension. + """ + num_sym_dim0 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape0)]) + num_sym_dim1 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape1)]) + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + + def _rsqrt(): def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr["T"].name)) @@ -1075,7 +1087,7 @@ def _impl(inputs, attr, params, mod): # Ignore the new attributes from TF2.0, for now. return AttrCvt( - op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} + op_name="resize2d", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} )(inputs, attr) return _impl @@ -1137,6 +1149,8 @@ def _impl(inputs, attr, params, mod): def _batch_matmul(): def _impl(inputs, attr, params, mod): + from .tensorflow import TF_DEFAULT_CONFIGS + input_x = inputs[0] input_y = inputs[1] orig_shape_x = _infer_shape(input_x, mod) @@ -1173,9 +1187,16 @@ def _impl(inputs, attr, params, mod): input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] - input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x - input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y - ret = get_relay_op("batch_matmul")(input_x, input_y) + + if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Strictly convert all batch_matmul to NT format + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y + ret = get_relay_op("batch_matmul")(input_x, input_y) + else: + ret = get_relay_op("batch_matmul")( + input_x, input_y, transpose_a=adj_x, transpose_b=adj_y + ) # reshape result back to n-dimensional if ndim > 3: @@ -1483,7 +1504,13 @@ def _impl(inputs, attr, params, mod): def _concatV2(): def _impl(inputs, attr, params, mod): pop_node = inputs.pop(len(inputs) - 1) - axis = int(_get_num_param(params, pop_node)) + try: + axis = int(_get_num_param(params, pop_node)) + except (IndexError, KeyError, AttributeError): + try: + axis = int(_infer_value(pop_node, params, mod).numpy()) + except Exception: + axis = int(pop_node) return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})( [inputs], attr ) @@ -2244,7 +2271,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin[index] < 0 else begin[index] ) - m_end[final_index] = begin[index] + 1 + m_end[final_index] = m_begin[final_index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) else: @@ -2943,8 +2970,8 @@ def _impl(inputs, attr, params, mod): "Relu": AttrCvt("relu"), "Relu6": _relu6(), "Reshape": _reshape(), - "ResizeBicubic": _resize("bilinear"), - "ResizeBilinear": _resize("bilinear"), + "ResizeBicubic": _resize("cubic"), + "ResizeBilinear": _resize("linear"), "ResizeNearestNeighbor": _resize("nearest_neighbor"), "ReverseV2": _reverse_v2(), "RightShift": AttrCvt("right_shift"), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a47fdf0141b5..5501185f7985 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -255,23 +255,23 @@ def get_op_code_str(self, op): op_c = self.model.OperatorCodes(op_code_list_idx) # In TFlite 2.4.x there was a change where the type of the field that contained # the builtin code changed from int8 to int32 in the flat buffer representation. - # However to retain support for old flat buffers that were created, they retained - # the original 8 bit encoding for the operator but in a new field accessed by the - # DeprecatedBuiltinCode method. - # This means that the API function BuiltinCode() is used on an operator - # which was originally encoded as an 8 bit quantity it would look for the - # code in the new int32 field in the schema and this creates the need - # for the check for the magic number of 127 which is indicated by - # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES + # However, to retain support for old flat buffers that were created, they retained + # the original 8 bit field, but named it "deprecated_builtin_code" in TFLite 2.4. + # This means that the API function BuiltinCode() which originally returned the value + # of the 8 bit field would now look for the value in the new int32 field in the + # schema and DeprecatedBuiltinCode() will look at the old 8 bit field. + # In TFLite 2.4, if the opcode value is less than 127, it can be in either field + # (however, if it is only in the "builtin_code" field, the model is not backward + # compatible), so similarly to TFLite 2.4 reader, we'll pick the higher value of the + # two fields. # Remember however that this value came into existence only after Tensorflow # lite 2.4.x and hence encase it in a try -except block. # Phew ! try: - if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: - opc = op_c.DeprecatedBuiltinCode() - else: - opc = op_c.BuiltinCode() + opc = max(op_c.DeprecatedBuiltinCode(), op_c.BuiltinCode()) except AttributeError: + # In versions before 2.4 the int8 field that holds the builtin code is accessed + # by BuiltinCode() and DeprecatedBuiltinCode() doesn't exist opc = op_c.BuiltinCode() op_code_id = opc @@ -630,7 +630,7 @@ def _convert_resize(self, method, op): # Options - align_corners (bool) resize_options = None align_corners = False - bilinear_method = method == "bilinear" + bilinear_method = method == "linear" if bilinear_method: assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions resize_options = ResizeBilinearOptions() @@ -647,7 +647,7 @@ def _convert_resize(self, method, op): coord_trans = "align_corners" if align_corners else "asymmetric" if bilinear_method and input_tensor.qnn_params: in_expr = self.dequantize(in_expr, input_tensor) - out = _op.image.resize( + out = _op.image.resize2d( in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans ) if bilinear_method and output_tensor.qnn_params: @@ -656,7 +656,7 @@ def _convert_resize(self, method, op): def convert_resize_bilinear(self, op): """Convert TFLite RESIZE_BILINEAR""" - return self._convert_resize("bilinear", op) + return self._convert_resize("linear", op) def convert_resize_nearest_neighbor(self, op): """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" @@ -3471,7 +3471,7 @@ def _get_flattened_index(indices, shape): indices_list = [] # Below function iterates through each applicable indices per dimension - # based on format type specified and finaly produce the dense matrix and the NZ indices. + # based on format type specified and finally produce the dense matrix and the NZ indices. def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx): if level == len(indices): start_pos = 0 diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index 4b5d2b9c605c..4533886d14da 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -88,7 +88,7 @@ def indirect_jump(self, offset, byte_width): def decode_keys(self, end, size, byte_width): """Decodes the flexbuffer type vector. Map keys are stored in this form""" - # Keys are strings here. The format is all strings seperated by null, followed by back + # Keys are strings here. The format is all strings separated by null, followed by back # offsets for each of the string. For example, (str1)\0(str1)\0(offset1)(offset2) The end # pointer is pointing at the end of all strings keys = list() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index fa2772c1299f..3793f947c5cc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad): GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk """ lhs, rhs = orig.args + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True): + # ki, jk -> ij + # jk, ij -> ki + # ij, ki -> jk + return [ + collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False): + # ki, kj -> ij + # kj, ij -> ki + # ki, ij -> kj + return [ + collapse_sum_like( + _nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs + ), + collapse_sum_like( + _nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs + ), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True): + # ik, jk -> ij + # ij, jk -> ik + # ij, ik -> jk + # Keep using NT format batch_matmul here for not involving extra ops + # TODO(jcf94): Merge all to normal batch_matmul when it is finally ready + return [ + collapse_sum_like( + _nn.batch_matmul( + grad, + transpose(rhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + lhs, + ), + collapse_sum_like( + _nn.batch_matmul( + transpose(grad, [0, 2, 1]), + transpose(lhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + rhs, + ), + ] + # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False) + # ik, kj -> ij + # ij, kj -> ik + # ik, ij -> kj return [ - collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), - collapse_sum_like( - _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs - ), + collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs), ] diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..cec7c4d141cb 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -23,6 +23,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.ir import Op from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1035,6 +1036,7 @@ def visit_tuple_getitem(self, op): return visit if ( isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) and visit.tuple_value.op.name == "nn.dropout" and visit.index == 0 ): diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 32bd88456ffc..5e97d2461100 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -26,36 +26,36 @@ # resize -@reg.register_compute("dyn.image.resize") -def compute_resize(attrs, inputs, out_type): +@reg.register_compute("dyn.image.resize2d") +def compute_resize2d(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ - tvm.topi.image.resize( + tvm.topi.image.resize2d( inputs[0], inputs[1], layout, method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, out_type.shape, ) ] -reg.register_injective_schedule("dyn.image.resize") +reg.register_injective_schedule("dyn.image.resize2d") @script -def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): +def _resize2d_shape_func(dshape, size, ndim, height_axis, width_axis): out = output_tensor((ndim,), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) @@ -64,15 +64,15 @@ def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): return out -@reg.register_shape_func("dyn.image.resize", True) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("dyn.image.resize2d", True) +def resize2d_shape_func(attrs, inputs, _): """ Shape function for dyn.image.resize op. """ layout = attrs.layout if nchw_pack_layout(layout) or nchw_xc_layout(layout): out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), convert(2), convert(3) ) ] @@ -84,7 +84,7 @@ def resize_shape_func(attrs, inputs, _): if letter == "W": width_axis = i out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 2071a43f828b..ec24ff76b90e 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,42 +26,42 @@ from .. import op as reg from .. import strategy from ..op import OpPattern -from .image import resize +from .image import resize2d # resize -@reg.register_compute("image.resize") -def compute_resize(attrs, inputs, out_type): - """compute definition for resize op""" +@reg.register_compute("image.resize1d") +def compute_resize1d(attrs, inputs, out_type): + """compute definition for resize1d op""" size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ - topi.image.resize( + topi.image.resize1d( inputs[0], size, layout, method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) ] -reg.register_injective_schedule("image.resize") +reg.register_injective_schedule("image.resize1d") -@reg.register_convert_op_layout("image.resize") -def convert_image_resize(attrs, inputs, tinfos, desired_layouts): - """Convert Layout pass registration for image resize op. +@reg.register_convert_op_layout("image.resize1d") +def convert_image_resize1d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize1d op. Parameters ---------- @@ -86,11 +86,104 @@ def convert_image_resize(attrs, inputs, tinfos, desired_layouts): desired_layout = str(desired_layouts[0]) assert desired_layout != "default", "Layout cannot be default" new_attrs["layout"] = desired_layout - return resize(*inputs, **new_attrs) + return resize1d(*inputs, **new_attrs) @script -def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): +def _resize1d_shape_func(image_shape, size, batch_axis, width_axis, channel_axis): + out = output_tensor((3,), "int64") + out[batch_axis] = int64(image_shape[0]) + out[width_axis] = int64(size[1]) + out[channel_axis] = image_shape[channel_axis] + return out + + +@reg.register_shape_func("image.resize1d", False) +def resize1d_shape_func(attrs, inputs, _): + """ + Shape function for resize2d op. + """ + layout = attrs.layout + width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "N": + batch_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + size = get_const_tuple(attrs.size) + return [ + _resize1d_shape_func( + inputs[0], + convert(size), + convert(batch_axis), + convert(width_axis), + convert(channel_axis), + ) + ] + + +@reg.register_compute("image.resize2d") +def compute_resize2d(attrs, inputs, out_type): + """compute definition for resize2d op""" + size = attrs.size + layout = attrs.layout + method = attrs.method + coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude + out_dtype = attrs.out_dtype + return [ + topi.image.resize2d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + ] + + +reg.register_injective_schedule("image.resize2d") + + +@reg.register_convert_op_layout("image.resize2d") +def convert_image_resize2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize2d op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize2d(*inputs, **new_attrs) + + +@script +def _resize2d_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") out[batch_axis] = int64(image_shape[0]) out[height_axis] = int64(size[0]) @@ -99,10 +192,10 @@ def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, c return out -@reg.register_shape_func("image.resize", False) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("image.resize2d", False) +def resize2d_shape_func(attrs, inputs, _): """ - Shape function for resize op. + Shape function for resize2d op. """ layout = attrs.layout height_axis = width_axis = channel_axis = 1 @@ -117,7 +210,7 @@ def resize_shape_func(attrs, inputs, _): channel_axis = i size = get_const_tuple(attrs.size) return [ - _resize_shape_func( + _resize2d_shape_func( inputs[0], convert(size), convert(batch_axis), @@ -130,12 +223,28 @@ def resize_shape_func(attrs, inputs, _): @reg.register_compute("image.resize3d") def compute_resize3d(attrs, inputs, out_type): + """compute definition for resize3d op""" size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype - return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)] + return [ + topi.image.resize3d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + ] reg.register_injective_schedule("image.resize3d") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 6d7d79264844..7f5bd80159f9 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -20,18 +20,94 @@ from ...expr import Expr, Constant -def resize( +def resize1d( + data, + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.5, + cubic_exclude=0, + out_dtype=None, +): + """Image resize1d operator. + + This operator takes data as input and does 1D scaling to the given scale factor. + In the default case, where the data_layout is `NCW` + with data of shape (n, c, w) + out will have a shape (n, c, size[0]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Int or Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for cubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The resized result. + """ + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize1d is not yet implemented, got size", size) + return _make.resize1d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + + +def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, - bicubic_exclude=0, + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): - """Image resize operator. + """Image resize2d operator. This operator takes data as input and does 2D scaling to the given scale factor. In the default case, where the data_layout is `NCHW` @@ -39,7 +115,7 @@ def resize( out will have a shape (n, c, size[0], size[1]) method indicates the algorithm to be used while calculating the out value - and method can be one of ("bilinear", "nearest_neighbor", "bicubic") + and method can be one of ("linear", "nearest_neighbor", "cubic") Parameters ---------- @@ -53,7 +129,7 @@ def resize( Layout of the input. method : str, optional - Scale method to used [nearest_neighbor, bilinear, bicubic]. + Scale method to used [nearest_neighbor, linear, cubic]. coordinate_transformation_mode : string, optional Describes how to transform the coordinate in the resized tensor @@ -65,10 +141,10 @@ def resize( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] - bicubic_alpha: float - Spline Coefficient for Bicubic Interpolation + cubic_alpha: float + Spline Coefficient for bicubic interpolation - bicubic_exclude: int + cubic_exclude: int Flag to exclude exterior of the image during bicubic interpolation out_dtype : str, optional @@ -82,26 +158,26 @@ def resize( if isinstance(size, Constant): size = list(size.data.numpy().astype("int32")) if isinstance(size, Expr): - return _dyn_make.resize( + return _dyn_make.resize2d( data, size, layout, method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) - return _make.resize( + return _make.resize2d( data, size, layout, method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) @@ -110,11 +186,14 @@ def resize3d( data, size, layout="NCDHW", - method="trilinear", + method="linear", coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): - """Image resize 3D operator. + """Image resize3d operator. This operator takes data as input and does 3D scaling to the given scale factor. In the default case, where the data_layout is `NCDHW` @@ -122,27 +201,38 @@ def resize3d( out will have a shape `(n, c, size[0], size[1], size[2])` method indicates the algorithm to be used while calculating the out value - and method can be one of ("trilinear", "nearest_neighbor") + and method can be one of ("linear", "nearest_neighbor", "cubic") Parameters ---------- data : relay.Expr The input data to the operator. - size: Tuple of Expr + size: Tuple of Int or Expr The out size to which the image will be resized. layout : str, optional Layout of the input. method : str, optional - Scale method to used [nearest_neighbor, trilinear]. + Scale method to used [nearest_neighbor, linear, cubic]. coordinate_transformation_mode : string, optional Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. [half_pixel, align_corners, asymmetric] + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for cubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -151,7 +241,21 @@ def resize3d( result: relay.Expr The resized result. """ - return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype) + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize3d is not yet implemented, got size", size) + return _make.resize3d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) def crop_and_resize( diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 753a17605667..96cef8bc3588 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -964,7 +964,8 @@ def compute_space_to_depth(attrs, inputs, out_dtype): @script -def _conv_shape_func(dshape, kshape, strides, padding, dilation): +def _conv_shape_func_nchw(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nchw & oihw layout.""" out = output_tensor((dshape.shape[0],), "int64") out[0] = dshape[0] out[1] = kshape[0] @@ -975,23 +976,52 @@ def _conv_shape_func(dshape, kshape, strides, padding, dilation): return out +@script +def _conv_shape_func_nhwc_hwio(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwio layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 1] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + +@script +def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & hwoi layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[kshape.shape[0] - 2] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + def conv_shape_func(attrs, inputs, _): - """ - Shape function for contrib_conv2d_NCHWc op. - """ + """Shape function for conv*d op.""" strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) dilation = get_const_tuple(attrs.dilation) - return [ - _conv_shape_func( - inputs[0], - inputs[1], - convert(strides), - convert(padding), - convert(dilation), + shape_func = None + if attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] == "OIHW": + shape_func = _conv_shape_func_nchw + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO": + shape_func = _conv_shape_func_nhwc_hwio + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI": + shape_func = _conv_shape_func_nhwc_hwoi + else: + raise ValueError( + "Unsupported data/kernel layout: %s, %s" + % (attrs["data_layout"], attrs["kernel_layout"]) ) - ] + + return [shape_func(inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation))] reg.register_shape_func("nn.conv1d", False, conv_shape_func) @@ -1052,27 +1082,22 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _): @script -def _conv2d_transpose_nchw_shape_func(dshape, kshape, strides, padding, dilation, output_padding): +def _conv_transpose_shape_func(dshape, kshape, strides, padding, dilation, output_padding): out = output_tensor((dshape.shape[0],), "int64") - kheight = kshape[2] - kwidth = kshape[3] - dilated_kh = (kheight - 1) * dilation[0] + 1 - dilated_kw = (kwidth - 1) * dilation[1] + 1 - - out_height = strides[0] * (dshape[2] - 1) + dilated_kh - 2 * padding[0] + output_padding[0] - out_width = strides[1] * (dshape[3] - 1) + dilated_kw - 2 * padding[1] + output_padding[1] - out[0] = dshape[0] out[1] = kshape[1] - out[2] = out_height - out[3] = out_width + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1 + out[i + 2] = ( + strides[i] * (dshape[i + 2] - 1) + dilated_k - 2 * padding[i] + output_padding[i] + ) return out -@reg.register_shape_func("nn.conv2d_transpose", False) -def conv2d_transpose_nchw_shape_func(attrs, inputs, _): +def conv_transpose_shape_func(attrs, inputs, _): """ - Shape function for conv2d_transpose op. + Shape function for contrib_conv2d_NCHWc op. """ strides = get_const_tuple(attrs.strides) padding = get_const_tuple(attrs.padding) @@ -1080,7 +1105,7 @@ def conv2d_transpose_nchw_shape_func(attrs, inputs, _): output_padding = get_const_tuple(attrs.output_padding) return [ - _conv2d_transpose_nchw_shape_func( + _conv_transpose_shape_func( inputs[0], inputs[1], convert(strides), @@ -1091,6 +1116,10 @@ def conv2d_transpose_nchw_shape_func(attrs, inputs, _): ] +reg.register_shape_func("nn.conv1d_transpose", False, conv_transpose_shape_func) +reg.register_shape_func("nn.conv2d_transpose", False, conv_transpose_shape_func) + + @script def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis): out = output_tensor((data_shape.shape[0],), "int64") @@ -1247,14 +1276,11 @@ def dense_pack_shape_func(attrs, inputs, _): @script -def _batch_matmul_shape_func(data_shape, weight_shape): - out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - if i == 0: - out[i] = max(data_shape[i], weight_shape[i]) - else: - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] +def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b): + out = output_tensor((tensor_a_shape.shape[0],), "int64") + out[0] = max(tensor_a_shape[0], tensor_b_shape[0]) + out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1] + out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2] return out @@ -1262,9 +1288,16 @@ def _batch_matmul_shape_func(data_shape, weight_shape): @reg.register_shape_func("nn.batch_matmul", False) def batch_matmul_shape_func(attrs, inputs, _): """ - Shape function for dense op. + Shape function for batch matmul op. """ - ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] + ret = [ + _batch_matmul_shape_func( + inputs[0], + inputs[1], + expr.IntImm("bool", attrs.transpose_a), + expr.IntImm("bool", attrs.transpose_b), + ) + ] return ret @@ -1307,4 +1340,7 @@ def dilate_shape_func(attrs, inputs, _): reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) +reg.register_shape_func("nn.fast_softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) +reg.register_shape_func("nn.leaky_relu", False, elemwise_shape_func) +reg.register_shape_func("nn.prelu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4c94102275bb..64b397a4d4f9 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2137,32 +2137,40 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True, return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) -def batch_matmul(x, y, out_dtype=""): +def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True): r""" - Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data - in batch. + Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. .. math:: - \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + \mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :]) Parameters ---------- - x : tvm.relay.Expr + tensor_a : tvm.relay.Expr The first input. - y : tvm.relay.Expr + tensor_b : tvm.relay.Expr The second input. - out_dtype : str, optional - Specifies the output data type for mixed precision batch matmul + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- result: tvm.relay.Expr The computed result. """ - return _make.batch_matmul(x, y, out_dtype) + return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b) # pylint: disable=no-else-return,inconsistent-return-statements diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 780badc89fc4..507dd9371a97 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -74,6 +74,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs") +class BatchMatmulAttrs(Attrs): + """Attributes for nn.batch_matmul""" + + @tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" @@ -139,9 +144,19 @@ class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@tvm._ffi.register_object("relay.attrs.ResizeAttrs") -class ResizeAttrs(Attrs): - """Attributes for image.resize""" +@tvm._ffi.register_object("relay.attrs.Resize1DAttrs") +class Resize1DAttrs(Attrs): + """Attributes for image.resize1d""" + + +@tvm._ffi.register_object("relay.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes for image.resize2d""" + + +@tvm._ffi.register_object("relay.attrs.Resize3DAttrs") +class Resize3DAttrs(Attrs): + """Attributes used in resize3d operators""" @tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs") @@ -499,11 +514,6 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" -@tvm._ffi.register_object("relay.attrs.Resize3dAttrs") -class Resize3dAttrs(Attrs): - """Attributes used in resize3d operators""" - - @tvm._ffi.register_object("relay.attrs.ScatterAttrs") class ScatterAttrs(Attrs): """Attributes used in scatter operators""" diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index 24e68a47bbeb..8008391fe86c 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -65,6 +65,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack), name="conv2d_nchw_spatial_pack.bifrost", ) + elif layout == "NHWC": + assert kernel_layout == "HWIO" + # For now just reuse general Mali strategy. + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.bifrost", + ) else: raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index aeeb62af11a9..ba47ae7bc4f1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() x, y = inputs - if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32": + if ( + x.dtype == "int8" + and y.dtype == "int8" + and out_type.dtype == "int32" + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8), @@ -840,17 +846,25 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) - if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + if ( + target.kind.name == "cuda" + and nvcc.have_tensorcore(target=target) + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) - if x.dtype in ["float16", "int8", "uint8"] and ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ): + if ( + x.dtype in ["float16", "int8", "uint8"] + and ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ) + ) or (x.dtype in ["int4", "uint4"] and K % 32 == 0 and M % 8 == 0 and N % 8 == 0): strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore), + wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore), name="batch_matmul_tensorcore.cuda", plevel=20, diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3348d8033904..9c756f201721 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] + args.append(out_type.dtype if need_out_dtype else None) + args.append(attrs.transpose_a) + args.append(attrs.transpose_b) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) - if need_out_dtype: - args.append(out_type.dtype) return [topi_compute(*args)] return _compute_batch_matmul diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index 6c6440e486f1..d38fe0d82758 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -73,36 +73,39 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" if not is_auto_scheduler_enabled(): - raise RuntimeError( - "conv2d NHWC layout is not enabled for mali without auto_scheduler." - ) - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), - naive_schedule, - name="conv2d_nhwc.mali", - ) - is_winograd_applicable = False - if len(kernel.shape) == 4: - kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape) - is_winograd_applicable = ( - "float" in data.dtype - and "float" in kernel.dtype - and kernel_h == 3 - and kernel_w == 3 - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.mali", ) - if is_winograd_applicable: + else: strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True - ), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + naive_schedule, + name="conv2d_nhwc.mali", ) + is_winograd_applicable = False + if len(kernel.shape) == 4: + kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape) + is_winograd_applicable = ( + "float" in data.dtype + and "float" in kernel.dtype + and kernel_h == 3 + and kernel_w == 3 + and stride_h == 1 + and stride_w == 1 + and dilation_h == 1 + and dilation_w == 1 + ) + if is_winograd_applicable: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) else: raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout)) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index f4538071e11e..64373dcdd7bf 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -20,6 +20,7 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib.thrust import can_use_rocthrust +from tvm.contrib import miopen from .generic import * from .. import op as _op @@ -304,3 +305,41 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): plevel=15, ) return strategy + + +@softmax_strategy.register(["rocm"]) +def softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="softmax.miopen", + plevel=15, + ) + return strategy + + +@log_softmax_strategy.register(["rocm"]) +def log_softmax_strategy_rocm(attrs, inputs, out_type, target): + """rocm strategy for log softmax""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="log_softmax.rocm", + ) + if "miopen" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(miopen.log_softmax), + wrap_topi_schedule(topi.generic.schedule_extern), + name="log_softmax.miopen", + plevel=15, + ) + return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6a4030514580..a6e141f2753b 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -521,14 +521,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if is_dynamic(out_type) or is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True), + wrap_compute_batch_matmul( + topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True + ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, ) else: strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", plevel=10, diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 376bf4a4804d..542980561e78 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude): return None -def _get_name_static(canonical, dtype, shape): - """Get name for static shape tensor array op corresponding - to the canonical name""" +def _get_name_static(canonical, dtype, shape, batch_dim=None): + """Get name for static shape tensor array op + + By design, static ADT tensor in TVM has type name in the format + of static_tensor_dim0_dim1_..._dimN_t + or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item. + + Parameters + ---------- + canonical : String + Tensor array op name + + dtype : str + Data type. + + shape : tuple of (int, Any) or None + Tensor array shape + + batch_dim: None or int + 1 if tensorlist stack only have one item. + None by default + + Returns + ------- + name : String + The tensor array op name + """ dim_names = [] for dim in shape: if isinstance(dim, Any): @@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape): shape_str = "scalar" if canonical == "tensor_t": return "static_tensor_{}_{}_t".format(dtype, shape_str) - return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]: + return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim != 1: + return "{}_{}_{}".format(canonical, dtype, shape_str) + return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) class StaticTensorArrayOps(object): """Contains tensor array related ops for fixed rank tensor array""" - def __init__(self, prelude, dtype, shape): + def __init__(self, prelude, dtype, shape, batch_dim=None): """Create tensor array ops registry""" self.prelude = prelude self.dtype = dtype self.shape = shape + self.batch_dim = batch_dim self.list, self.cons, self.nil = self.prelude.mod.get_type("List") def get_name(self, canonical): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, self.dtype, self.shape) + return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim) def get_global_var(self, canonical): """Get global corresponding to the canonical name""" - return self.prelude.get_global_var_static(canonical, self.dtype, self.shape) + return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim) def get_type(self, canonical): """Get type corresponding to the canonical name""" @@ -262,9 +291,10 @@ def define_tensor_expand_dims(self): # Note: we set the added axis to be Any() instead of 1 due to # in stack op, we need to recursively concatenate. + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape( [ - Any(), + new_axis, ] + list(self.shape) ) @@ -573,20 +603,27 @@ def define_tensor_array_stack(self): expand_dims_var = self.get_global_var("tensor_expand_dims") # Register tensor_concatenate for output_shape + new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) - _, _, output_ops = self._get_adt_by_shape(output_shape) output_ops.define_tensor_concatenate() concat_var = output_ops.get_global_var("tensor_concatenate") tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) - tensors = self.prelude.foldl( - concat_var, - self.prelude.hd(tensor_array_expand_dims), - self.prelude.tl(tensor_array_expand_dims), - ) + if self.batch_dim is not None and self.batch_dim == 1: + # only one element + tensors = self.prelude.id( + self.prelude.hd(tensor_array_expand_dims), + ) + else: + tensors = self.prelude.foldl( + concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims), + ) + output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) self.prelude.mod[stack_var] = Function( [tensor_array], tensors, output_tensor_type_var(), [] @@ -599,8 +636,9 @@ def define_tensor_array_gather(self): helper_name = self.get_name("tensor_array_gather_helper") helper_var = self._create_global_var(helper_name) + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) stack_var = self.get_global_var("tensor_array_stack") @@ -668,7 +706,7 @@ def register(self): def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" - adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape) + adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim) adt_ops.define_tensor_adt() tensor_type_var = adt_ops.get_type("tensor_t") tensor_constructor = adt_ops.get_ctor("tensor_constructor") @@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype): ty = self.get_type("tensor_t", dtype) return self.get_ctor(ty.name_hint, canonical, dtype) - def get_name_static(self, canonical, dtype, shape): + def get_name_static(self, canonical, dtype, shape, batch_dim=None): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, dtype, shape) + return _get_name_static(canonical, dtype, shape, batch_dim) - def get_global_var_static(self, canonical, dtype, shape): + def get_global_var_static(self, canonical, dtype, shape, batch_dim=None): """Get var corresponding to the canonical name""" - name = self.get_name_static(canonical, dtype, shape) + name = self.get_name_static(canonical, dtype, shape, batch_dim) return self.mod.get_global_var(name) def get_type_static(self, canonical, dtype, shape): diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index f02f8227e14a..e74256ec74c3 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -682,6 +682,44 @@ def subtract( ) +def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="int32"): + r""" + Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data + in batch. + + .. math:: + + \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + + Parameters + ---------- + x : tvm.relay.Expr + The first quantized input. + A quantized tensor is represented in following manner + `A = scale_a x (QA - zp_A)` + where QA is quantized tensor, scale_a and zp_A are quantization + params. + y : tvm.relay.Expr + The second quantized input. + x_zero_point: tvm.relay.Expr + The first input zero point. + y_zero_point: tvm.relay.Expr + The second input zero point. + x_scale: tvm.relay.Expr + The scale for the first input tensor. + y_scale: tvm.relay.Expr + The scale for the second input tensor. + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + + Returns + ------- + result: tvm.relay.Expr + The computed result. + """ + return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype) + + # register fuse pattern for qnn ops reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index bfe797d844a8..de85ed69238a 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -154,15 +154,16 @@ def check_grad( assert len(grads) > 0, "You must test at least one gradient." # Get numeric gradients for each dimension of each param, using two-sided approximation. + fwd_func_compiled = intrp.evaluate(fwd_func) approx_grads = [] for x in test_inputs: approx_grad = np.zeros(x.shape) for i in np.ndindex(*x.shape): x_i = x[i] x[i] = x_i + eps - fwd_plus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64") + fwd_plus = fwd_func_compiled(*inputs).numpy().astype("float64") x[i] = x_i - eps - fwd_minus = intrp.evaluate(fwd_func)(*inputs).numpy().astype("float64") + fwd_minus = fwd_func_compiled(*inputs).numpy().astype("float64") x[i] = x_i approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps)) approx_grads.append(approx_grad) diff --git a/python/tvm/relay/testing/yolo_detection.py b/python/tvm/relay/testing/yolo_detection.py index a387f3076bf5..949d024bd86f 100644 --- a/python/tvm/relay/testing/yolo_detection.py +++ b/python/tvm/relay/testing/yolo_detection.py @@ -103,8 +103,8 @@ def _get_yolo_detections(l, im_shape, net_shape, thresh, relative, dets): l["biases"], np.asarray(l["mask"])[location[0]], location, - data.shape[2], data.shape[3], + data.shape[2], net_shape[0], net_shape[1], ) @@ -139,10 +139,10 @@ def _get_region_detections(l, im_shape, net_shape, thresh, relative, dets): l["biases"], n, location, - data.shape[2], data.shape[3], data.shape[2], data.shape[3], + data.shape[2], ) objectness = scale if scale > thresh else 0 if objectness: diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 9ed40f85c3bc..378b0c38ff64 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,4 @@ # transformation passes from .transform import * from .recast import recast -from . import fake_quantization_to_integer +from . import fake_quantization_to_integer, mixed_precision diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 5f4c53772eec..783204fb700f 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -17,13 +17,12 @@ """Relay functions for rewriting fake quantized ops.""" import tvm from tvm import relay +from tvm.ir import TensorAffineType, TupleAffineType from ..op import register_fake_quantization_to_integer def fold_constant(expr): - mod = tvm.IRModule.from_expr(expr) - mod = relay.transform.FoldConstant()(mod) - return mod["main"].body + return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) @register_fake_quantization_to_integer("qnn.dequantize") @@ -31,7 +30,7 @@ def dequantize(expr, type_map): """Remove dequantize op""" out = expr.args[0] t = type_map[expr] - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("qnn.quantize") @@ -54,23 +53,26 @@ def quantize(expr, type_map): expr.args[2], out_dtype=expr.attrs.out_dtype, ) - return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] -def register_unary_identity(op_name, op): +def register_unary_identity(op_name): def identity(expr, type_map): assert len(expr.args) == 1 arg = expr.args[0] t = type_map[arg] - out = op(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] + return [expr, t] return register_fake_quantization_to_integer(op_name, identity) -register_unary_identity("reshape", relay.op.reshape) -register_unary_identity("transpose", relay.op.transpose) -register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) +register_unary_identity("reshape") +register_unary_identity("squeeze") +register_unary_identity("strided_slice") +register_unary_identity("transpose") +register_unary_identity("expand_dims") +register_unary_identity("nn.max_pool2d") +register_unary_identity("nn.batch_flatten") @register_fake_quantization_to_integer("nn.avg_pool2d") @@ -81,7 +83,7 @@ def avgpool2d(expr, type_map): arg = relay.op.cast(arg, "int32") out = relay.op.nn.avg_pool2d(arg, **expr.attrs) out = relay.op.cast(out, t.dtype) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("nn.bias_add") @@ -99,10 +101,10 @@ def bias_add(expr, type_map): b_t.zero_point, in_scale, in_zero_point, - out_dtype=xt.dtype, + out_dtype=x_t.dtype, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) - return [out, x_t.scale, x_t.zero_point, x_t.dtype] + return [out, x_t] @register_fake_quantization_to_integer("nn.conv2d") @@ -118,7 +120,23 @@ def conv2d(expr, type_map): out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, conv_scale, conv_zp, out.attrs.out_dtype] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + + +@register_fake_quantization_to_integer("nn.dense") +def dense(expr, type_map): + """Rewrite a dense op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + dense_scale = fold_constant(x_t.scale * w_t.scale) + dense_zp = relay.const(0) + out = relay.qnn.op.dense( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] @register_fake_quantization_to_integer("concatenate") @@ -126,8 +144,9 @@ def concat(expr, type_map): """Rewrite a concat op""" scales = [] zps = [] - for arg in expr.args[0].fields: - t = type_map[arg] + + tuple_type = type_map[expr.args[0]] + for t in tuple_type.types: scales.append(t.scale) zps.append(t.zero_point) @@ -141,7 +160,21 @@ def concat(expr, type_map): out_type.zero_point, **expr.attrs, ) - return [out, out_type.scale, out_type.zero_point, out_type.dtype] + return [out, out_type] + + +@register_fake_quantization_to_integer("split") +def split(expr, type_map): + """Rewrite a split op""" + arg = expr.args[0] + t = type_map[arg] + attrs = {**expr.attrs} + if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): + num_split = attrs["indices_or_sections"].value + attrs["indices_or_sections"] = num_split + else: + num_split = len(attrs["indices_or_sections"]) + 1 + return [expr, TupleAffineType([t] * num_split)] @register_fake_quantization_to_integer("clip") @@ -163,4 +196,133 @@ def clip(expr, type_map): amin = relay.op.round(relay.op.const(amin) / scale + z_p) amax = relay.op.round(relay.op.const(amax) / scale + z_p) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] + + +@register_fake_quantization_to_integer("nn.pad") +def pad(expr, type_map): + """Rewite an nn.pad op""" + arg = expr.args[0] + t = type_map[arg] + pad_value = expr.args[1] + ## TF2ONNX will sometimes implement the pad_value as a constant without a quantize + ## To support that, the pass lets branches that terminate in a constant through + if pad_value in type_map: + ## if the pad value is calcuated from a dequantize op, it should be in the type map + ## and we need to make sure it's affine type matches the arg + pad_t = type_map[pad_value] + if not tvm.ir.structural_equal(t, pad_t): + pad_value = relay.qnn.op.requantize( + pad_value, + pad_t.scale, + pad_t.zero_point, + t.scale, + t.zero_point, + out_dtype=t.dtype, + ) + else: + ## If the pad-value is a constant, we need to quantize it + assert isinstance(pad_value, relay.expr.Constant) + pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point) + + out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs) + return [out, t] + + +def get_binary_types(expr, type_map): + """Get Affine types of a binary op's inputs and unify them""" + ##Support the case where one input is quantized and the other is a constant float + left = expr.args[0] + right = expr.args[1] + left_t = None + right_t = None + + if left in type_map: + left_t = type_map[left] + if right in type_map: + right_t = type_map[right] + + out_t = type_map[expr] + if left_t is None and right_t is None: + raise TypeError("neither input is quantized!") + if left_t is None: + assert isinstance(left, relay.expr.Constant) + left = relay.qnn.op.quantize( + left, right_t.scale, right_t.zero_point, out_dtype=right_t.dtype + ) + left_t = right_t + out_t = right_t + if right_t is None: + assert isinstance(right, relay.expr.Constant) + right = relay.qnn.op.quantize( + right, left_t.scale, left_t.zero_point, out_dtype=left_t.dtype + ) + right_t = left_t + out_t = left_t + + # Handle the case of mismatched inputs + if not left_t.dtype == out_t.dtype: + out_t = left_t + + return left, right, left_t, right_t, out_t + + +def register_binary_qnn(op_name, op): + """Register a Binary Op that converts to QNN""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + out = op( + left, + right, + left_t.scale, + left_t.zero_point, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + ) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +# Use lambdas here to avoid a circular import problem +# pylint: disable=unnecessary-lambda +register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args)) +register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args)) +register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args)) + + +def register_binary_identity(op_name, op): + """Register a binary op that works directly on int8""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + if left_t != out_t: + left = relay.qnn.op.requantize( + left, + left_t.scale, + left_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + + if right_t != out_t: + right = relay.qnn.op.requantize( + right, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + out = op(left, right) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +register_binary_identity("minimum", relay.op.minimum) +register_binary_identity("maximum", relay.op.maximum) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6f8ecb970221..1e982a0f18a4 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -18,7 +18,6 @@ """Default behavior for ops in mixed_precision pass. Import this file to use.""" from typing import List -from tvm import relay from tvm.relay.op import register_mixed_precision_conversion # MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory @@ -141,7 +140,7 @@ def decorator(func): return decorator -def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: +def get_generic_out_dtypes(call_node: "relay.Call", mixed_precision_type: str) -> List[str]: """A function which returns output dtypes in a way which works for most ops. Parameters @@ -174,15 +173,15 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> # Take in CallNodes and a DType and returns a conversion type, # an accumulation dtype, and an output_dtype. @register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST) -def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_always_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST) -def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_follow_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type) @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) -def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: +def generic_never_op(call_node: "relay.Call", mixed_precision_type: str) -> List: return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 4531ceca2ce9..d8199c4c93a6 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -327,7 +327,7 @@ def text_summary(self): res += "----------------------------\n" for item in data["server_info"]: addr = item["addr"] - res += addr[0] + ":" + str(addr[1]) + "\t" + res += str(addr[0]) + ":" + str(addr[1]) + "\t" res += item["key"] + "\n" key = item["key"].split(":")[1] # 'server:rasp3b` -> 'rasp3b' if key not in total_ct: diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 0b49b675d77d..52a7a898269a 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -365,9 +365,14 @@ def _popen_start_rpc_server( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): if no_fork: multiprocessing.set_start_method("spawn") + + if server_init_callback: + server_init_callback() + # This is a function that will be sent to the # Popen worker to run on a separate process. # Create and start the server in a different thread @@ -420,6 +425,25 @@ class Server(object): no_fork: bool, optional Whether forbid fork in multiprocessing. + + server_init_callback: Callable, optional + Additional initialization function when starting the server. + + Note + ---- + The RPC server only sees functions in the tvm namespace. + To bring additional custom functions to the server env, you can use server_init_callback. + + .. code:: python + + def server_init_callback(): + import tvm + # must import mypackage here + import mypackage + + tvm.register_func("function", mypackage.func) + + server = rpc.Server(host, server_init_callback=server_init_callback) """ def __init__( @@ -434,6 +458,7 @@ def __init__( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): try: if _ffi_api.ServerLoop is None: @@ -455,6 +480,7 @@ def __init__( custom_addr, silent, no_fork, + server_init_callback, ], ) # receive the port diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index 0c2abd296b42..cac7c6d8fdae 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -56,8 +56,10 @@ def __dir__(self): return sorted([fnames(i) for i in range(size)] + class_names) def __getattr__(self, name): - if name in self.__slots__: - raise AttributeError(f"{name} is not set") + # specially check handle since + # this is required for PackedFunc calls + if name == "handle": + raise AttributeError("handle is not set") try: return _ffi_node_api.NodeGetAttr(self, name) diff --git a/python/tvm/runtime/profiler_vm.py b/python/tvm/runtime/profiler_vm.py index e1c3dc66a360..b3043d8b8760 100644 --- a/python/tvm/runtime/profiler_vm.py +++ b/python/tvm/runtime/profiler_vm.py @@ -50,7 +50,7 @@ def get_stat(self, sort_by_time=True): # pylint: disable=unused-argument warnings.warn("get_stat has been removed, use profile instead") return "" - def profile(self, *args, func_name="main", **kwargs): + def profile(self, *args, func_name="main", collectors=None, **kwargs): """Profile a function call. Parameters @@ -58,6 +58,9 @@ def profile(self, *args, func_name="main", **kwargs): func_name : str The name of the function. + collectors : Optional[Sequence[MetricCollector]] + Extra metrics to collect. + args : list[tvm.runtime.NDArray] or list[np.ndarray] The arguments to the function. @@ -69,6 +72,7 @@ def profile(self, *args, func_name="main", **kwargs): timing_results : str Overall and per-op timing results formatted in a table. """ + collectors = [] if collectors is None else collectors if args or kwargs: self.set_input(func_name, *args, **kwargs) - return self._profile(func_name) + return self._profile(func_name, collectors) diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py new file mode 100644 index 000000000000..881691609398 --- /dev/null +++ b/python/tvm/runtime/profiling/__init__.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Registration of profiling objects in python.""" + +from typing import Dict, Sequence, Optional +from ... import _ffi +from . import _ffi_api +from .. import Object, Device + + +@_ffi.register_object("runtime.profiling.Report") +class Report(Object): + """A container for information gathered during a profiling run. + + Attributes + ---------- + calls : Array[Dict[str, Object]] + Per-call profiling metrics (function name, runtime, device, ...). + + device_metrics : Dict[Device, Dict[str, Object]] + Per-device metrics collected over the entire run. + """ + + def csv(self): + """Convert this profiling report into CSV format. + + This only includes calls and not overall metrics. + + Returns + ------- + csv : str + `calls` in CSV format. + """ + return _ffi_api.AsCSV(self) + + def json(self): + """Convert this profiling report into JSON format. + + Example output: + + .. code-block: + + { + "calls": [ + { + "Duration (us)": { + "microseconds": 12.3 + }, + "Name": "fused_dense", + "Count": { + "count": 1 + }, + "Percent": { + "percent": 10.3 + } + } + ], + "device_metrics": { + "cpu": { + "Duration (us)": { + "microseconds": 334.2 + }, + "Percent": { + "percent": 100 + } + } + } + } + + {"calls": + [ + {"Duration (us)": {"microseconds": 12.3} + ,"Name": "fused_dense" + ,"Count": {"count":1} + ,"Percent": {"percent": 10.3} + } + ], + "device_metrics": + {"cpu": + {"Duration (us)": {"microseconds": 334.2} + ,"Percent": {"percent": 100.0} + } + } + } + + Returns + ------- + json : str + Formatted JSON + """ + return _ffi_api.AsJSON(self) + + +@_ffi.register_object("runtime.profiling.MetricCollector") +class MetricCollector(Object): + """Interface for user defined profiling metric collection.""" + + +@_ffi.register_object("runtime.profiling.DeviceWrapper") +class DeviceWrapper(Object): + """Wraps a tvm.runtime.Device""" + + def __init__(self, dev: Device): + self.__init_handle_by_constructor__(_ffi_api.DeviceWrapper, dev) + + +# We only enable this class when TVM is build with PAPI support +if _ffi.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is not None: + + @_ffi.register_object("runtime.profiling.PAPIMetricCollector") + class PAPIMetricCollector(MetricCollector): + """Collects performance counter information using the Performance + Application Programming Interface (PAPI). + """ + + def __init__(self, metric_names: Optional[Dict[Device, Sequence[str]]] = None): + """ + Parameters + ---------- + metric_names : Optional[Dict[Device, Sequence[str]]] + List of per-device metrics to collect. You can find a list of valid + metrics by runing `papi_native_avail` from the command line. + """ + metric_names = {} if metric_names is None else metric_names + wrapped = dict() + for dev, names in metric_names.items(): + wrapped[DeviceWrapper(dev)] = names + self.__init_handle_by_constructor__(_ffi_api.PAPIMetricCollector, wrapped) diff --git a/python/tvm/runtime/profiling.py b/python/tvm/runtime/profiling/_ffi_api.py similarity index 52% rename from python/tvm/runtime/profiling.py rename to python/tvm/runtime/profiling/_ffi_api.py index 5a1cd6796b64..d26b847a699f 100644 --- a/python/tvm/runtime/profiling.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -14,35 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Registration of profiling objects in python.""" - -from .. import _ffi -from . import Object +"""FFI for profiling""" +from ... import _ffi _ffi._init_api("runtime.profiling", __name__) - - -@_ffi.register_object("runtime.profiling.Report") -class Report(Object): - """A container for information gathered during a profiling run. - - Attributes - ---------- - calls : Array[Dict[str, Object]] - Per-call profiling metrics (function name, runtime, device, ...). - - device_metrics : Dict[Device, Dict[str, Object]] - Per-device metrics collected over the entire run. - """ - - def csv(self): - """Convert this profiling report into CSV format. - - This only includes calls and not overall metrics. - - Returns - ------- - csv : str - `calls` in CSV format. - """ - return AsCSV(self) diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 429da5892628..e748c297c76f 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -45,7 +45,7 @@ def _convert(arg, cargs): _convert(field, field_args) cargs.append(container.tuple_object(field_args)) elif isinstance(arg, (_base.numeric_types, bool)): - dtype = "int32" if isinstance(arg, (int, bool)) else "float32" + dtype = "int32" if isinstance(arg, (_base.integer_types, bool)) else "float32" value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) cargs.append(value) else: diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index ae3e9d885f1a..44c92b792f12 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -57,7 +57,7 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer_region(C[vi, vj]) + D = tir.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value with tir.init(): @@ -65,13 +65,13 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # block body CC[0, 0] = A[vi, vk] * B[vj, vk] - D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] + D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] """ alloc_buffers: List[Buffer] = [] """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" + """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py index fc196f6b16ae..e676461ab39e 100644 --- a/python/tvm/script/diagnostics.py +++ b/python/tvm/script/diagnostics.py @@ -17,8 +17,9 @@ """Bridge from synr's (the library used for parsing the python AST) DiagnosticContext to TVM's diagnostics """ -import tvm from synr import DiagnosticContext, ast + +import tvm from tvm.ir.diagnostics import DiagnosticContext as TVMCtx from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 38ff1b71f07d..e2d44440e2ac 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -37,67 +37,67 @@ def handle(self, arg_list: List[Any], span: tvm.ir.Span): @register def bool(imm, span): - return tvm.tir.Cast("bool", imm, span) + return imm.astype("bool", span) @register def int8(imm, span): - return tvm.tir.Cast("int8", imm, span) + return imm.astype("int8", span) @register def int16(imm, span): - return tvm.tir.Cast("int16", imm, span) + return imm.astype("int16", span) @register def int32(imm, span): - return tvm.tir.Cast("int32", imm, span) + return imm.astype("int32", span) @register def int64(imm, span): - return tvm.tir.Cast("int64", imm, span) + return imm.astype("int64", span) @register def uint8(imm, span): - return tvm.tir.Cast("uint8", imm, span) + return imm.astype("uint8", span) @register def uint16(imm, span): - return tvm.tir.Cast("uint16", imm, span) + return imm.astype("uint16", span) @register def uint32(imm, span): - return tvm.tir.Cast("uint32", imm, span) + return imm.astype("uint32", span) @register def uint64(imm, span): - return tvm.tir.Cast("uint64", imm, span) + return imm.astype("uint64", span) @register def float8(imm, span): - return tvm.tir.Cast("float8", imm, span) + return imm.astype("float8", span) @register def float16(imm, span): - return tvm.tir.Cast("float16", imm, span) + return imm.astype("float16", span) @register def float32(imm, span): - return tvm.tir.Cast("float32", imm, span) + return imm.astype("float32", span) @register def float64(imm, span): - return tvm.tir.Cast("float64", imm, span) + return imm.astype("float64", span) @register @@ -120,6 +120,11 @@ def floormod(x, y, span): return tvm.tir.floormod(x, y, span) +@register +def abs(x, span): + return tvm.tir.abs(x, span) + + @register def load(dtype, var, index, predicate=None, span=None): return tvm.tir.Load(dtype, var, index, predicate, span) diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py index c4593683da78..cfbc668946a0 100644 --- a/python/tvm/script/node.py +++ b/python/tvm/script/node.py @@ -108,7 +108,7 @@ def check_index(index: Union[int, PrimExpr]): span, ) - slices: List[Slice] = [] + slices: List[Union[Slice, BufferSlice]] = [] for index in indices: if isinstance(index, Slice): check_index(index.start) @@ -117,6 +117,10 @@ def check_index(index: Union[int, PrimExpr]): elif isinstance(index, (PrimExpr, int)): check_index(index) slices.append(Slice(index)) + elif isinstance(index, BufferSlice): + buffer_load = index.asobject() + check_index(buffer_load) + slices.append(Slice(buffer_load)) else: report_error( "Unsupported index type for BufferSlice, " diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 49f71041590b..9acf21b6ba3a 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -784,10 +784,11 @@ def transform_Slice(self, node): def transform_Subscript(self, node): """Array access visitor. - By now only 2 types of Subscript are supported: + By now only 3 types of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) + 3. Array[index], Buffer element access """ symbol = self.transform(node.params[0]) @@ -812,6 +813,25 @@ def transform_Subscript(self, node): return BufferSlice( symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) ) + elif isinstance(symbol, tvm.container.Array): + if len(indexes) > 1: + self.report_error( + "Array access should be one-dimension access, but the indices are " + + str(indexes), + node.span, + ) + index = indexes[0] + if not isinstance(index, (int, tvm.tir.expr.IntImm)): + self.report_error( + "Array access index expected int or IntImm, but got " + type(index), + node.span, + ) + if int(index) >= len(symbol): + self.report_error( + f"Array access out of bound, size: {len(symbol)}, got index {index}.", + node.span, + ) + return symbol[int(index)] else: self.report_error( f"Cannot subscript from a {type(symbol).__name__}. Only variables and " diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index a23401d926e9..bb408f6cdc8f 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -110,10 +110,9 @@ def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): condition = tvm.runtime.convert(condition) scope = tvm.runtime.convert(scope) - body = tvm.tir.Allocate( + return tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body, span=span) super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None @@ -140,7 +139,7 @@ def enter_scope( def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 7eb938c58f96..25af7635742b 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -96,12 +96,24 @@ def handle( @register class MatchBuffer(SpecialStmt): - """Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, + """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align, offset_factor, buffer_type) + + Note + ---- + This Special Stmt will perform different behavior depends on the type of param. + If the param is a var in function parameter, it will create a buffer from DLTensor. + Else if the param is a subregion of other buffers, then create a subregion match inside a block. + Example ------- + Match buffer from function parameter .. code-block:: python A = tir.match_buffer(a, (128, 128), dtype="float32") + + Match buffer from Buffer subregion + .. code-block:: python + A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") """ def __init__(self): @@ -123,10 +135,6 @@ def match_buffer( "match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)", self.node.span, ) - if param not in self.context.func_params: - self.context.report_error( - "Can not bind non-input param to buffer", self.node.rhs.params[0].span - ) if strides is None: strides = [] align = convert_to_int(align, "align", self.context.report_error, self.node.span) @@ -146,7 +154,23 @@ def match_buffer( buffer_type, span=span, ) - self.context.func_buffer_map[param] = buffer + if isinstance(param, tvm.tir.Var): + if param not in self.context.func_params: + self.context.report_error( + "Can not bind non-input param to buffer", self.node.rhs.params[0].span + ) + self.context.func_buffer_map[param] = buffer + elif isinstance(param, BufferSlice): + buffer_region = buffer_slice_to_region(param) + self.context.current_block_scope().match_buffers.append( + tvm.tir.MatchBufferRegion(buffer, buffer_region) + ) + else: + self.context.report_error( + "The source of match_buffer expected Var or BufferSlice, but got " + + str(type(param)), + self.node.rhs.params[0].span, + ) self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -225,7 +249,7 @@ def alloc_buffer( data=None, strides=None, elem_offset=None, - scope="", + scope="global", align=-1, offset_factor=0, buffer_type="default", @@ -414,68 +438,6 @@ def where(predicate, span=None): super().__init__(where, def_symbol=False) -@register -class BlockMatchBufferRegion(SpecialStmt): - """Special function match_buffer_region(source, strides, elem_offset, align, offset_factor) - - Example - ------- - .. code-block:: python - - B = tir.match_buffer_region(A[0: 4]) - """ - - def __init__(self): - def match_buffer_region( - source, - strides=None, - elem_offset=None, - align=-1, - offset_factor=0, - span=None, - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - if not isinstance(self.node, ast.Assign): - self.context.report_error( - "match_buffer_region must be assigned to a buffer, " - + "e.g. A = match_buffer_region(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - - if not isinstance(source, BufferSlice): - self.context.report_error( - "match_buffer_region needs a buffer region as source", - span=span, - ) - buffer_region = buffer_slice_to_region(source) - shape = [r.extent for r in buffer_region.region] - buffer = tvm.tir.decl_buffer( - shape, - buffer_region.buffer.dtype, - self.node.lhs.id.name, - data=None, - strides=strides, - elem_offset=elem_offset, - scope=buffer_region.buffer.scope, - data_alignment=align, - offset_factor=offset_factor, - span=span, - ) - self.context.current_block_scope().match_buffers.append( - tvm.tir.MatchBufferRegion(buffer, buffer_region) - ) - self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) - - super().__init__(match_buffer_region, def_symbol=True) - - @register class VarDef(SpecialStmt): """Special function for defining a Var""" @@ -491,6 +453,22 @@ def var(dtype, span): super().__init__(var, def_symbol=True) +@register +class BufferVarDef(SpecialStmt): + """Special function for defining a variable of pointer type""" + + def __init__(self): + def buffer_var(dtype, storage_scope, span): + assert isinstance( + self.node, ast.Assign + ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + v = te.var(self.node.lhs.id.name, ptr_type, span=span) + self.context.update_symbol(v.name, v, self.node) + + super().__init__(buffer_var, def_symbol=True) + + @register class EnvThread(SpecialStmt): """Bind a var to thread env""" diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index a9ff9294e8a5..106432cd44f7 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -281,9 +281,11 @@ def intel_graphics(model="unknown", options=None): MICRO_SUPPORTED_MODELS = { "host": [], - "stm32f746xx": ["-mcpu=cortex-m7", "-march=armv7e-m"], - "nrf5340dk": ["-mcpu=cortex-m33"], "mps2_an521": ["-mcpu=cortex-m33"], + "nrf5340dk": ["-mcpu=cortex-m33"], + "stm32f746xx": ["-mcpu=cortex-m7", "-march=armv7e-m"], + "stm32l4r5zi": ["-mcpu=cortex-m4"], + "zynq_mp_r5": ["-mcpu=cortex-r5"], } diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) - body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 4721c0050656..79518ac24984 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -78,7 +78,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """Version of np.testing.assert_allclose with `atol` and `rtol` fields set in reasonable defaults. - Arguments `actual` and `desired` are not interchangable, since the function + Arguments `actual` and `desired` are not interchangeable, since the function compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 030918a5e18e..500195ac9a13 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,13 +16,17 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name -from typing import Dict +from typing import Dict, List + +from tvm.tir.stmt import Block, BufferRegion +from tvm.tir.stmt import PrimExpr +from tvm.tir.expr import Var from . import _ffi_api from ..function import PrimFunc from .. import Buffer, Stmt -def expr_deep_equal(lhs, rhs): +def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool: """Deeply compare two nested expressions. Parameters @@ -56,10 +60,10 @@ def expr_deep_equal(lhs, rhs): -------- tvm.ir.structural_equal """ - return _ffi_api.expr_deep_equal(lhs, rhs) + return _ffi_api.expr_deep_equal(lhs, rhs) # type: ignore -def verify_ssa(func): +def verify_ssa(func: PrimFunc) -> bool: """Verify if the func is in SSA form. Parameters @@ -72,10 +76,10 @@ def verify_ssa(func): result : bool The result of verification. """ - return _ffi_api.verify_ssa(func) + return _ffi_api.verify_ssa(func) # type: ignore -def verify_memory(func): +def verify_memory(func: PrimFunc) -> bool: """Verify if func contains illegal host side direct memory access. Parameters @@ -88,10 +92,10 @@ def verify_memory(func): result : bool The result of verification. """ - return _ffi_api.verify_memory(func) + return _ffi_api.verify_memory(func) # type: ignore -def verify_gpu_code(func, constraints): +def verify_gpu_code(func: PrimFunc, constraints: Dict[str, int]) -> None: """Verify if module contains illegal host side direct memory access. Parameters @@ -107,10 +111,12 @@ def verify_gpu_code(func, constraints): result : bool The result of verification. """ - return _ffi_api.verify_gpu_code(func, constraints) + return _ffi_api.verify_gpu_code(func, constraints) # type: ignore -def get_block_access_region(block, buffer_var_map): +def get_block_access_region( + block: Block, buffer_var_map: Dict[Var, Buffer] +) -> List[List[BufferRegion]]: """Detect which regions of tensors in this block are read or written to. Regions are sorted by order of appearance in the AST. @@ -130,10 +136,10 @@ def get_block_access_region(block, buffer_var_map): - second: write regions - third: opaque regions """ - return _ffi_api.get_block_access_region(block, buffer_var_map) + return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore -def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int): +def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int: """Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc. @@ -149,7 +155,7 @@ def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int): result : int Workspace size in bytes. """ - return _ffi_api.calculate_workspace_bytes(func, workspace_byte_alignment) + return _ffi_api.calculate_workspace_bytes(func, workspace_byte_alignment) # type: ignore def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: @@ -167,4 +173,4 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: result : Dict[Buffer, Stmt] Map from buffer to the LCA of all access to it. """ - return _ffi_api.detect_buffer_access_lca(func) # pylint: disable=no-member + return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d905a53b3303..6dddd7b119a0 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -90,7 +90,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask offset = convert(offset) - return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset) + return _ffi_api.BufferAccessPtr( + self, access_mask, ptr_type, content_lanes, offset # type: ignore + ) def vload(self, begin, dtype=None): """Generate an Expr that loads dtype from begin index. @@ -111,7 +113,7 @@ def vload(self, begin, dtype=None): """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin dtype = dtype if dtype else self.dtype - return _ffi_api.BufferVLoad(self, begin, dtype) + return _ffi_api.BufferVLoad(self, begin, dtype) # type: ignore def vstore(self, begin, value): """Generate a Stmt that store value into begin index. @@ -130,7 +132,16 @@ def vstore(self, begin, value): The corresponding store stmt. """ begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin - return _ffi_api.BufferVStore(self, begin, value) + return _ffi_api.BufferVStore(self, begin, value) # type: ignore + + def scope(self): + """Return the storage scope associated with this buffer. + Returns + ------- + scope : str + The storage scope associated with this buffer. + """ + return _ffi_api.BufferStorageScope(self) # type: ignore def decl_buffer( @@ -244,21 +255,20 @@ def decl_buffer( dtype = "float32" if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" elem_offset = Var("%s_elem_offset" % name, shape_dtype) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 storage_type = PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type - data = Var(name, PointerType(storage_type), span) - return _ffi_api.Buffer( + data = Var(name, PointerType(storage_type, scope), span) + return _ffi_api.Buffer( # type: ignore data, dtype, shape, strides, elem_offset, name, - scope, data_alignment, offset_factor, buffer_type, diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index 40805d9beb8e..f46a154612e1 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Data layout.""" +from typing import Union + import tvm._ffi from tvm.runtime import Object @@ -36,7 +38,7 @@ class Layout(Object): """ def __len__(self): - return _ffi_api.LayoutNdim(self) + return _ffi_api.LayoutNdim(self) # type: ignore def __contains__(self, axis): return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name @@ -44,7 +46,7 @@ def __contains__(self, axis): def __getitem__(self, index): if index >= len(self): raise IndexError("Layout index out of range") - return _ffi_api.LayoutGetItem(self, index) + return _ffi_api.LayoutGetItem(self, index) # type: ignore def index_of(self, axis): """Get the index of an axis @@ -59,7 +61,7 @@ def index_of(self, axis): index : int The index of the axis, -1 if not found. """ - return _ffi_api.LayoutIndexOf(self, axis) + return _ffi_api.LayoutIndexOf(self, axis) # type: ignore def factor_of(self, axis): """Get the factor size of the subordinate axis. @@ -76,7 +78,7 @@ def factor_of(self, axis): or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout. """ - return _ffi_api.LayoutFactorOf(self, axis) + return _ffi_api.LayoutFactorOf(self, axis) # type: ignore @tvm._ffi.register_object("tir.BijectiveLayout") @@ -113,7 +115,7 @@ def forward_index(self, index): dst_index: Array of Expr The inferred indices in dst-layout. """ - return _ffi_api.BijectiveLayoutForwardIndex(self, index) + return _ffi_api.BijectiveLayoutForwardIndex(self, index) # type: ignore def backward_index(self, index): """Given the indices of the dst-layout, infer the src index. @@ -128,7 +130,7 @@ def backward_index(self, index): src_index: Array of Expr The inferred indices in src-layout. """ - return _ffi_api.BijectiveLayoutBackwardIndex(self, index) + return _ffi_api.BijectiveLayoutBackwardIndex(self, index) # type: ignore def forward_shape(self, shape): """Given the shape of the src-layout, infer the dst shape. @@ -143,7 +145,7 @@ def forward_shape(self, shape): dst_shape: Array of Expr The inferred shape in dst-layout. """ - return _ffi_api.BijectiveLayoutForwardShape(self, shape) + return _ffi_api.BijectiveLayoutForwardShape(self, shape) # type: ignore def backward_shape(self, shape): """Given the shape of the dst-layout, infer the src shape. @@ -158,10 +160,10 @@ def backward_shape(self, shape): src_shape: Array of Expr The inferred shape in src-layout. """ - return _ffi_api.BijectiveLayoutBackwardShape(self, shape) + return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore -def layout(layout_str): +def layout(layout_str: str) -> Layout: """Create a layout node from a string. Parameters @@ -180,10 +182,12 @@ def layout(layout_str): layout : Layout The created layout """ - return _ffi_api.Layout(layout_str) + return _ffi_api.Layout(layout_str) # type: ignore -def bijective_layout(src_layout, dst_layout): +def bijective_layout( + src_layout: Union[str, Layout], dst_layout: Union[str, Layout] +) -> BijectiveLayout: """Create a bijective layout mapping. Parameters @@ -203,4 +207,4 @@ def bijective_layout(src_layout, dst_layout): src_layout = layout(src_layout) if isinstance(dst_layout, str): dst_layout = layout(dst_layout) - return _ffi_api.BijectiveLayout(src_layout, dst_layout) + return _ffi_api.BijectiveLayout(src_layout, dst_layout) # type: ignore diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 286e4051da51..4ba8c5471b5d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,10 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ +from typing import Optional, Union +from tvm import ir import tvm._ffi +from tvm.ir.base import Span from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const from tvm.ir import PrimExpr, Op @@ -47,13 +50,17 @@ def div_ambiguity_error(): def _dtype_is_int(value): if isinstance(value, int): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT + return ( + isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT + ) # type: ignore def _dtype_is_float(value): if isinstance(value, float): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT + return ( + isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT + ) # type: ignore class ExprOp(object): @@ -106,55 +113,55 @@ def __rfloordiv__(self, other): return _generic.floordiv(other, self, None) def __mod__(self, other): - return _ffi_api._OpFloorMod(self, other, None) + return _ffi_api._OpFloorMod(self, other, None) # type: ignore def __rmod__(self, other): - return _ffi_api._OpFloorMod(other, self, None) + return _ffi_api._OpFloorMod(other, self, None) # type: ignore def __neg__(self): - neg_one = const(-1, self.dtype) + neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) def __lshift__(self, other): - return _ffi_api.left_shift(self, other, None) + return _ffi_api.left_shift(self, other, None) # type: ignore def __rlshift__(self, other): - return _ffi_api.left_shift(other, self, None) + return _ffi_api.left_shift(other, self, None) # type: ignore def __rshift__(self, other): - return _ffi_api.right_shift(self, other, None) + return _ffi_api.right_shift(self, other, None) # type: ignore def __rrshift__(self, other): - return _ffi_api.right_shift(other, self, None) + return _ffi_api.right_shift(other, self, None) # type: ignore def __and__(self, other): - return _ffi_api.bitwise_and(self, other, None) + return _ffi_api.bitwise_and(self, other, None) # type: ignore def __rand__(self, other): - return _ffi_api.bitwise_and(other, self, None) + return _ffi_api.bitwise_and(other, self, None) # type: ignore def __or__(self, other): - return _ffi_api.bitwise_or(self, other, None) + return _ffi_api.bitwise_or(self, other, None) # type: ignore def __ror__(self, other): - return _ffi_api.bitwise_or(other, self, None) + return _ffi_api.bitwise_or(other, self, None) # type: ignore def __xor__(self, other): - return _ffi_api.bitwise_xor(self, other, None) + return _ffi_api.bitwise_xor(self, other, None) # type: ignore def __rxor__(self, other): - return _ffi_api.bitwise_xor(other, self, None) + return _ffi_api.bitwise_xor(other, self, None) # type: ignore def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.bitwise_not(self, None) + return _ffi_api.bitwise_not(self, None) # type: ignore def __lt__(self, other): - return _ffi_api._OpLT(self, other, None) + return _ffi_api._OpLT(self, other, None) # type: ignore def __le__(self, other): - return _ffi_api._OpLE(self, other, None) + return _ffi_api._OpLE(self, other, None) # type: ignore def __eq__(self, other): return EqualOp(self, other) @@ -163,10 +170,10 @@ def __ne__(self, other): return NotEqualOp(self, other) def __gt__(self, other): - return _ffi_api._OpGT(self, other, None) + return _ffi_api._OpGT(self, other, None) # type: ignore def __ge__(self, other): - return _ffi_api._OpGE(self, other, None) + return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): raise ValueError( @@ -193,9 +200,9 @@ def equal(self, other, span=None): ret : PrimExpr The equality expression. """ - return _ffi_api._OpEQ(self, other, span) + return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype, span=None): + def astype(self, dtype: str, span: Optional[Span] = None): """Cast the expression to other type. Parameters @@ -248,7 +255,7 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _ffi_api._OpEQ(self.a, self.b, self.span) + return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore class NotEqualOp(ObjectGeneric, ExprOp): @@ -285,7 +292,7 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _ffi_api._OpNE(self.a, self.b, self.span) + return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore class IntImmEnum(ObjectGeneric): @@ -307,7 +314,7 @@ def __init__(self, value, span=None): def asobject(self): """Convert object.""" - return IntImm("int32", self.value, self.span) + return IntImm("int32", self.value, self.span) # type: ignore class PrimExprWithOp(ExprOp, PrimExpr): @@ -350,8 +357,8 @@ class Var(PrimExprWithOp): The location of this itervar in the source code. """ - def __init__(self, name, dtype, span=None): - self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) + def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None): + self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore @tvm._ffi.register_object("tir.SizeVar") @@ -373,7 +380,7 @@ class SizeVar(Var): # pylint: disable=super-init-not-called def __init__(self, name, dtype, span=None): - self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) + self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore @tvm._ffi.register_object("tir.IterVar") @@ -428,7 +435,9 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): name = var if var is not None else "iter" dtype = "int32" if dom is None else dom.extent.dtype var = Var(name, dtype=dtype, span=span) if not isinstance(var, Var) else var - self.__init_handle_by_constructor__(_ffi_api.IterVar, dom, var, iter_type, thread_tag, span) + self.__init_handle_by_constructor__( + _ffi_api.IterVar, dom, var, iter_type, thread_tag, span # type: ignore + ) @tvm._ffi.register_object("tir.CommReducer") @@ -455,7 +464,7 @@ class CommReducer(Object): def __init__(self, lhs, rhs, result, identity_element, span=None): self.__init_handle_by_constructor__( - _ffi_api.CommReducer, lhs, rhs, result, identity_element, span + _ffi_api.CommReducer, lhs, rhs, result, identity_element, span # type: ignore ) @@ -489,7 +498,7 @@ class Reduce(PrimExprWithOp): def __init__(self, combiner, src, rdom, condition, value_index, init=None, span=None): self.__init_handle_by_constructor__( - _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span + _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span # type: ignore ) @@ -510,7 +519,9 @@ class FloatImm(ConstExpr): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.FloatImm, dtype, value, span) + self.__init_handle_by_constructor__( + tvm.ir._ffi_api.FloatImm, dtype, value, span # type: ignore + ) @tvm._ffi.register_object @@ -530,7 +541,9 @@ class IntImm(ConstExpr): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.IntImm, dtype, value, span) + self.__init_handle_by_constructor__( + tvm.ir._ffi_api.IntImm, dtype, value, span # type: ignore + ) def __hash__(self): return self.value @@ -542,16 +555,16 @@ def __nonzero__(self): return self.value != 0 def __eq__(self, other): - return _ffi_api._OpEQ(self, other, None) + return _ffi_api._OpEQ(self, other, None) # type: ignore def __ne__(self, other): - return _ffi_api._OpNE(self, other, None) + return _ffi_api._OpNE(self, other, None) # type: ignore def __bool__(self): return self.__nonzero__() -@tvm._ffi.register_object("tir.StringImm") +@tvm._ffi.register_object("tir.StringImm") # type: ignore class StringImm(ConstExpr): """String constant. @@ -565,7 +578,7 @@ class StringImm(ConstExpr): """ def __init__(self, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore def __eq__(self, other): if isinstance(other, ConstExpr): @@ -577,6 +590,9 @@ def __ne__(self, other): return self.value != other.value return self.value != other + def __hash__(self): + return PrimExpr.__hash__(self) + @tvm._ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): @@ -595,7 +611,7 @@ class Cast(PrimExprWithOp): """ def __init__(self, dtype, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) + self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore @tvm._ffi.register_object("tir.Add") @@ -615,7 +631,7 @@ class Add(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Sub") @@ -635,7 +651,7 @@ class Sub(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Mul") @@ -655,7 +671,7 @@ class Mul(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Div") @@ -675,7 +691,7 @@ class Div(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Mod") @@ -695,7 +711,7 @@ class Mod(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore @tvm._ffi.register_object("tir.FloorDiv") @@ -715,7 +731,7 @@ class FloorDiv(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore @tvm._ffi.register_object("tir.FloorMod") @@ -735,7 +751,7 @@ class FloorMod(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Min") @@ -755,7 +771,7 @@ class Min(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Max") @@ -775,7 +791,7 @@ class Max(BinaryOpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore @tvm._ffi.register_object("tir.EQ") @@ -795,7 +811,7 @@ class EQ(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore @tvm._ffi.register_object("tir.NE") @@ -815,7 +831,7 @@ class NE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.LT") @@ -835,7 +851,7 @@ class LT(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore @tvm._ffi.register_object("tir.LE") @@ -855,7 +871,7 @@ class LE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.GT") @@ -875,7 +891,7 @@ class GT(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore @tvm._ffi.register_object("tir.GE") @@ -895,7 +911,7 @@ class GE(CmpExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore @tvm._ffi.register_object("tir.And") @@ -915,7 +931,7 @@ class And(LogicalExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Or") @@ -935,7 +951,7 @@ class Or(LogicalExpr): """ def __init__(self, a, b, span=None): - self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) + self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore @tvm._ffi.register_object("tir.Not") @@ -952,7 +968,7 @@ class Not(LogicalExpr): """ def __init__(self, a, span=None): - self.__init_handle_by_constructor__(_ffi_api.Not, a, span) + self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore @tvm._ffi.register_object("tir.Select") @@ -983,7 +999,7 @@ class Select(PrimExprWithOp): def __init__(self, condition, true_value, false_value, span=None): self.__init_handle_by_constructor__( - _ffi_api.Select, condition, true_value, false_value, span + _ffi_api.Select, condition, true_value, false_value, span # type: ignore ) @@ -1011,9 +1027,9 @@ class Load(PrimExprWithOp): def __init__(self, dtype, buffer_var, index, predicate=None, span=None): if predicate is None: - predicate = _ffi_api.const_true(dtype, span) + predicate = _ffi_api.const_true(dtype, span) # type: ignore self.__init_handle_by_constructor__( - _ffi_api.Load, dtype, buffer_var, index, predicate, span + _ffi_api.Load, dtype, buffer_var, index, predicate, span # type: ignore ) @@ -1034,7 +1050,9 @@ class BufferLoad(PrimExprWithOp): """ def __init__(self, buffer, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.BufferLoad, buffer, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.BufferLoad, buffer, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.ProducerLoad") @@ -1054,7 +1072,9 @@ class ProducerLoad(PrimExprWithOp): """ def __init__(self, producer, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.ProducerLoad, producer, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.ProducerLoad, producer, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.Ramp") @@ -1077,7 +1097,9 @@ class Ramp(PrimExprWithOp): """ def __init__(self, base, stride, lanes, span=None): - self.__init_handle_by_constructor__(_ffi_api.Ramp, base, stride, lanes, span) + self.__init_handle_by_constructor__( + _ffi_api.Ramp, base, stride, lanes, span # type: ignore + ) @tvm._ffi.register_object("tir.Broadcast") @@ -1097,7 +1119,7 @@ class Broadcast(PrimExprWithOp): """ def __init__(self, value, lanes, span=None): - self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) + self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore @tvm._ffi.register_object("tir.Shuffle") @@ -1117,7 +1139,9 @@ class Shuffle(PrimExprWithOp): """ def __init__(self, vectors, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.Shuffle, vectors, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.Shuffle, vectors, indices, span # type: ignore + ) class CallEffectKind: @@ -1163,7 +1187,7 @@ def __init__(self, dtype, op, args, span=None): % op ) op = Op.get(op) - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) + self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore @tvm._ffi.register_object("tir.Let") @@ -1186,11 +1210,11 @@ class Let(PrimExprWithOp): """ def __init__(self, var, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) + self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) # type: ignore @tvm._ffi.register_object("tir.Any") -class Any(PrimExpr): +class Any(PrimExprWithOp): """Any node. span : Optional[Span] @@ -1198,4 +1222,4 @@ class Any(PrimExpr): """ def __init__(self, span=None): - self.__init_handle_by_constructor__(_ffi_api.Any, span) + self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 79d18d8970b5..68d967aa497d 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,12 +16,14 @@ # under the License. """Function data types.""" +from typing import Mapping, Union + import tvm._ffi import tvm.runtime from tvm.runtime import Object from tvm.ir import BaseFunc from .buffer import Buffer -from .expr import Var +from .expr import Var, PrimExpr from . import _ffi_api @@ -65,7 +67,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore ) def with_body(self, new_body, span=None): @@ -85,3 +87,54 @@ def with_body(self, new_body, span=None): The created new function. """ return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) + + def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): + """Specialize parameters of PrimFunc + + Parameters + ---------- + + param_map : Mapping[Var, Union[PrimExpr, Buffer]] + The mapping from function params to the instance + + Examples + -------- + We can define a Meta TIR function with symbolic shape: + + .. code-block:: python + + @tvm.script.tir + def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32") + B = tir.match_buffer(b, (m, n), "float32") + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Then we can make it specialized with given shapes or buffers. + + .. code-block:: python + + a, _, m, n = mem_copy.params + func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + # or + func = mem_copy.specialize({n: 16, m: 16}) + + The specialized function: + + .. code-block:: python + + @tvm.script.tir + def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Returns + ------- + func : PrimFunc + The new function with parameter specialized + """ + return _ffi_api.Specialize(self, param_map) # type: ignore diff --git a/python/tvm/tir/generic.py b/python/tvm/tir/generic.py index 1194eaa6b462..58efc0985970 100644 --- a/python/tvm/tir/generic.py +++ b/python/tvm/tir/generic.py @@ -43,7 +43,7 @@ def add(lhs, rhs, span=None): op : tvm.Expr The result Expr of add operaton. """ - return _ffi_api._OpAdd(lhs, rhs, span) + return _ffi_api._OpAdd(lhs, rhs, span) # type: ignore def subtract(lhs, rhs, span=None): @@ -63,7 +63,7 @@ def subtract(lhs, rhs, span=None): op : tvm.Expr The result Expr of subtract operaton. """ - return _ffi_api._OpSub(lhs, rhs, span) + return _ffi_api._OpSub(lhs, rhs, span) # type: ignore def multiply(lhs, rhs, span=None): @@ -83,7 +83,7 @@ def multiply(lhs, rhs, span=None): op : tvm.Expr The result Expr of multiply operaton. """ - return _ffi_api._OpMul(lhs, rhs, span) + return _ffi_api._OpMul(lhs, rhs, span) # type: ignore def divide(lhs, rhs, span=None): @@ -103,7 +103,7 @@ def divide(lhs, rhs, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._OpDiv(lhs, rhs, span) + return _ffi_api._OpDiv(lhs, rhs, span) # type: ignore def floordiv(lhs, rhs, span=None): @@ -123,7 +123,7 @@ def floordiv(lhs, rhs, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._OpFloorDiv(lhs, rhs, span) + return _ffi_api._OpFloorDiv(lhs, rhs, span) # type: ignore def cast(src, dtype, span=None): @@ -141,4 +141,4 @@ def cast(src, dtype, span=None): op : tvm.Expr The result Expr of divide operaton. """ - return _ffi_api._cast(dtype, src, span) + return _ffi_api._cast(dtype, src, span) # type: ignore diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4934bf04727f..978c630b17ad 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -141,7 +141,7 @@ class IRBuilder(object): """ def __init__(self): - self._seq_stack = [[]] + self._seq_stack = [[]] # type: ignore self.nidx = 0 def _pop_seq(self): @@ -394,7 +394,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=None): + def allocate(self, dtype, shape, name="buf", scope=""): """Create a allocate statement. Parameters @@ -416,15 +416,13 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - if scope: - self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) - def pointer(self, content_type, name="ptr"): + def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. Parameters @@ -435,12 +433,15 @@ def pointer(self, content_type, name="ptr"): name : str, optional The name of the pointer. + scope : str, optional + The scope of the pointer. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) return BufferVar(self, buffer_var, None, content_type) def buffer_ptr(self, buf, shape=None): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 874724b767cb..de3ca5fa8d5b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,12 +16,14 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" +from typing import Any, Optional import tvm._ffi +from tvm.ir.base import Span from tvm.runtime import convert, const from tvm.ir import Array, Op from .buffer import Buffer -from .expr import Call, StringImm, Var, CommReducer +from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer from . import _ffi_api @@ -257,9 +259,9 @@ def any(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - val = _ffi_api._OpOr(args[0], args[1], span) + val = _ffi_api._OpOr(args[0], args[1], span) # type: ignore for i in range(2, len(args)): - val = _ffi_api._OpOr(val, args[i], span) + val = _ffi_api._OpOr(val, args[i], span) # type: ignore return val @@ -284,9 +286,9 @@ def all(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - val = _ffi_api._OpAnd(args[0], args[1], span) + val = _ffi_api._OpAnd(args[0], args[1], span) # type: ignore for i in range(2, len(args)): - val = _ffi_api._OpAnd(val, args[i], span) + val = _ffi_api._OpAnd(val, args[i], span) # type: ignore return val @@ -343,10 +345,10 @@ def min_value(dtype, span=None): value : tvm.Expr The minimum value of dtype. """ - return _ffi_api.min_value(dtype, span) + return _ffi_api.min_value(dtype, span) # type: ignore -def max_value(dtype, span=None): +def max_value(dtype: str, span: Optional[Span] = None) -> Any: """maximum value of dtype Parameters @@ -362,11 +364,11 @@ def max_value(dtype, span=None): value : tvm.Expr The maximum value of dtype. """ - return _ffi_api.max_value(dtype, span) + return _ffi_api.max_value(dtype, span) # type: ignore def exp(x): - """Take exponetial of input x. + """Take exponential of input x. Parameters ---------- @@ -769,7 +771,7 @@ def clz(x): return call_intrin("int32", "tir.clz", x) -def floor(x, span=None): +def floor(x: PrimExprWithOp, span=None): """Take floor of float input x. Parameters @@ -785,7 +787,7 @@ def floor(x, span=None): y : PrimExpr The result. """ - return _ffi_api.floor(x, span) + return _ffi_api.floor(x, span) # type: ignore def ceil(x, span=None): @@ -804,7 +806,7 @@ def ceil(x, span=None): y : PrimExpr The result. """ - return _ffi_api.ceil(x, span) + return _ffi_api.ceil(x, span) # type: ignore def trunc(x, span=None): @@ -826,7 +828,7 @@ def trunc(x, span=None): y : PrimExpr The result. """ - return _ffi_api.trunc(x, span) + return _ffi_api.trunc(x, span) # type: ignore def abs(x, span=None): @@ -845,7 +847,7 @@ def abs(x, span=None): y : PrimExpr The result. """ - return _ffi_api.abs(x, span) + return _ffi_api.abs(x, span) # type: ignore def round(x, span=None): @@ -864,7 +866,7 @@ def round(x, span=None): y : PrimExpr The result. """ - return _ffi_api.round(x, span) + return _ffi_api.round(x, span) # type: ignore def nearbyint(x, span=None): @@ -890,7 +892,7 @@ def nearbyint(x, span=None): y : PrimExpr The result. """ - return _ffi_api.nearbyint(x, span) + return _ffi_api.nearbyint(x, span) # type: ignore def nextafter(x1, x2): @@ -909,7 +911,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.nextafter", x1, x2) + return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore def hypot(x1, x2): @@ -928,7 +930,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.hypot", x1, x2) + return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore def copysign(x1, x2): @@ -947,7 +949,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.copysign", x1, x2) + return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore def ldexp(x1, x2): @@ -966,7 +968,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_intrin(x1.dtype, "tir.ldexp", x1, x2) + return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore def isnan(x, span=None): @@ -985,7 +987,7 @@ def isnan(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isnan(x, span) + return _ffi_api.isnan(x, span) # type: ignore def isfinite(x, span=None): @@ -1004,7 +1006,7 @@ def isfinite(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isfinite(x, span) + return _ffi_api.isfinite(x, span) # type: ignore def isinf(x, span=None): @@ -1023,7 +1025,7 @@ def isinf(x, span=None): y : PrimExpr The result. """ - return _ffi_api.isinf(x, span) + return _ffi_api.isinf(x, span) # type: ignore def power(x, y, span=None): @@ -1045,7 +1047,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore def popcount(x): @@ -1141,7 +1143,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) + return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore def div(a, b, span=None): @@ -1166,7 +1168,7 @@ def div(a, b, span=None): ---- When operands are integers, returns truncdiv(a, b, span). """ - return _ffi_api._OpDiv(a, b, span) + return _ffi_api._OpDiv(a, b, span) # type: ignore def indexdiv(a, b, span=None): @@ -1194,7 +1196,7 @@ def indexdiv(a, b, span=None): This function may take advantage of operands' non-negativeness. """ - return _ffi_api._OpIndexDiv(a, b, span) + return _ffi_api._OpIndexDiv(a, b, span) # type: ignore def indexmod(a, b, span=None): @@ -1222,7 +1224,7 @@ def indexmod(a, b, span=None): This function may take advantage of operands' non-negativeness. """ - return _ffi_api._OpIndexMod(a, b, span) + return _ffi_api._OpIndexMod(a, b, span) # type: ignore def truncdiv(a, b, span=None): @@ -1248,7 +1250,7 @@ def truncdiv(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api._OpTruncDiv(a, b, span) + return _ffi_api._OpTruncDiv(a, b, span) # type: ignore def truncmod(a, b, span=None): @@ -1274,7 +1276,7 @@ def truncmod(a, b, span=None): ---- This is the default integer division behavior in C. """ - return _ffi_api._OpTruncMod(a, b, span) + return _ffi_api._OpTruncMod(a, b, span) # type: ignore def floordiv(a, b, span=None): @@ -1296,7 +1298,7 @@ def floordiv(a, b, span=None): res : PrimExpr The result expression. """ - return _ffi_api._OpFloorDiv(a, b, span) + return _ffi_api._OpFloorDiv(a, b, span) # type: ignore def floormod(a, b, span=None): @@ -1318,7 +1320,7 @@ def floormod(a, b, span=None): res : PrimExpr The result expression. """ - return _ffi_api._OpFloorMod(a, b, span) + return _ffi_api._OpFloorMod(a, b, span) # type: ignore def comm_reducer(fcombine, fidentity, name="reduce"): @@ -1476,5 +1478,5 @@ def reducer(expr, axis, where=None, init=None, *args): # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") -min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") -max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") +min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore +max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") # type: ignore diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index ef1cab1fb663..5f0e169c43e3 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -18,5 +18,7 @@ """Namespace for the TensorIR schedule API.""" from .block_scope import BlockScope, Dependency, DepKind, StmtSRef +from .instruction import Instruction, InstructionKind +from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState -from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError +from .trace import Trace diff --git a/python/tvm/tir/schedule/_ffi_api_schedule.py b/python/tvm/tir/schedule/_ffi_api.py similarity index 100% rename from python/tvm/tir/schedule/_ffi_api_schedule.py rename to python/tvm/tir/schedule/_ffi_api.py diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py index 061a472ad9a9..30e047b4f78a 100644 --- a/python/tvm/tir/schedule/block_scope.py +++ b/python/tvm/tir/schedule/block_scope.py @@ -22,7 +22,7 @@ from tvm.runtime import Object from tvm.tir import Block, For -from . import _ffi_api_schedule +from . import _ffi_api @register_object("tir.StmtSRef") @@ -45,24 +45,24 @@ class StmtSRef(Object): @property def stmt(self) -> Optional[Union[Block, For]]: """The block/for stmt the object refers to""" - return _ffi_api_schedule.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefStmt(self) # type: ignore # pylint: disable=no-member @property def parent(self) -> Optional["StmtSRef"]: """The parent sref""" - return _ffi_api_schedule.StmtSRefParent(self) # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefParent(self) # type: ignore # pylint: disable=no-member @staticmethod def inline_mark() -> "StmtSRef": """A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do the work of compute-inline""" - return _ffi_api_schedule.StmtSRefInlineMark() # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefInlineMark() # type: ignore # pylint: disable=no-member @staticmethod def root_mark() -> "StmtSRef": """A special StmtSRef, which doesn't point to any stmt in the AST, only serving as a "mark" to hint compute-at to do nothing""" - return _ffi_api_schedule.StmtSRefRootMark() # type: ignore # pylint: disable=no-member + return _ffi_api.StmtSRefRootMark() # type: ignore # pylint: disable=no-member class DepKind(IntEnum): @@ -137,7 +137,7 @@ def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.BlockScopeGetDepsBySrc(self, block) # type: ignore # pylint: disable=no-member def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: """Get all dependencies whose `dst` is the target `block`. @@ -152,4 +152,4 @@ def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: blocks: List[Dependency] The dependencies """ - return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.BlockScopeGetDepsByDst(self, block) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/instruction.py b/python/tvm/tir/schedule/instruction.py new file mode 100644 index 000000000000..09b2d70dc321 --- /dev/null +++ b/python/tvm/tir/schedule/instruction.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Schedule instructions each corresponds to a schedule primitive""" +from typing import TYPE_CHECKING, Any, List, Union + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object + +from . import _ffi_api + +if TYPE_CHECKING: + from .schedule import RAND_VAR_TYPE + + INPUT_RV_TYPE = Union[RAND_VAR_TYPE, float, int, str, None] # pylint: disable=invalid-name + OUTPUT_RV_TYPE = Union[RAND_VAR_TYPE] # pylint: disable=invalid-name + ATTR_TYPE = Any +else: + INPUT_RV_TYPE = OUTPUT_RV_TYPE = ATTR_TYPE = Any + + +@_register_object("tir.InstructionKind") +class InstructionKind(Object): + """Kind of an instruction, e.g. Split, Reorder, etc. + Besides the name, every kind of instruction has its own properties, including: + 1) A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + 2) A functor that applies the instruction to a TensorIR schedule + 3) A functor that converts the instruction to a statement in python syntax + 4) A functor that serialize its attributes to JSON + 5) A functor that deserialize its attributes from JSON + + Unlike `tvm.ir.op`, `InstructionKind` doesn't support unstructured properties, + mainly because there is no such usecase yet to add any other property. + + Attributes + ---------- + name : str + The name of a kind of instructions + + Note + ---- + The functor properties are not exposed on python side at the moment + """ + + name: str + + @property + def is_pure(self) -> bool: + """Indicates if the instruction is pure, i.e. removing it alone doesn't mutate the schedule + state. For example, the instruction `GetBlock` is pure because it changes + nothing, while `ComputeInline` is not because removing it leads to a different resulting + schedule. + + Returns + ------- + pure : bool + The boolean flag indicating if the instruction is pure + """ + return bool(self._is_pure) + + @staticmethod + def get(name: str) -> "InstructionKind": + """Retrieve an InstructionKind using its name + + Parameters + ---------- + name : str + The registered name of the InstructionKind + + Returns + ------- + kind : InstructionKind + The InstructionKind retrieved + """ + return _ffi_api.InstructionKindGet(name) # type: ignore # pylint: disable=no-member + + +@_register_object("tir.Instruction") +class Instruction(Object): + """Schedule instructions each corresponds to a schedule primitive + + Attributes + ---------- + kind : InstructionKind + The kind of the instruction + inputs : List[INPUT_RV_TYPE] + The input random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV + - float + - int + - str + - None + attrs : List[ATTR_TYPE] + The attributes of the instruction. Similar to attributes of an operator, + attributes of an instruction are arbitrary constant metadata required by the instructions. + For example, the name of the block to be retrieved in `GetBlock`. + outputs : List[OUTPUT_RV_TYPE] + The output random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV, atomic variables only, won't be constants or composite PrimExpr + """ + + kind: InstructionKind + inputs: List[INPUT_RV_TYPE] + attrs: List[ATTR_TYPE] + outputs: List[OUTPUT_RV_TYPE] + + def __init__( + self, + kind: InstructionKind, + inputs: List[INPUT_RV_TYPE], + attrs: List[ATTR_TYPE], + outputs: List[OUTPUT_RV_TYPE], + ) -> None: + """Constructor + + Parameters + ---------- + kind : InstructionKind + The kind of the instruction + inputs : List[INPUT_RV_TYPE] + The input random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV + - float + - int + - str + - None + attrs : List[ATTR_TYPE] + The attributes of the instruction. Similar to attributes of an operator, + attributes of an instruction are arbitrary constant metadata required by the + instructions. For example, the name of the block to be retrieved in `GetBlock`. + outputs : List[OUTPUT_RV_TYPE] + The output random variables of the instruction, + and the type of each element can be one of the following: + - BlockRV + - LoopRV + - ExprRV, atomic variables only, won't be constants or composite PrimExpr + """ + self.__init_handle_by_constructor__( + _ffi_api.Instruction, # type: ignore # pylint: disable=no-member + kind, + inputs, + attrs, + outputs, + ) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 2091f4d80ab3..22c08398df33 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -22,9 +22,9 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object -from tvm.tir import Block, For, IntImm, PrimFunc, Var +from tvm.tir import Block, For, IntImm, PrimFunc -from . import _ffi_api_schedule +from . import _ffi_api from .state import ScheduleState, StmtSRef @@ -37,15 +37,33 @@ class ScheduleError(TVMError): class LoopRV(Object): """A random variable that refers to a loop""" + def __init__(self) -> None: + """Construct a new LoopRV.""" + self.__init_handle_by_constructor__( + _ffi_api.LoopRV # type: ignore # pylint: disable=no-member + ) + @_register_object("tir.BlockRV") class BlockRV(Object): """A random variable that refers to a block""" + def __init__(self) -> None: + """Construct a new BlockRV.""" + self.__init_handle_by_constructor__( + _ffi_api.BlockRV # type: ignore # pylint: disable=no-member + ) + + +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer -ExprRV = PrimExpr # A random variable that evaluates to an integer +RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name -RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name +# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8 +ERROR_RENDER_LEVEL_CANDIDATES = Union[str] # pylint: disable=invalid-name @_register_object("tir.Schedule") @@ -63,20 +81,24 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ - ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2} + ERROR_RENDER_LEVEL = { + "detail": 0, + "fast": 1, + "none": 2, + } def __init__( self, - func_or_mod: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], *, debug_mode: Union[bool, int] = False, - error_render_level: str = "detail", - ): + error_render_level: ERROR_RENDER_LEVEL_CANDIDATES = "detail", + ) -> None: """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc Parameters ---------- - func_or_mod : Union[PrimFunc, IRModule] + mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled debug_mode : Union[bool, int] Do extra correctness checking after the class creation and each time @@ -88,11 +110,13 @@ def __init__( "none": Do not show any error message. Note - ---------- + ---- The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags """ + if isinstance(mod, PrimFunc): + mod = IRModule({"main": mod}) if isinstance(debug_mode, bool): if debug_mode: debug_mode = -1 @@ -105,12 +129,11 @@ def __init__( 'error_render_level can be "detail", "fast", or "none", but got: ' + f"{error_render_level}" ) - error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level) # type: ignore self.__init_handle_by_constructor__( - _ffi_api_schedule.ConcreteSchedule, # type: ignore # pylint: disable=no-member - func_or_mod, + _ffi_api.ConcreteSchedule, # type: ignore # pylint: disable=no-member + mod, debug_mode, - error_render_level, + Schedule.ERROR_RENDER_LEVEL.get(error_render_level), ) ########## Utilities ########## @@ -118,12 +141,12 @@ def __init__( @property def mod(self) -> IRModule: """Returns the AST of the module being scheduled""" - return _ffi_api_schedule.ScheduleModule(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleModule(self) # type: ignore # pylint: disable=no-member @property def state(self) -> ScheduleState: """Returns the ScheduleState in the current schedule class""" - return _ffi_api_schedule.ScheduleGetState(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetState(self) # type: ignore # pylint: disable=no-member def copy(self) -> "Schedule": """Returns a copy of the schedule, including both the state and the symbol table, @@ -132,30 +155,34 @@ def copy(self) -> "Schedule": * 2) The IRModule being scheduled is untouched; * 3) All the random variables are valid in the copy, pointing to the correpsonding sref * reconstructed + Returns ------- copy : Schedule A new copy of the schedule """ - return _ffi_api_schedule.ScheduleCopy(self) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleCopy(self) # type: ignore # pylint: disable=no-member def seed(self, seed: int) -> None: """Seed the randomness + Parameters ---------- seed : int The new random seed, -1 if use device random, otherwise non-negative """ - return _ffi_api_schedule.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member def show(self, rand_var: RAND_VAR_TYPE) -> str: """Returns a string representation of the value that the random variable evaluates to + Parameters ---------- rand_var : Union[ExprRV, BlockRV, LoopRV] The random variable to be evaluated + Returns - ---------- + ------- str_repr : str The string representation """ @@ -173,18 +200,20 @@ def get( - the corresponding integer that a ExprRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to; + Parameters ---------- rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef] The random variable / sref to be evaluated + Returns - ---------- + ------- result : Optional[Union[int, Block, For]] The correpsonding result """ if isinstance(rand_var_or_sref, StmtSRef): return rand_var_or_sref.stmt - result = _ffi_api_schedule.ScheduleGet(self, rand_var_or_sref) # type: ignore # pylint: disable=no-member + result = _ffi_api.ScheduleGet(self, rand_var_or_sref) # type: ignore # pylint: disable=no-member if isinstance(result, IntImm): result = result.value return result @@ -195,49 +224,55 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti 2) BlockRV 3) Block 4) For + Parameters ---------- rand_var_or_stmt : Union[BlockRV, LoopRV, Block, For] The random variable / sref to be evaluated + Returns - ---------- + ------- result : Optional[StmtSRef] The correpsonding result """ - return _ffi_api_schedule.ScheduleGetSRef( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member self, rand_var_or_stmt ) def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: """Remove a random variable from the symbol table + Parameters ---------- rand_var : Union[BlockRV, LoopRV, ExprRV] The random variable to be removed """ - return _ffi_api_schedule.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member - ########## Block/Loop relation ########## + ########## Schedule: Sampling ########## + ########## Schedule: Get blocks & loops ########## def get_block( self, name: str, func_name: str = "main", ) -> BlockRV: """Retrieve a block in a specific function with its name + Parameters ---------- name : str The name of the block func_name : str = "main" The name of the function + Returns - ---------- + ------- block : BlockRV The block retrieved IndexError is raised if 0 or multiple blocks exist with the specific name. """ - return _ffi_api_schedule.ScheduleGetBlock( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member self, name, func_name, @@ -245,19 +280,157 @@ def get_block( def get_loops(self, block: BlockRV) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner + Parameters ---------- block : BlockRV The query block + Returns - ---------- + ------- loops : List[LoopRV] A list of loops above the given block in its scope, from outer to inner """ - return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + + ########## Schedule: Transform loops ########## + def fuse(self, *loops: List[LoopRV]) -> LoopRV: + """Fuse a list of consecutive loops into one. It requires: + 1) The loops can't have annotations or thread bindings. + 2) The (i+1)-th loop must be the only child of the i-th loop. + 3) All loops must start with 0. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be fused + + Returns + ------- + fused_loop : LoopRV + The new loop after fusion + + Examples + -------- + + Before applying fuse, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse) + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + print(tvm.script.asscript(sch.mod["main"])) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the 2 loops are fused into 1 + for i_j_fused in tir.serial(0, 16384): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 128)) + tir.bind(vj, tir.floormod(i_j_fused, 128)) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + + def split( + self, + loop: LoopRV, + factors: List[Union[ExprRV, None]], + ) -> List[LoopRV]: + """Split a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + 2) The loop must start with 0. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None, + which will be automatically inferred. + + Parameters + ---------- + loop : LoopRV + The loop to be split + + factors: List[Union[ExprRV, None]] + The splitting factors + Potential inputs are: + - None + - ExprRV + - Nonnegative constant integers + + Returns + ------- + split_loops : List[LoopRV] + The new loops after split + + Examples + -------- + + Before split, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_split) + i, j = sch.get_loops(sch.get_block("B")) + sch.split(i, factors=[2, 64]) + print(tvm.script.asscript(sch.mod["main"])) + + After applying split, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # the original loop is split into 2 loops + for i0, i1, j in tir.grid(2, 64, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0*64) + i1)) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + # it will be checked later in C++ implementation + # that there is at most one None in `factors` + return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + + ########## Schedule: Manipulate ForKind ########## + + ########## Schedule: Insert cache stages ########## + + ########## Schedule: Compute location ########## - ########## Schedule: loops manipulation ########## - ########## Schedule: compute location ########## def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: @@ -297,7 +470,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.compute_inline(sch.get_block("B")) print(tvm.script.asscript(sch.mod["main"])) @@ -313,7 +486,7 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ - _ffi_api_schedule.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member def reverse_compute_inline(self, block: BlockRV) -> None: """Inline a block into its only producer. It requires: @@ -357,7 +530,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.reverse_compute_inline(sch.get_block("C")) print(tvm.script.asscript(sch.mod["main"])) @@ -373,12 +546,162 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ - _ffi_api_schedule.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member + + ########## Schedule: Reduction ########## + + def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: + """Factorize an associative reduction block by the specified loop. + + An associative reduction cannot be parallelized directly, + because it leads to potential race condition during accumulation. + Alternatively, the reduction could be factorized on a loop with the following steps: + - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent + - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; + - Step 3: accumulate the `n` separate buffer into the result buffer. + Note that the Step 2 above introduces opportunities for parallelization. + + RFactor is a schedule primitive that implements the transformation described above: + Given a block that writes to buffer `B`, it factorizes a loop of extent `n`. + + For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`: + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop + for j in range(128): # loop j is a reduction loop + for k in range(128): # loop k is a reduction loop + B[i] = B[i] + A[i, j, k] + + Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`. + RFactor then creates an intermediate buffer and two blocks. + + 1. The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and + size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n` + at the position specified by `factor_axis`. For example, + + * shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3] + * shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n] + + 2. The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without + accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop + to a data parallel loop. In our example, the rf-block is: + + .. code-block:: python + + B_rf = np.zeros((128, 128)) # the rf-buffer + for k in range(128): # loop k is converted to a data parallel loop + for i in range(128): # loop i is a data parallel loop (unchanged) + for j in range(128): # loop j is a reduction loop (unchanged) + B_rf[i, k] = B_rf[i, k] + A[i, j, k] + + + 3. The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into + the result buffer. All the reduction loops are removed except the loop `k` for accumulation. + In our example, the wb-block is: + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop (unchanged) + # loop j is removed because it is a reduction loop + for k in range(128): # loop k is a reduction loop (unchanged) + B[i] = B[i] + B_rf[i, k] + + + Parameters + ---------- + loop : LoopRV + The loop outside block for which we want to do rfactor + factor_axis : int + The position where the new dimension is placed in the new introduced rfactor buffer + + Returns + ------- + rf_block : BlockRV + The block which computes partial results over each slices (i.e., the first block + as described in the above illustration) + + Examples + -------- + + Before rfactor, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128,)) + with tir.block([128, tir.reduce_axis(0, 128), + tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + with tir.init(): + B[vii] = 0.0 + B[vii] = B[vii] + A[vii, vi, vj] + + Create the schedule and do rfactor: + + .. code-block:: python + + sch = tir.Schedule(before_rfactor) + _, _, k = sch.get_loops(sch.get_block("B")) + sch.rfactor(k, 0) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128]) + B_rf = tir.alloc_buffer([128, 128]) + with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: + with tir.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: + with tir.init(): + B[vii_1] = 0.0 + B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + + + Note + ---- + + Rfactor requires: + 1) `loop` has only one child block, and it is a reduction block; + 2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables + in the block binding; + 3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis; + 4) The block scope that `loop` is in is a staged-pipeline; + 5) The outermost loop outside the reduction block should has the reduction block as its + first child block; + 6) The outermost reduction loop should have only one child block; + 7) An unary extent loop that is not bound to any reduction or data parallel variables in + the block binding should not appear under some reduction loop; + 8) The reduction block should write to only one buffer, and its init and body are both + simple `BufferStore`s, and the pattern is registered as an associative reducer. + The pre-defined patterns include: plus, multiplication, min and max; + 9) Each of the loops on top of the block cannot be bound to a data parallel and a + reduction block binding at the same time; + 10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`, + where `B` is the buffer that the reduction block writes to. + Negative indexing is normalized according to numpy convention. + """ + return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member + + ########## Schedule: Blockize & Tensorize ########## + + ########## Schedule: Annotation ########## + + ########## Schedule: Misc ########## - ########## Schedule: loop binding/annotation ########## - ########## Schedule: cache read/write ########## - ########## Schedule: reduction ########## - ########## Schedule: blockize & tensorize ########## + def enter_postproc(self) -> None: + """A no-op that marks the start of postprocessing phase of scheduling""" + _ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member @_register_object("tir.ConcreteSchedule") diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index 845e1db5cb83..cc2415f150c9 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -24,7 +24,7 @@ from tvm.runtime import Object from tvm.tir import Block, BlockRealize, For, PrimFunc -from . import _ffi_api_schedule +from . import _ffi_api from .block_scope import BlockScope, StmtSRef CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"]) @@ -75,14 +75,14 @@ class ScheduleState(Object): def __init__( self, - func_or_mod: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], debug_mode: Union[bool, int] = False, - ): + ) -> None: """Construct a schedule state from an IRModule or a PrimFunc Parameters ---------- - func_or_mod : Union[PrimFunc, IRModule] + mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled debug_mode : Union[bool, int] Do extra correctness checking after the class creation and each time @@ -92,6 +92,8 @@ def __init__( 2) False - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ + if isinstance(mod, PrimFunc): + mod = IRModule({"main": mod}) if isinstance(debug_mode, bool): if debug_mode: debug_mode = -1 @@ -100,8 +102,8 @@ def __init__( if not isinstance(debug_mode, int): raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") self.__init_handle_by_constructor__( - _ffi_api_schedule.ScheduleState, # type: ignore # pylint: disable=no-member - func_or_mod, + _ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member + mod, debug_mode, ) @@ -118,7 +120,7 @@ def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: sref : StmtSRef The corresponding sref """ - return _ffi_api_schedule.ScheduleStateGetSRef(self, stmt) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleStateGetSRef(self, stmt) # type: ignore # pylint: disable=no-member def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: """Get the BlockScope correpsonding to the block sref @@ -133,7 +135,7 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: sref : StmtSRef The corresponding sref """ - return _ffi_api_schedule.ScheduleStateGetBlockScope( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleStateGetBlockScope( # type: ignore # pylint: disable=no-member self, block_sref ) @@ -151,14 +153,14 @@ def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags: Three flags: affine_binding, region_cover, stage_pipeline Note - ------- + ---- It is an API intended for internal testing use. """ ( affine_binding, region_cover, stage_pipeline, - ) = _ffi_api_schedule.ScheduleStateGetCachedFlags( # type: ignore # pylint: disable=no-member + ) = _ffi_api.ScheduleStateGetCachedFlags( # type: ignore # pylint: disable=no-member self, block_sref ) return CachedFlags( @@ -199,12 +201,12 @@ def replace( the sref that points to the old block will point to the new one Note - ---------- + ---- The reuse of loop srefs are detected automatically according to the reuse of loop vars. """ if block_sref_reuse is None: block_sref_reuse = {} - _ffi_api_schedule.ScheduleStateReplace( # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleStateReplace( # type: ignore # pylint: disable=no-member self, src_sref, tgt_stmt, diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py new file mode 100644 index 000000000000..18bcca373dbb --- /dev/null +++ b/python/tvm/tir/schedule/trace.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""An execution trace of a scheduling program""" +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object + +from ...ir import Array, Map +from ...runtime import String +from ..expr import FloatImm, IntImm +from . import _ffi_api +from .instruction import ATTR_TYPE, INPUT_RV_TYPE, Instruction + +if TYPE_CHECKING: + from .schedule import Schedule + + +DECISION_TYPE = Any +JSON_TYPE = Any + + +def _json_from_tvm(obj): + if obj is None: + return None + if isinstance(obj, Array): + return [_json_from_tvm(i) for i in obj] + if isinstance(obj, Map): + return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} + if isinstance(obj, String): + return str(obj) + if isinstance(obj, (IntImm, FloatImm)): + return obj.value + raise TypeError("Not supported type: " + str(type(obj))) + + +@_register_object("tir.Trace") +class Trace(Object): + """An execution trace of a scheduling program. + + A trace has two parts: + 1) The instructions invoked so far + 2) The random decisions made upon those instructions, if any + + A trace can be serialized to: + 1) Roundtrippable JSON format: can be saved to file and loaded back + 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process + + A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with + their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its + corresponding decision; Otherwise the existing decision will be reused accordingly. + + Attributes + ---------- + insts : List[Instruction] + The instructions invoked so far in the program execution + decisions : Dict[Instruction, DECISION_TYPE] + The random decisions made upon those instructions + """ + + insts: List[Instruction] + decisions: Dict[Instruction, DECISION_TYPE] + + def __init__( + self, + insts: List[Instruction], + decisions: Dict[Instruction, DECISION_TYPE], + ) -> None: + """Constructor + + Parameters + ---------- + insts : List[Instruction] + The instructions invoked so far in the program execution + decisions : Dict[Instruction, DECISION_TYPE] + The random decisions made upon those instructions + """ + self.__init_handle_by_constructor__( + _ffi_api.Trace, # type: ignore # pylint: disable=no-member + insts, + decisions, + ) + + def get_decision(self, inst: Instruction) -> Optional[DECISION_TYPE]: + """Retrieve the decision made on a specific instruction + + Parameters + ---------- + insts : Instruction + The instruction whose decision is to be retrieved + + Returns + ------- + decision : Optional[DECISION_TYPE] + The corresponding decision; None if there is no decision made on the instruction + """ + return _ffi_api.TraceGetDecision(self, inst) # type: ignore # pylint: disable=no-member + + def append( + self, + inst: Instruction, + decision: Optional[DECISION_TYPE] = None, + ) -> None: + """Append a new instruction to the trace + + Parameters + ---------- + insts : Instruction + The new instruction to be appended + decision : Optional[DECISION_TYPE] = None + The random decision made on this instruction + """ + _ffi_api.TraceAppend(self, inst, decision) # type: ignore # pylint: disable=no-member + + def pop(self) -> Optional[Instruction]: + """Remove the last instruction, along with the decision made on that instruction, if any + + Returns + ------- + popped_inst : Instruction + Returns the instruction removed; NullOpt if the trace is empty + """ + return _ffi_api.TracePop(self) # type: ignore # pylint: disable=no-member + + def apply_to_schedule( + self, + sch: "Schedule", + remove_postproc: bool, + decision_provider: Optional[ + Callable[ + [Instruction, List[INPUT_RV_TYPE], List[ATTR_TYPE], DECISION_TYPE], DECISION_TYPE + ] + ] = None, + ) -> None: + """Apply the trace to a TensorIR schedule + + Parameters + ---------- + sch : Schedule + The schedule to be applied onto + remove_postproc : bool + If postprocessing instructions are removed + decision_provider: Optional[Callable] = None + A callback that allows users to mutate decisions on the fly when applying instructions. + The signature of the callback is: + - The 1st argument: The instruction + - The 2nd argument: The input random variables + - The 3rd argument: The attributes + - The 4th argument: The decision + - Return: A new decision + """ + _ffi_api.TraceApplyToSchedule( # type: ignore # pylint: disable=no-member + self, + sch, + remove_postproc, + decision_provider, + ) + + def as_json(self, remove_postproc: bool = False) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + + Parameters + ---------- + remove_postproc : bool = False + If postprocessing instructions are removed + + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + obj = _ffi_api.TraceAsJSON(self, remove_postproc) # type: ignore # pylint: disable=no-member + return _json_from_tvm(obj) + + def as_python(self, remove_postproc: bool = False) -> List[str]: + """Serialize the trace as a sequence of python statements + + Parameters + ---------- + remove_postproc : bool = False + If postprocessing instructions are removed + + Returns + ------- + py_stmts: List[str] + A sequence of python statements + """ + return _ffi_api.TraceAsPython(self, remove_postproc) # type: ignore # pylint: disable=no-member + + def with_decision( + self, + inst: Instruction, + decision: DECISION_TYPE, + remove_postproc: bool, + ) -> "Trace": + """Create a new trace with an instruction whose decision is changed, + assuming this instruction exists in the resulting trace + + Parameters + ---------- + inst : Instruction + The instruction whose decision is to be changed + decision : DECISION_TYPE + The decision to be changed to + remove_postproc : bool + If postprocessing instructions are removed + + Returns + ------- + trace: Trace + The new trace with the decision changed + """ + return _ffi_api.TraceWithDecision( # type: ignore # pylint: disable=no-member + self, + inst, + decision, + remove_postproc, + ) + + def simplified(self, remove_postproc: bool) -> "Trace": + """Simplify the trace with dead-code elimination + + Parameters + ---------- + remove_postproc : bool + If postprocessing instructions are removed + + Returns + ------- + trace: Trace + A simplified trace + """ + return _ffi_api.TraceSimplified(self, remove_postproc) # type: ignore # pylint: disable=no-member + + @staticmethod + def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None: + """Apply a JSON-serialized trace to a TensorIR schedule + + Parameters + ---------- + json_obj : JSON_TYPE + The JSON-serialized trace + sch : Schedule + The TensorIR schedule + """ + _ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index dd7665a56692..d57077f08b52 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -62,7 +62,9 @@ class LetStmt(Stmt): """ def __init__(self, var, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.LetStmt, var, value, body, span) + self.__init_handle_by_constructor__( + _ffi_api.LetStmt, var, value, body, span # type: ignore + ) @tvm._ffi.register_object("tir.AssertStmt") @@ -85,7 +87,9 @@ class AssertStmt(Stmt): """ def __init__(self, condition, message, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body, span) + self.__init_handle_by_constructor__( + _ffi_api.AssertStmt, condition, message, body, span # type: ignore + ) class ForKind(IntEnum): @@ -148,7 +152,7 @@ def __init__( span=None, ): self.__init_handle_by_constructor__( - _ffi_api.For, + _ffi_api.For, # type: ignore loop_var, min_val, extent, @@ -178,7 +182,7 @@ class While(Stmt): def __init__(self, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.While, + _ffi_api.While, # type: ignore condition, body, span, @@ -209,9 +213,9 @@ class Store(Stmt): def __init__(self, buffer_var, value, index, predicate=None, span=None): if predicate is None: - predicate = _ffi_api.const_true(value.dtype, span) + predicate = _ffi_api.const_true(value.dtype, span) # type: ignore self.__init_handle_by_constructor__( - _ffi_api.Store, buffer_var, value, index, predicate, span + _ffi_api.Store, buffer_var, value, index, predicate, span # type: ignore ) @@ -235,7 +239,9 @@ class BufferStore(Stmt): """ def __init__(self, buffer, value, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.BufferStore, buffer, value, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.BufferStore, buffer, value, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.BufferRealize") @@ -262,7 +268,7 @@ class BufferRealize(Stmt): def __init__(self, buffer, bounds, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.BufferRealize, buffer, bounds, condition, body, span + _ffi_api.BufferRealize, buffer, bounds, condition, body, span # type: ignore ) @@ -286,7 +292,9 @@ class ProducerStore(Stmt): """ def __init__(self, producer, value, indices, span=None): - self.__init_handle_by_constructor__(_ffi_api.ProducerStore, producer, value, indices, span) + self.__init_handle_by_constructor__( + _ffi_api.ProducerStore, producer, value, indices, span # type: ignore + ) @tvm._ffi.register_object("tir.Allocate") @@ -316,7 +324,7 @@ class Allocate(Stmt): def __init__(self, buffer_var, dtype, extents, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span + _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span # type: ignore ) @@ -343,7 +351,9 @@ class AttrStmt(Stmt): """ def __init__(self, node, attr_key, value, body, span=None): - self.__init_handle_by_constructor__(_ffi_api.AttrStmt, node, attr_key, value, body, span) + self.__init_handle_by_constructor__( + _ffi_api.AttrStmt, node, attr_key, value, body, span # type: ignore + ) @tvm._ffi.register_object("tir.ProducerRealize") @@ -364,13 +374,22 @@ class ProducerRealize(Stmt): body : Stmt The realize body + storage_scope : str + The storage scope associated with this realization + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body, span=None): + def __init__(self, producer, bounds, condition, body, storage_scope="", span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body, span + _ffi_api.ProducerRealize, + producer, + bounds, + condition, + body, + storage_scope, + span, # type: ignore ) @@ -388,7 +407,7 @@ class SeqStmt(Stmt): """ def __init__(self, seq, span=None): - self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) + self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) # type: ignore def __getitem__(self, i): return self.seq[i] @@ -418,7 +437,7 @@ class IfThenElse(Stmt): def __init__(self, condition, then_case, else_case, span=None): self.__init_handle_by_constructor__( - _ffi_api.IfThenElse, condition, then_case, else_case, span + _ffi_api.IfThenElse, condition, then_case, else_case, span # type: ignore ) @@ -436,7 +455,7 @@ class Evaluate(Stmt): """ def __init__(self, value, span=None): - self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) + self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore @tvm._ffi.register_object("tir.Prefetch") @@ -456,7 +475,7 @@ class Prefetch(Stmt): """ def __init__(self, buffer, bounds, span=None): - self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) + self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) # type: ignore @tvm._ffi.register_object("tir.BufferRegion") @@ -476,7 +495,7 @@ class BufferRegion(Object): region: List[Range] def __init__(self, buffer: Buffer, region: List[Range]): - self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) + self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore @tvm._ffi.register_object("tir.MatchBufferRegion") @@ -496,7 +515,9 @@ class MatchBufferRegion(Object): source: BufferRegion def __init__(self, buffer: Buffer, source: BufferRegion): - self.__init_handle_by_constructor__(_ffi_api.MatchBufferRegion, buffer, source) + self.__init_handle_by_constructor__( + _ffi_api.MatchBufferRegion, buffer, source # type: ignore + ) @tvm._ffi.register_object("tir.Block") @@ -567,7 +588,7 @@ def __init__( if annotations is None: annotations = {} self.__init_handle_by_constructor__( - _ffi_api.Block, + _ffi_api.Block, # type: ignore iter_vars, reads, writes, @@ -578,7 +599,7 @@ def __init__( match_buffers, annotations, span, - ) + ) # type: ignore @tvm._ffi.register_object("tir.BlockRealize") @@ -615,12 +636,12 @@ def __init__( if isinstance(predicate, bool): predicate = const(predicate, "bool") self.__init_handle_by_constructor__( - _ffi_api.BlockRealize, + _ffi_api.BlockRealize, # type: ignore iter_values, predicate, block, span, - ) + ) # type: ignore def stmt_seq(*args): diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 4ec755cdf922..56dc1c20c2b3 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -43,7 +43,7 @@ def ir_transform(stmt, preorder, postorder, only_enable=None): result : tvm.tir.Stmt The result. """ - return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) + return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) # type: ignore def post_order_visit(stmt, fvisit): @@ -55,7 +55,7 @@ def post_order_visit(stmt, fvisit): fvisit: function The visitor function. """ - return _ffi_api.PostOrderVisit(stmt, fvisit) + return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore def substitute(node, vmap): @@ -74,4 +74,4 @@ def substitute(node, vmap): result : tvm.tir.Stmt The result. """ - return _ffi_api.Substitute(node, vmap) + return _ffi_api.Substitute(node, vmap) # type: ignore diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 374e731725be..9450ade34e67 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -18,6 +18,7 @@ import inspect import types import functools +from typing import Callable, List, Optional, Union import tvm._ffi from tvm.ir.transform import Pass, PassInfo @@ -47,7 +48,10 @@ def __init__(self, *args, **kwargs): def _pass_func(func, mod, ctx): return inst.transform_function(func, mod, ctx) - self.__init_handle_by_constructor__(_ffi_api.CreatePrimFuncPass, _pass_func, pass_info) + self.__init_handle_by_constructor__( + _ffi_api.CreatePrimFuncPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst def __getattr__(self, name): @@ -61,7 +65,12 @@ def __getattr__(self, name): return PyFunctionPass -def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None): +def prim_func_pass( + pass_func=None, + opt_level: int = None, + name: Optional[str] = None, + required: Optional[List[str]] = None, +) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. This function returns a callback when pass_func @@ -123,7 +132,7 @@ def transform(func, mod, ctx): assert isinstance(function_pass, transform.FunctionPass) assert function_pass.info.opt_level == 2 - # Given a module m, the optimization could be invoked as the follwoing: + # Given a module m, the optimization could be invoked as the following: updated_mod = function_pass(m) # Now constant folding should have been applied to every function in # the provided module m. And the updated module will be returned. @@ -144,7 +153,7 @@ def create_function_pass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): raise TypeError("pass_func must be a callable for Module pass") - return _ffi_api.CreatePrimFuncPass(pass_arg, info) + return _ffi_api.CreatePrimFuncPass(pass_arg, info) # type: ignore if pass_func: return create_function_pass(pass_func) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 51330f80afc6..537499a27fa9 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -16,6 +16,7 @@ # under the License. """Wrapping existing transformations.""" # pylint: disable=invalid-name +from typing import Optional from . import _ffi_api from . import function_pass as _fpass @@ -39,7 +40,7 @@ def Apply(ftransform): def _transform(func, mod, ctx): return ftransform(func) - return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") + return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore def Filter(fcond): @@ -59,7 +60,7 @@ def Filter(fcond): def _transform(func, mod, ctx): return func if fcond(func) else None - return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") + return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") # type: ignore def InjectPrefetch(): @@ -70,10 +71,10 @@ def InjectPrefetch(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectPrefetch() + return _ffi_api.InjectPrefetch() # type: ignore -def StorageFlatten(cache_line_size, create_bound_attribute=False): +def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): """Flatten the multi-dimensional read/write to 1D. @@ -91,10 +92,10 @@ def StorageFlatten(cache_line_size, create_bound_attribute=False): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) + return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) # type: ignore -def InjectCopyIntrin(pragma_key, fintrin): +def InjectCopyIntrin(pragma_key: str, fintrin): """Inject virtual thread loops. Parameters @@ -110,7 +111,7 @@ def InjectCopyIntrin(pragma_key, fintrin): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) + return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) # type: ignore def CoProcSync(): @@ -121,10 +122,10 @@ def CoProcSync(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.CoProcSync() + return _ffi_api.CoProcSync() # type: ignore -def LiftAttrScope(attr_key): +def LiftAttrScope(attr_key: str): """Lift common attrs with attr_key to outer scope. Parameters @@ -137,7 +138,7 @@ def LiftAttrScope(attr_key): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LiftAttrScope(attr_key) + return _ffi_api.LiftAttrScope(attr_key) # type: ignore def LoopPartition(): @@ -148,10 +149,10 @@ def LoopPartition(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LoopPartition() + return _ffi_api.LoopPartition() # type: ignore -def VectorizeLoop(enable_vectorize=True): +def VectorizeLoop(enable_vectorize: bool = True): """Lower vectorization loops. Parameters @@ -165,7 +166,7 @@ def VectorizeLoop(enable_vectorize=True): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.VectorizeLoop(enable_vectorize) + return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore def InjectVirtualThread(): @@ -176,7 +177,7 @@ def InjectVirtualThread(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectVirtualThread() + return _ffi_api.InjectVirtualThread() # type: ignore def InjectDoubleBuffer(): @@ -187,7 +188,7 @@ def InjectDoubleBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InjectDoubleBuffer() + return _ffi_api.InjectDoubleBuffer() # type: ignore def StorageRewrite(): @@ -202,7 +203,7 @@ def StorageRewrite(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.StorageRewrite() + return _ffi_api.StorageRewrite() # type: ignore def UnrollLoop(): @@ -215,7 +216,7 @@ def UnrollLoop(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.UnrollLoop() + return _ffi_api.UnrollLoop() # type: ignore def RemoveNoOp(): @@ -226,7 +227,7 @@ def RemoveNoOp(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RemoveNoOp() + return _ffi_api.RemoveNoOp() # type: ignore def BF16Legalize(): @@ -238,7 +239,7 @@ def BF16Legalize(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Legalize() + return _ffi_api.BF16Legalize() # type: ignore def BF16Promote(): @@ -250,7 +251,7 @@ def BF16Promote(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Promote() + return _ffi_api.BF16Promote() # type: ignore def BF16CastElimination(): @@ -269,7 +270,7 @@ def BF16CastElimination(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16CastElimination() + return _ffi_api.BF16CastElimination() # type: ignore def BF16TypeLowering(): @@ -281,7 +282,7 @@ def BF16TypeLowering(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16TypeLowering() + return _ffi_api.BF16TypeLowering() # type: ignore def RewriteUnsafeSelect(): @@ -292,7 +293,7 @@ def RewriteUnsafeSelect(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RewriteUnsafeSelect() + return _ffi_api.RewriteUnsafeSelect() # type: ignore def Simplify(): @@ -303,7 +304,7 @@ def Simplify(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.Simplify() + return _ffi_api.Simplify() # type: ignore def InstrumentBoundCheckers(): @@ -314,7 +315,7 @@ def InstrumentBoundCheckers(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InstrumentBoundCheckers() + return _ffi_api.InstrumentBoundCheckers() # type: ignore def LowerCustomDatatypes(): @@ -327,24 +328,25 @@ def LowerCustomDatatypes(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerCustomDatatypes() + return _ffi_api.LowerCustomDatatypes() # type: ignore -def MakePackedAPI(num_unpacked_params=0): +def MakePackedAPI(num_unpacked_params: int = -1): """Transform the PrimFuncs in the module to a packed func API. Parameters ---------- num_unpacked_params : int Number of parameters that we hope to directly pass via normal arguments - following the PackedFunc input signature. + following the PackedFunc input signature. If it is specified as -1 or it + is less than the number of arguments, the pass will packed arguments still. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MakePackedAPI(num_unpacked_params) + return _ffi_api.MakePackedAPI(num_unpacked_params) # type: ignore def MakeUnpackedAPI(): @@ -355,7 +357,7 @@ def MakeUnpackedAPI(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MakeUnpackedAPI() + return _ffi_api.MakeUnpackedAPI() # type: ignore def SplitHostDevice(): @@ -366,7 +368,7 @@ def SplitHostDevice(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.SplitHostDevice() + return _ffi_api.SplitHostDevice() # type: ignore def DecorateDeviceScope(): @@ -377,7 +379,7 @@ def DecorateDeviceScope(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.DecorateDeviceScope() + return _ffi_api.DecorateDeviceScope() # type: ignore def SkipAssert(): @@ -388,10 +390,10 @@ def SkipAssert(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.SkipAssert() + return _ffi_api.SkipAssert() # type: ignore -def ThreadSync(storage_scope): +def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. Parameters @@ -404,7 +406,7 @@ def ThreadSync(storage_scope): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ThreadSync(storage_scope) + return _ffi_api.ThreadSync(storage_scope) # type: ignore def LowerThreadAllreduce(): @@ -415,7 +417,7 @@ def LowerThreadAllreduce(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerThreadAllreduce() + return _ffi_api.LowerThreadAllreduce() # type: ignore def InferFragment(): @@ -426,7 +428,7 @@ def InferFragment(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.InferFragment() + return _ffi_api.InferFragment() # type: ignore def LowerWarpMemory(): @@ -437,7 +439,7 @@ def LowerWarpMemory(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerWarpMemory() + return _ffi_api.LowerWarpMemory() # type: ignore def LowerTVMBuiltin(): @@ -448,7 +450,7 @@ def LowerTVMBuiltin(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerTVMBuiltin() + return _ffi_api.LowerTVMBuiltin() # type: ignore def LegalizePackedCalls(): @@ -459,7 +461,7 @@ def LegalizePackedCalls(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LegalizePackedCalls() + return _ffi_api.LegalizePackedCalls() # type: ignore def LowerIntrin(): @@ -470,7 +472,7 @@ def LowerIntrin(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerIntrin() + return _ffi_api.LowerIntrin() # type: ignore def LowerDeviceStorageAccessInfo(): @@ -485,7 +487,7 @@ def LowerDeviceStorageAccessInfo(): ---- Run this pass after all storage access analysis finish. """ - return _ffi_api.LowerDeviceStorageAccessInfo() + return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore def CombineContextCall(): @@ -496,10 +498,10 @@ def CombineContextCall(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.CombineContextCall() + return _ffi_api.CombineContextCall() # type: ignore -def NarrowDataType(target_bits): +def NarrowDataType(target_bits: int): """Narrow down PrimExpr datatype in stmt to target_bits. Parameters @@ -516,7 +518,7 @@ def NarrowDataType(target_bits): ---- Run this pass after StorageFlatten. """ - return _ffi_api.NarrowDataType(target_bits) + return _ffi_api.NarrowDataType(target_bits) # type: ignore def VerifyMemory(): @@ -527,12 +529,12 @@ def VerifyMemory(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.VerifyMemory() + return _ffi_api.VerifyMemory() # type: ignore # pylint: disable=no-else-return,inconsistent-return-statements -def HoistIfThenElse(variant=None): - """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. +def HoistIfThenElse(variant: Optional[str] = None): + """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. Parameters ---------- @@ -540,7 +542,7 @@ def HoistIfThenElse(variant=None): The variant of the pass. variant can have any one of following values ["basic", None(Default)]. - The basic variant supports basic hoisting scenarios where it exepects + The basic variant supports basic hoisting scenarios where it expects the For & If Nodes are in place consecutively and does not involve global scope variables or more advanced scenarios. @@ -555,20 +557,20 @@ def HoistIfThenElse(variant=None): The result pass """ if variant == "basic": - return _ffi_api.HoistIfThenElseBasic() + return _ffi_api.HoistIfThenElseBasic() # type: ignore elif variant is None: - return _ffi_api.HoistIfThenElse() + return _ffi_api.HoistIfThenElse() # type: ignore def LowerInitBlock(): - """Lower block init stmt into IfThenElse stmts + """Lower block init stmt into IfThenElse statements. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerInitBlock() + return _ffi_api.LowerInitBlock() # type: ignore def PlanAndUpdateBufferAllocationLocation(): @@ -581,7 +583,7 @@ def PlanAndUpdateBufferAllocationLocation(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.PlanAndUpdateBufferAllocationLocation() + return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore def ConvertBlocksToOpaque(): @@ -594,7 +596,7 @@ def ConvertBlocksToOpaque(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.ConvertBlocksToOpaque() + return _ffi_api.ConvertBlocksToOpaque() # type: ignore def CompactBufferAllocation(): @@ -639,7 +641,18 @@ def CompactBufferAllocation(): The result pass """ - return _ffi_api.CompactBufferAllocation() + return _ffi_api.CompactBufferAllocation() # type: ignore + + +def LowerMatchBuffer(): + """Remove match buffers inside the block. Also, it will validate the binding. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerMatchBuffer() # type: ignore def FlattenBuffer(): @@ -652,4 +665,16 @@ def FlattenBuffer(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FlattenBuffer() + return _ffi_api.FlattenBuffer() # type: ignore + + +def MergeDynamicSharedMemoryAllocations(): + """This pass merges multiple TIR-level dynamic shared memory allocations + into one allocation. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore diff --git a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py index f4cd9d899b73..4eed56a22572 100644 --- a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py +++ b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py @@ -247,7 +247,7 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output return s -def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): +def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2): """Spatial pack compute for Conv2d NHWC""" out_dtype = out_dtype or data.dtype @@ -273,12 +273,21 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0]) # ==================== define configuration space ==================== - n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) + # If it has dynamic shape in batch, we fix the split factor to 1 + n = cfg.axis(N) if isinstance(N, int) else cfg.axis(1) + oc, oh, ow = cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW) - oco, oci = cfg.define_split("tile_co", oc, num_outputs=2) - oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2) - owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2) + if num_tile == 2: # for arm cpu + oco, oci = cfg.define_split("tile_co", oc, num_outputs=2) + oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2) + owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2) + elif num_tile == 3: # for mali gpu + oco, _, oci = cfg.define_split("tile_co", oc, num_outputs=3) + oho, _, ohi = cfg.define_split("tile_oh", oh, num_outputs=3) + owo, _, owi = cfg.define_split("tile_ow", ow, num_outputs=3) + else: + raise RuntimeError("Invalid num_tile") cfg.define_reorder( "reorder_conv", @@ -344,7 +353,7 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ conv = te.compute( ovshape, lambda n, oho, owo, oco, ohi, owi, oci: te.sum( - data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) + data_vec[n, oho, owo, kh, kw, ic, ohi, owi].astype(out_dtype) * kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype), axis=[ic, kh, kw], ), diff --git a/python/tvm/topi/arm_cpu/group_conv2d.py b/python/tvm/topi/arm_cpu/group_conv2d.py index d852b9acef66..81b2c7260f05 100644 --- a/python/tvm/topi/arm_cpu/group_conv2d.py +++ b/python/tvm/topi/arm_cpu/group_conv2d.py @@ -42,7 +42,9 @@ def schedule_group_conv2d_nchw(outs): return schedule_group_conv2d_nchwc(outs) -def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"): +def _get_default_config( + cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW" +): """ Get default schedule config for the workload """ @@ -54,7 +56,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, static_data_shape.append(dim) data = te.placeholder(static_data_shape, dtype=data.dtype) - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout) _fallback_schedule(cfg, wkl) @@ -158,6 +160,7 @@ def group_conv2d_nchw_spatial_pack( ), strides, padding, + dilation, groups, out_dtype, ) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index fb91912f29a0..3fc8a584b557 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -27,9 +27,49 @@ @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y, out_shape=None): - """Compute conv2d with NCHW layout""" - return nn.batch_matmul(x, y) +def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + return nn.batch_matmul( + x, + y, + oshape=out_shape, + out_dtype=out_dtype, + transpose_a=transpose_a, + transpose_b=transpose_b, + ) @autotvm.register_topi_schedule("batch_matmul.cuda") @@ -140,31 +180,54 @@ def _callback(op): @autotvm.register_topi_compute("batch_matmul_cublas.cuda") -def batch_matmul_cublas(cfg, x, y, out_shape=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. +def batch_matmul_cublas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `x` and `y`. + + Both `x` and `y` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + x : tvm.te.Tensor - 3-D with shape [batch, M, K] + 3-D with shape [batch, M, K] or [batch, K, M]. y : tvm.te.Tensor - 3-D with shape [batch, N, K] + 3-D with shape [batch, K, N] or [batch, N, K]. - out_shape : None - The output shape + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - b, m, k = get_const_tuple(x.shape) - b, n, k = get_const_tuple(y.shape) + if transpose_a: + b, k, m = get_const_tuple(x.shape) + else: + b, m, k = get_const_tuple(x.shape) + if transpose_b: + b, n, k = get_const_tuple(y.shape) + else: + b, k, n = get_const_tuple(y.shape) if all([isinstance(s, int) for s in [b, m, n, k]]): cfg.add_flop(b * m * k * n * 2) - return cublas.batch_matmul(x, y, False, True) + return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b) @autotvm.register_topi_schedule("batch_matmul_cublas.cuda") @@ -175,7 +238,31 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): - """Batch Matmul operator for int8 on CUDA""" + """Batch Matmul operator for int8 on CUDA. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + x : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + y : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 962a8af7853b..a56d3c36ba33 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,10 +29,10 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None): +def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): """batch matmul tensorcore operator on cuda""" # todo: deal with out_shape for broadcast, liuxin.ai - return batch_matmul_tensorcore_cuda(x, y) + return batch_matmul_tensorcore_cuda(x, y, out_dtype) @autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda") @@ -57,10 +57,8 @@ def _schedule(cfg, s, C): A, B = s[C].op.input_tensors batch, m_dim, k_dim = get_const_tuple(A.shape) batch, n_dim, k_dim = get_const_tuple(B.shape) + data_dtype = A.dtype out_dtype = C.dtype - # inline astype fp16 - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -94,15 +92,28 @@ def _schedule(cfg, s, C): cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use - if m_dim % 32 == 0 and n_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif m_dim % 16 == 0 and n_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif m_dim % 8 == 0 and n_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "uint8", "int8"]: + if m_dim % 32 == 0 and n_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif m_dim % 16 == 0 and n_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif m_dim % 8 == 0 and n_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + elif data_dtype in ["int4", "uint4"]: + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + raise ValueError("data dtype %s is not yet supported" % data_dtype) warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -110,16 +121,8 @@ def _schedule(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -211,10 +214,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -236,7 +237,7 @@ def shared_shedule(stage, strides): "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), - "float16", + data_dtype, ), ) s[BF].tensorize( @@ -248,7 +249,7 @@ def shared_shedule(stage, strides): "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), - "float16", + data_dtype, ), ) s[CF].tensorize( @@ -270,7 +271,7 @@ def _callback(op): return s -def batch_matmul_tensorcore_cuda(x, y): +def batch_matmul_tensorcore_cuda(x, y, out_dtype=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -294,22 +295,26 @@ def batch_matmul_tensorcore_cuda(x, y): assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" batch, M, K = x.shape N = y.shape[1] - out_dtype = x.dtype - - assert ( - (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) - or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) - or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) - ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" - x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16")) - y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16")) + if out_dtype is None: + out_dtype = x.dtype + + assert x.dtype == y.dtype + assert x.dtype in ["float16", "uint8", "int8", "uint4", "int4"] + if x.dtype in ["float16", "uint8", "int8"]: + assert ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" + else: + assert ( + M % 8 == 0 and K % 32 == 0 and N % 8 == 0 + ), "The shape of (M, K, N) must be multiple of (8, 32, 8)" k = te.reduce_axis((0, K), name="k") return te.compute( (batch, M, N), - lambda b, i, j: te.sum( - x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k - ), + lambda b, i, j: te.sum(x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k), tag="batch_matmul_tensorcore", ) diff --git a/python/tvm/topi/cuda/conv1d_transpose_ncw.py b/python/tvm/topi/cuda/conv1d_transpose_ncw.py index 58f53eab20ac..2098aa9089c6 100644 --- a/python/tvm/topi/cuda/conv1d_transpose_ncw.py +++ b/python/tvm/topi/cuda/conv1d_transpose_ncw.py @@ -142,7 +142,7 @@ def _callback(op): ##### space definition begin ##### n, f, x = s[conv].op.axis rc = s[conv].op.reduce_axis[0] - cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + cfg.define_split("tile_n", cfg.axis(n if isinstance(n, int) else 1), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3) diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py index 001411d6e4c9..02470bab5228 100644 --- a/python/tvm/topi/cuda/conv2d_int8.py +++ b/python/tvm/topi/cuda/conv2d_int8.py @@ -98,9 +98,9 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_ ) out_channels, in_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape) - assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format( - oc_block_factor - ) + assert ( + out_channels % oc_block_factor == 0 + ), "Number of output channels should be multiple of {}".format(oc_block_factor) packed_kernel = te.compute( ( out_channels // oc_block_factor, diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/cuda/conv2d_nhwc.py index 991585587bbf..e4361e30b5c3 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/cuda/conv2d_nhwc.py @@ -43,12 +43,15 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): AL = s.cache_read(AA, "local", [OL]) WL = s.cache_read(WW, "local", [OL]) + # Currently Conv2d NHWC only support dynamic shpe in batch + dynamic_batch = isinstance(s[output].op.axis[0].dom.extent, tvm.tir.expr.Var) + # Schedule for autotvm - cfg.define_knob("tile_n", [2, 4, 8]) + cfg.define_knob("tile_n", [1] if dynamic_batch else [2, 4, 8]) cfg.define_knob("tile_c", [2, 4, 8]) - cfg.define_knob("num_thread_n", [4, 8, 16]) + cfg.define_knob("num_thread_n", [1] if dynamic_batch else [4, 8, 16]) cfg.define_knob("num_thread_c", [4, 8, 16]) - cfg.define_knob("vthread_n", [1, 2]) + cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 430f8044528c..9bac34cbeaf7 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -60,22 +60,27 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): out_dtype = data.dtype batch, in_dim = get_const_tuple(data.shape) out_dim, _ = get_const_tuple(weight.shape) - assert ( - (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) - or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) - or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) - ), ( - "The shape of (batch, in_dim, out_dim) " - "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" - ) + + assert data.dtype == weight.dtype + assert data.dtype in ["float16", "int8", "uint8", "int4", "uint4"] + if data.dtype in ["float16", "int8", "uint8"]: + assert ( + (batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) + or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) + or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0) + ), ( + "The shape of (batch, in_dim, out_dim) " + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + ) + else: + assert ( + batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0 + ), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)" + k = te.reduce_axis((0, in_dim), name="k") - data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16")) - weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16")) matmul = te.compute( (batch, out_dim), - lambda i, j: te.sum( - data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k - ), + lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k), name="T_dense", tag="dense_tensorcore", ) @@ -92,9 +97,8 @@ def _schedule_dense_tensorcore(cfg, s, C): """Schedule dense operator using Tensorcore""" A, B = s[C].op.input_tensors batch, out_dim = get_const_tuple(C.shape) + data_dtype = A.dtype out_dtype = C.dtype - s[A].compute_inline() - s[B].compute_inline() # Explicit memory access AS = s.cache_read(A, "shared", [C]) @@ -127,16 +131,29 @@ def _schedule_dense_tensorcore(cfg, s, C): cfg.define_knob("offsetCS", [0, 8]) cfg.define_knob("vec", [1, 2, 4, 8]) - # Ensure that the default parameters are applicable when autotvm is not in use - if batch % 32 == 0 and out_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif batch % 16 == 0 and out_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif batch % 8 == 0 and out_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "int8", "uint8"]: + # Ensure that the default parameters are applicable when autotvm is not in use + if batch % 32 == 0 and out_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif batch % 16 == 0 and out_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif batch % 8 == 0 and out_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + elif data_dtype in ["int4", "uint4"]: + wmma_m = wmma_n = 8 + wmma_k = 32 + else: + raise ValueError("data dtype %s is not yet supported" % data_dtype) warp_size = 32 - wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val @@ -144,16 +161,8 @@ def _schedule_dense_tensorcore(cfg, s, C): chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val - wmma_m = cfg["wmma_m"].val vec = cfg["vec"].val - if wmma_m == 16: - wmma_n = 16 - elif wmma_m == 8: - wmma_n = 32 - elif wmma_m == 32: - wmma_n = 8 - # Define the stride of intrin functions AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset @@ -245,10 +254,8 @@ def shared_shedule(stage, strides): shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) - # TODO: add checking here, datatype casting may cause precision loss - in_dtype = "float16" - AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) - BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) + AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), @@ -264,13 +271,13 @@ def shared_shedule(stage, strides): s[AF].tensorize( b_ii, intrin_wmma_load_matrix_A( - AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16" + AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype ), ) s[BF].tensorize( o_ii, intrin_wmma_load_matrix_W( - BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16" + BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype ), ) s[CF].tensorize( diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index cce56b796cea..0faddc31c25a 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable, """Schedule for composition of injective operator""" +import numpy as np + import tvm from tvm import te from .. import utils @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out): sch: Schedule The updated schedule. """ + + def find_nearest_small_factor(num, target): + """Find the nearest factor of the given number that is smaller than the target.""" + for i in range(target, 0, -1): + if num % i == 0: + return i + # Unreachable because i=1 must hold. + return -1 + fused = sch[out].fuse(*sch[out].op.axis) num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 - # vectorize on fp16 data type. This allows to better utilize the memory - # bandwidth. - vector_width = 4 if out.dtype == "float16" else 1 + # Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization. + vector_width = 2 if out.dtype == "float16" else 1 is_dynamic_output = False for dim in out.shape: @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out): try: const_size = utils.get_const_int(out_len) + + # Adjust block and thread to make sure they are dividable so that vectorize can be + # correctly applied. + if vector_width > 1 and const_size % vector_width == 0: + remain_total_size = const_size // vector_width + cand_sizes = [] + for max_size in [num_thread, max_block]: + cand_sizes.append( + max_size + if remain_total_size % max_size == 0 + else find_nearest_small_factor(remain_total_size, max_size) + ) + remain_total_size //= cand_sizes[-1] + + # If the product of candidate dividable (block * thread) is too small, + # then the performance may be worse even half2 is enabled. Note that 0.7 + # is just a heuristic ratio and may not be optimal for all workloads. + if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: + num_thread, max_block = cand_sizes + need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index b9d02d9c81d8..3aef96b24ed3 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition """Schedule for reduce operators""" from __future__ import absolute_import as _abs +from operator import mul +from functools import reduce import tvm from tvm import te from .. import tag @@ -80,13 +82,18 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): if is_idx_reduce: sch[temp_idx_input].compute_at(sch[real_output], outer_in) sch[temp_val_input].compute_at(sch[real_output], outer_in) + sch[real_output].set_store_predicate( + tvm.tir.all( + thread_x.equal(0), block_x * num_thread + thread_y < reduce(mul, real_output.shape) + ) + ) else: if is_idx_reduce: spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis)) sch[real_output].bind(spatial_axis, te.thread_axis("blockIdx.x")) sch[temp_idx_input].compute_at(sch[real_output], spatial_axis) sch[temp_val_input].compute_at(sch[real_output], spatial_axis) - sch[real_output].set_store_predicate(thread_x.equal(0)) + sch[real_output].set_store_predicate(thread_x.equal(0)) return sch diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index cee13d7e01a2..fa7545cd323a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -772,9 +772,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # We combine all the indices dimensions but the first one into a single - # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the update dimensions. + atomic_add_return = ib.allocate( + updates.dtype, (1,), name="atomic_add_return", scope="local" + ) + fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i @@ -787,42 +788,91 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) - ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_updates_dimension, tdim) - ib.scope_attr(bx, "thread_extent", bdim) - with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: - index = i * fused_updates_dimension + bx * tdim + tx + with ib.new_scope(): + bdim = ceil_div(fused_shape, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim) + ib.scope_attr(tx, "thread_extent", tdim) + + index = bx * tdim + tx with ib.if_scope(index < fused_shape): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx - with ib.if_scope(j < fused_updates_dimension): - offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] - elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + # For better performance, we introduce blockIdx.y to implement for-loops + # within one thread. + # The code is parallel over the scattered indices, so we use atomic_add + # to guarantee correctness when mode=="add" + + # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" + # So we fallback to normal algorithm, using "+=" rather than atomic_add + + # TODO (CaptainDuke): + # Since multiple threads compete for the same write index, which leads to + # non-determinstic output for update mode. We could add a new attribute, + # "allow_non_deterministic", which can be conditionally set to True by + # each frontend when non-determinsm is allowed. + cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) + with ib.new_scope(): + if ( + mode == "add" + and cur_target_kind not in ["vulkan", "metal"] + and updates.dtype in ["int32", "float32"] + ): + bdim_x = fused_indices_dimension + bdim_y = ceil_div(fused_updates_dimension, tdim) + # In case of large input sizes, fused_indices_dimension might be too large. + # So we use blockIdx.x because holds larger scales. + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(by, "thread_extent", bdim_y) + ib.scope_attr(tx, "thread_extent", tdim) + + j = by * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + up_index = bx * fused_updates_dimension + j + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[bx + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) + else: + bdim_x = ceil_div(fused_updates_dimension, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(tx, "thread_extent", tdim) + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the + # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, + # ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index eb7c71ddf1c9..50bcafd9f9a7 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -19,7 +19,7 @@ import logging import math -from tvm import relay +from tvm import relay, tir from .. import nn @@ -54,14 +54,23 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - B, M, K = x_tensor.shape - B, N, K = y_tensor.shape - M = M.value - K = K.value - N = N.value + B, M, K = x_tensor.shape + B, N, K = y_tensor.shape + if ( + isinstance(B, tir.expr.Any) + or isinstance(M, tir.expr.Any) + or isinstance(K, tir.expr.Any) + or isinstance(N, tir.expr.Any) + ): + # Dynamic shape do not support alter op layout now + return None + + M = M.value + K = K.value + N = N.value + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -70,31 +79,32 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): ): # no need to pad return None - candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops > 2: - logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + elif dtype in ["int4", "uint4"]: + if M % 8 == 0 and K % 32 == 0 and N % 8 == 0: + # no need to pad return None - logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.batch_matmul(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops > 2: + logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) if dm or dk else x + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) if dn or dk else y + out_ = relay.nn.batch_matmul(x_, y_, attrs.out_dtype) + out = ( + relay.strided_slice(out_, begin=[0, 0, 0], end=[x.value for x in output_tensor.shape]) + if dm or dn + else out_ + ) + return out @nn.dense_legalize.register("cuda") @@ -115,6 +125,7 @@ def _dense_legalize(attrs, inputs, arg_types): result : tvm.relay.Expr The legalized expr """ + new_attrs = {k: attrs[k] for k in attrs.keys()} # Collect the input tensors. x_tensor, y_tensor = arg_types[0], arg_types[1] dtype = x_tensor.dtype @@ -125,18 +136,18 @@ def _dense_legalize(attrs, inputs, arg_types): # Collect the input exprs. x, y = inputs - # Pad input and output channels to use tensorcore schedule. - if dtype in ["float16"]: # todo: support int8/int4 - M, K = x_tensor.shape - N, K = y_tensor.shape - try: - M = M.value - K = K.value - N = N.value - except AttributeError: - # todo: deal with unfixed shape when compiling wdl model - return None + M, K = x_tensor.shape + N, K = y_tensor.shape + try: + M = M.value + K = K.value + N = N.value + except AttributeError: + # todo: deal with unfixed shape when compiling wdl model + return None + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16", "int8", "uint8"]: # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) if ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) @@ -147,30 +158,31 @@ def _dense_legalize(attrs, inputs, arg_types): return None candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] - (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) - - if extra_flops_ratio > 2: - logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + elif dtype in ["int4", "uint4"]: + if M % 8 == 0 and K % 32 == 0 and N % 8 == 0: + # no need to pad return None - - logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) - - if dm or dk: - x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) - else: - x_ = x - if dn or dk: - y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) - else: - y_ = y - out_ = relay.nn.dense(x_, y_) - if dm or dn: - original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) - else: - out = out_ - return out - return None + candidates = [(8, 32, 8)] + else: + return None + + (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates) + + if extra_flops_ratio > 2: + logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + return None + + logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) + + x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) if dm or dk else x + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) if dn or dk else y + out_ = relay.nn.dense(x_, y_, **new_attrs) + out = ( + relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) + if dm or dn + else out_ + ) + return out def pad_to_tensorcore(M, K, N, candidates): diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 42d0455665a1..d1abffd12972 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -23,6 +23,25 @@ from .. import tag +def get_1d_indices(indices, layout="NCW"): + """Get 1d indices""" + (cc, inum, ic) = (0, 0, 0) + if layout == "NWC": + n, x, c = indices + cc = None + elif layout == "NCW": + n, c, x = indices + cc = None + elif ncw_pack_layout(layout): + n, c, x, inum, ic = indices + else: + # else must be NCHWxc + assert ncw_xc_layout(layout) + n, c, x, cc = indices + + return n, c, x, cc, inum, ic + + def get_2d_indices(indices, layout="NCHW"): """Get 2d indices""" (cc, inum, ic) = (0, 0, 0) @@ -42,11 +61,39 @@ def get_2d_indices(indices, layout="NCHW"): return n, c, y, x, cc, inum, ic -def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic): +def get_3d_indices(indices, layout="NCDHW"): + """Get 3d indices""" + if layout == "NDHWC": + n, z, y, x, c = indices + cc = None + elif layout == "NCDHW": + n, c, z, y, x = indices + cc = None + else: + n, c, z, y, x, cc = indices + + return n, c, z, y, x, cc + + +def get_1d_pixel(data, layout, image_width, n, c, x, cc, ib, ic): + """Get 1d pixel""" + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NWC": + return data(n, x, c).astype("float") + if layout == "NCW": + return data(n, c, x).astype("float") + if ncw_pack_layout(layout): + return data(n, c, x, ib, ic).astype("float") + + # else must be NCHWxc + assert ncw_xc_layout(layout) + return data(n, c, x, cc).astype("float") + + +def get_2d_pixel(data, layout, image_height, image_width, n, c, y, x, cc, ib, ic): """Get 2d pixel""" - if boxes is None: - y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) - x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) if layout == "NHWC": return data(n, y, x, c).astype("float") if layout == "NCHW": @@ -59,53 +106,99 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") -def get_iny_inx( - y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode -): - """Infer input x,y from output x,y with various coordinate transformation methods""" - scale_y = te.div(image_height.astype("float"), target_height.astype("float")) +def get_3d_pixel(data, layout, image_depth, image_height, image_width, n, c, z, y, x, cc): + """Get 3d pixel""" + z = tvm.te.max(tvm.te.min(z, image_depth - 1), 0) + y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NDHWC": + return data(n, z, y, x, c).astype("float") + if layout == "NCDHW": + return data(n, c, z, y, x).astype("float") + # else must be NCDHWxc + return data(n, c, z, y, x, cc).astype("float") + + +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" scale_x = te.div(image_width.astype("float"), target_width.astype("float")) if coordinate_transformation_mode == "half_pixel": - in_y = (y + 0.5) * scale_y - 0.5 in_x = (x + 0.5) * scale_x - 0.5 elif coordinate_transformation_mode == "align_corners": - in_y = (image_height - 1).astype("float") / (target_height - 1) * y in_x = (image_width - 1).astype("float") / (target_width - 1) * x elif coordinate_transformation_mode == "asymmetric": - in_y = scale_y * y in_x = scale_x * x elif coordinate_transformation_mode == "pytorch_half_pixel": - in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) elif coordinate_transformation_mode == "tf_half_pixel_for_nn": - in_y = (y + 0.5) * scale_y in_x = (x + 0.5) * scale_x else: raise ValueError( "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) ) - return in_y, in_x + return in_x -def resize_nearest_neighbor( +def get_closest_index(in_x, rounding_method, boxes): + """get the closest index to a value based on a certain rounding method""" + if rounding_method == "round" or boxes is not None: + closest_x_index = te.round(in_x).astype("int32") + elif rounding_method == "round_prefer_floor": + closest_x_index = te.ceil(in_x - 0.5).astype("int32") + elif rounding_method == "round_prefer_ceil": + closest_x_index = te.floor(in_x + 0.5).astype("int32") + elif rounding_method == "floor": + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + closest_x_index = te.floor(in_x + epsilon).astype("int32") + elif rounding_method == "ceil": + # Subract epsilon from ceil to prevent gpu rounding errors. + epsilon = 1e-5 + closest_x_index = te.ceil(in_x - epsilon).astype("int32") + else: + raise ValueError("Uknown rounding method: {}".format(rounding_method)) + return closest_x_index + + +def _lerp(A, B, t): + """Perform Linear interpolation in 1D""" + return A * (1.0 - t) + B * t + + +def _cubic_spline_weights(t, alpha): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + +def _cubic_kernel(inputs, w): + """perform cubic interpolation in 1D""" + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + +def _resize_1d( indices, data, - image_height, image_width, - target_height, target_width, boxes=None, box_indices=None, + method=None, extrapolation_value=None, - layout="NCHW", + layout="NCW", coordinate_transformation_mode="align_corners", rounding_method="", + alpha=-0.5, + exclude_outside=0, out_dtype=None, ): - """Perform resize operation with nearest neighbor method on the data. - For details about Nearest-neighbor interpolation please refer to - https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. + """Perform resize operation on the data with selected method and options. Parameters ---------- @@ -113,19 +206,13 @@ def resize_nearest_neighbor( The indices of input data data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] - - image_height : integer - Input image height + inputs is a 3-D tensor with shape + [batch, channel, in_width] + or [batch, in_width, channel] image_width : integer Input image width - target_height : integer - The target resized image height - target_width : integer The target resized image width @@ -141,7 +228,7 @@ def resize_nearest_neighbor( Value used for extrapolation, when applicable. layout: string, optional - "NCHW", "NHWC", or "NCHWc". + "NCW", "NWC", or "NCWc". coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor @@ -153,6 +240,12 @@ def resize_nearest_neighbor( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + out_dtype: string, optional Type to return. If left None will be same as input type. @@ -161,11 +254,6 @@ def resize_nearest_neighbor( output : out_dtype The computed result with type out_dtype """ - if rounding_method == "": - if coordinate_transformation_mode == "align_corners": - rounding_method = "round" - else: - rounding_method = "floor" def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: @@ -174,136 +262,127 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): dtype = data_dtype return value.astype(dtype) - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) + n, c, x, cc, inum, ic = get_1d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + in_x = get_inx( + x, + image_width, + target_width, + coordinate_transformation_mode, + ) - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x - else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_1d_pixel( + data, + layout, image_width, - target_height, - target_width, - coordinate_transformation_mode, + box_idx, + c, + closest_x_index, + cc, + inum, + ic, ) + elif method == "linear": + x_int = te.floor(in_x).astype("int32") - if rounding_method == "round" or boxes is not None: - closest_x_index = te.round(in_x).astype("int32") - closest_y_index = te.round(in_y).astype("int32") - elif rounding_method == "round_prefer_floor": - closest_x_index = te.ceil(in_x - 0.5).astype("int32") - closest_y_index = te.ceil(in_y - 0.5).astype("int32") - elif rounding_method == "round_prefer_ceil": - closest_x_index = te.floor(in_x + 0.5).astype("int32") - closest_y_index = te.floor(in_y + 0.5).astype("int32") - elif rounding_method == "floor": - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - closest_y_index = te.floor(in_y + epsilon).astype("int32") - closest_x_index = te.floor(in_x + epsilon).astype("int32") - elif rounding_method == "ceil": - # Subract epsilon from ceil to prevent gpu rounding errors. - epsilon = 1e-5 - closest_y_index = te.ceil(in_y - epsilon).astype("int32") - closest_x_index = te.ceil(in_x - epsilon).astype("int32") - else: - raise ValueError("Uknown rounding method: {}".format(rounding_method)) + x_lerp = in_x - x_int - value = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - closest_y_index, - closest_x_index, - cc, - inum, - ic, - ) + p = [0 for i in range(2)] + for i in range(2): + p[i] = get_1d_pixel( + data, + layout, + image_width, + box_idx, + c, + x_int + i, + cc, + inum, + ic, + ) + + value = _lerp(*p, x_lerp) + + elif method == "cubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [0 for i in range(4)] + for i in range(4): + p[i] = get_1d_pixel( + data, + layout, + image_width, + box_idx, + c, + xint + i - 1, + cc, + inum, + ic, + ) + + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wx = sum(wx) + wx = [w / sum_wx for w in wx] + value = _cubic_kernel(p, wx) + + else: + raise ValueError("Unknown resize method:", method) if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) # use extrapolation_value if in_x is out of boundary value = tvm.tir.if_then_else( in_x < 0, extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, value), ) return _cast_output(value, data.dtype, out_dtype=out_dtype) -def resize_bilinear( - indices, +def resize1d( data, - image_height, - image_width, - target_height, - target_width, - boxes=None, - box_indices=None, - extrapolation_value=None, - layout="NCHW", - coordinate_transformation_mode="align_corners", + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, + output_shape=None, ): - - """Perform resize operation with bilinear method on the data. - For details about Bilinear interpolation please refer to - https://en.wikipedia.org/wiki/Bilinear_interpolation. + """Perform resize operation on the data. Parameters ---------- - indices : tuple - The indices of input data - data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] - - image_height : integer - Input image height - - image_width : integer - Input image width + inputs is a 3-D tensor with shape + [batch, channel in_width] + or [batch in_width, channel] - target_height : integer - The target resized image height - - target_width : integer - The target resized image width - - boxes : tvm.te.Tensor, optional - A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies - the coordinates of a box. - - box_indices : tvm.te.Tensor, optional - A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that - the i-th box refers to. - - extrapolation_value: float, optional - Value used for extrapolation, when applicable. + size: Tuple + Output resolution scale to layout: string, optional - "NCHW", "NHWC", or "NCHWc". + "NCW", "NWC", or "NCWc". coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor @@ -311,135 +390,69 @@ def resize_bilinear( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + method: {"linear", "nearest_neighbor", "cubic"} + Method to be used for resizing. + out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- - output : out_dtype - The computed result with type out_dtype + output : tvm.te.Tensor + 4-D with shape [batch, chananel, in_width*scale] + or [batch, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_width*scale, channel-minor] """ - - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout=layout) - box_idx = box_indices(n) if box_indices is not None else n - - if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) - - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) - - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x + method = method.lower() + if layout == "NWC": + in_n, in_w, in_c = data.shape + if output_shape is None: + output_shape = [in_n, size[0], in_c] + elif layout == "NCW": + in_n, in_c, in_w = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0]] + elif ncw_pack_layout(layout): # for NCWinic + in_n, in_c, in_w, in_inum, in_ic = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_inum, in_ic] + elif ncw_xc_layout(layout): # for NCWxc + in_n, in_c, in_w, in_cc = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_cc] else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, - image_width, - target_height, - target_width, - coordinate_transformation_mode, - ) - - top_y_index = te.floor(in_y).astype("int32") - bottom_y_index = te.ceil(in_y).astype("int32") - y_lerp = in_y - top_y_index - - left_x_index = te.floor(in_x).astype("int32") - right_x_index = te.ceil(in_x).astype("int32") - x_lerp = in_x - left_x_index + raise ValueError("%s layout is not supported." % layout) - top_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - left_x_index, - cc, - inum, - ic, - ) - top_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - right_x_index, - cc, - inum, - ic, - ) - bottom_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - left_x_index, - cc, - inum, - ic, - ) - bottom_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - right_x_index, - cc, - inum, - ic, - ) + if isinstance(size, tuple): + size = list(size) - top = _lerp(top_left, top_right, x_lerp) - bottom = _lerp(bottom_left, bottom_right, x_lerp) - value = _lerp(top, bottom, y_lerp) + for i in range(1): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) - # use extrapolation_value if in_y/in_x is out of boundary - if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) - value = tvm.tir.if_then_else( - in_x < 0, - extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + def compute_func(*indices): + return _resize_1d( + indices, + data, + in_w, + size[0], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) + + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) -def resize_bicubic( +def _resize_2d( indices, data, image_height, @@ -448,17 +461,17 @@ def resize_bicubic( target_width, boxes=None, box_indices=None, + method=None, extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", - out_dtype=None, + rounding_method="", alpha=-0.5, exclude_outside=0, + out_dtype=None, ): - """Perform resize operation with bicubic method on the data. - More details about Bicubic interpolation please refer to - https://en.wikipedia.org/wiki/Bicubic_interpolation. - This algorithm is doing a bicubic spline interpolation + + """Perform resize operation on the data with selected method and options. Parameters ---------- @@ -468,7 +481,7 @@ def resize_bicubic( data : tvm.te.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] - or [:batch, in_height, in_width, channel] + or [batch, in_height, in_width, channel] image_height : integer Input image height @@ -502,12 +515,19 @@ def resize_bicubic( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - out_dtype: string, optional - Type to return. If left None will be same as input type. + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] alpha: float, optional Bicubic spline coefficient + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + Returns ------- output : out_dtype @@ -523,7 +543,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n - if boxes is not None: y1, x1 = boxes(n, 0), boxes(n, 1) y2, x2 = boxes(n, 2), boxes(n, 3) @@ -536,77 +555,115 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - in_y, in_x = get_iny_inx( - y, - x, + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) + + value = get_2d_pixel( + data, + layout, image_height, image_width, - target_height, - target_width, - coordinate_transformation_mode, + box_idx, + c, + closest_y_index, + closest_x_index, + cc, + inum, + ic, ) + elif method == "linear": + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + p[j][i] = get_2d_pixel( + data, + layout, + image_height, + image_width, + box_idx, + c, + y_int + j, + x_int + i, + cc, + inum, + ic, + ) - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + top = _lerp(*p[0], x_lerp) + bottom = _lerp(*p[1], x_lerp) + value = _lerp(top, bottom, y_lerp) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + elif method == "cubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) - # Get the surrounding values - p = [[0 for i in range(4)] for j in range(4)] - for j in range(4): - for i in range(4): - p[j][i] = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - yint + j - 1, - xint + i - 1, - cc, - inum, - ic, - ) + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) + + # Get the surrounding values + p = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + p[j][i] = get_2d_pixel( + data, + layout, + image_height, + image_width, + box_idx, + c, + yint + j - 1, + xint + i - 1, + cc, + inum, + ic, + ) + + wx = _cubic_spline_weights(xfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + sum_wx = sum(wx) + sum_wy = sum(wy) + wx = [w / sum_wx for w in wx] + wy = [w / sum_wy for w in wy] + col0 = _cubic_kernel(p[0], wx) + col1 = _cubic_kernel(p[1], wx) + col2 = _cubic_kernel(p[2], wx) + col3 = _cubic_kernel(p[3], wx) + value = _cubic_kernel([col0, col1, col2, col3], wy) + + else: + raise ValueError("Unknown resize method:", method) - # Interpolate bicubically - def _cubic_spline_weights(t): - t2 = t * t - t3 = t * t * t - w1 = alpha * (t3 - 2 * t2 + t) - w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 - w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t - w4 = -alpha * t3 + alpha * t2 - return [w1, w2, w3, w4] - - def _cubic_kernel(inputs, w): - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) - - wx = _cubic_spline_weights(xfract) - wy = _cubic_spline_weights(yfract) - if exclude_outside: - for i in range(4): - wx[i] = te.if_then_else(te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i]) - wy[i] = te.if_then_else(te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i]) - sum_wx = sum(wx) - sum_wy = sum(wy) - wx = [w / sum_wx for w in wx] - wy = [w / sum_wy for w in wy] - col0 = _cubic_kernel(p[0], wx) - col1 = _cubic_kernel(p[1], wx) - col2 = _cubic_kernel(p[2], wx) - col3 = _cubic_kernel(p[3], wx) - value = _cubic_kernel([col0, col1, col2, col3], wy) - - # use extrapolation_value if in_y/in_x is out of boundary if extrapolation_value is not None: out = tvm.tir.if_then_else( in_y < 0, extrapolation_value, tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), ) + # use extrapolation_value if in_x is out of boundary value = tvm.tir.if_then_else( in_x < 0, extrapolation_value, @@ -615,11 +672,11 @@ def _cubic_kernel(inputs, w): return _cast_output(value, data.dtype, out_dtype=out_dtype) -def resize( +def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", bicubic_alpha=-0.5, @@ -648,7 +705,7 @@ def resize( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"bilinear", "nearest_neighbor", "bicubic"} + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional @@ -692,58 +749,23 @@ def resize( if isinstance(size[i], int): size[i] = tvm.tir.IntImm("int32", size[i]) - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( + def compute_func(*indices): + return _resize_2d( indices, data, in_h, in_w, size[0], size[1], + method=method, layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, rounding_method=rounding_method, - out_dtype=out_dtype, - ) - - def _bilinear(*indices): - return resize_bilinear( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, - ) - - def _bicubic(*indices): - return resize_bicubic( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, alpha=bicubic_alpha, exclude_outside=bicubic_exclude, + out_dtype=out_dtype, ) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear - elif method == "bicubic": - compute_func = _bicubic - else: - raise ValueError("%s method is not supported." % method) - return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) @@ -818,9 +840,11 @@ def crop_and_resize( image_w = data.shape[3].astype("int32") else: raise ValueError("%s layout is not supported." % layout) + if method == "bilinear": + method = "linear" - def _bilinear(*indices): - return resize_bilinear( + def compute_func(*indices): + return _resize_2d( indices, data, image_h, @@ -829,50 +853,280 @@ def _bilinear(*indices): target_w, boxes, box_indices, - extrapolation_value, - layout, + method=method, + extrapolation_value=extrapolation_value, + layout=layout, out_dtype=out_dtype, ) - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( - indices, + return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) + + +def _resize_3d( + indices, + data, + image_depth, + image_height, + image_width, + target_depth, + target_height, + target_width, + boxes=None, + box_indices=None, + method=None, + extrapolation_value=None, + layout="NCHW", + coordinate_transformation_mode="align_corners", + rounding_method="", + alpha=-0.5, + exclude_outside=0, + out_dtype=None, +): + + """Perform resize operation on the data with selected method and options. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.te.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_depth : integer + Input image depth + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_depth : integer + The target resized image depth + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.te.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.te.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + n, c, z, y, x, cc = get_3d_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_z_index = get_closest_index(in_z, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_3d_pixel( data, - image_h, - image_w, - target_h, - target_w, - boxes, - box_indices, - extrapolation_value, layout, - out_dtype=out_dtype, + image_depth, + image_height, + image_width, + box_idx, + c, + closest_z_index, + closest_y_index, + closest_x_index, + cc, ) + elif method == "linear": + z_int = te.floor(in_z).astype("int32") + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + z_lerp = in_z - z_int + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[[0 for i in range(2)] for j in range(2)] for k in range(2)] + for k in range(2): + for j in range(2): + for i in range(2): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + z_int + k, + y_int + j, + x_int + i, + cc, + ) + l = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + l[j][i] = _lerp(*p[j][i], x_lerp) + + top = _lerp(*l[0], y_lerp) + bottom = _lerp(*l[1], y_lerp) + value = _lerp(top, bottom, z_lerp) + + elif method == "cubic": + zint = te.floor(in_z).astype("int32") + zfract = in_z - te.floor(in_z) + + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) + + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for k in range(4): + for j in range(4): + for i in range(4): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + zint + k - 1, + yint + j - 1, + xint + i - 1, + cc, + ) + + wz = _cubic_spline_weights(zfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wz[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wz = sum(wz) + sum_wy = sum(wy) + sum_wx = sum(wx) + wz = [w / sum_wz for w in wz] + wy = [w / sum_wy for w in wy] + wx = [w / sum_wx for w in wx] + + l = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + l[j][i] = _cubic_kernel(p[j][i], wx) + col0 = _cubic_kernel(l[0], wy) + col1 = _cubic_kernel(l[1], wy) + col2 = _cubic_kernel(l[2], wy) + col3 = _cubic_kernel(l[3], wy) + value = _cubic_kernel([col0, col1, col2, col3], wz) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear else: - raise ValueError("%s method is not supported." % method) + raise ValueError("Unknown resize method:", method) - return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) + if extrapolation_value is not None: + out = tvm.tir.if_then_else( + in_z < 0, + extrapolation_value, + tvm.tir.if_then_else(in_z > image_depth - 1, extrapolation_value, value), + ) + out = tvm.tir.if_then_else( + in_y < 0, + extrapolation_value, + tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), + ) + # use extrapolation_value if in_x is out of boundary + value = tvm.tir.if_then_else( + in_x < 0, + extrapolation_value, + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + ) + return _cast_output(value, data.dtype, out_dtype=out_dtype) def resize3d( data, size, layout="NCDHW", - method="nearest_neighbor", - coordinate_transformation_mode="align_corners", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, + output_shape=None, ): """Perform resize operation on the data. Parameters ---------- - inputs: tvm.te.Tensor + data : tvm.te.Tensor inputs is a 5-D tensor with shape [batch, channel, in_depth, in_height, in_width] or [batch, in_depth, in_height, in_width, channel] @@ -887,24 +1141,28 @@ def resize3d( Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. Refer to the ONNX Resize operator specification for details. - Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"trilinear", "nearest_neighbor"} + + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- output : tvm.te.Tensor - 5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] + 4-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, - channel-minor] + or 5-D with shape + [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] """ - method = method.lower() + method = method.lower() if layout == "NDHWC": in_n, in_d, in_h, in_w, in_c = data.shape output_shape = [in_n, size[0], size[1], size[2], in_c] @@ -916,125 +1174,30 @@ def resize3d( in_n, in_c, in_d, in_h, in_w, in_cc = data.shape output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc] - if coordinate_transformation_mode == "align_corners": - z_ratio = (in_d - 1).astype("float") / (size[0] - 1) - y_ratio = (in_h - 1).astype("float") / (size[1] - 1) - x_ratio = (in_w - 1).astype("float") / (size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - z_ratio = (in_d).astype("float") / (size[0]) - y_ratio = (in_h).astype("float") / (size[1]) - x_ratio = (in_w).astype("float") / (size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _get_pixel(n, c, z, y, x, cc): - z = tvm.te.max(tvm.te.min(z, in_d - 1), 0) - y = tvm.te.max(tvm.te.min(y, in_h - 1), 0) - x = tvm.te.max(tvm.te.min(x, in_w - 1), 0) - if layout == "NDHWC": - return data(n, z, y, x, c).astype("float") - if layout == "NCDHW": - return data(n, c, z, y, x).astype("float") - # else must be NCDHWxc - return data(n, c, z, y, x, cc).astype("float") - - def _get_indices(*indices): - if layout == "NDHWC": - n, z, y, x, c = indices - cc = None - elif layout == "NCDHW": - n, c, z, y, x = indices - cc = None - else: - n, c, z, y, x, cc = indices - - return n, c, z, y, x, cc - - def _cast_output(value): - if out_dtype: - dtype = out_dtype - else: - dtype = data.dtype - return value.astype(dtype) - - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - if coordinate_transformation_mode == "align_corners": - zint = te.round(in_z).astype("int32") - yint = te.round(in_y).astype("int32") - xint = te.round(in_x).astype("int32") - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - zint = te.floor(in_z + epsilon).astype("int32") - yint = te.floor(in_y + epsilon).astype("int32") - xint = te.floor(in_x + epsilon).astype("int32") - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - return _cast_output(_get_pixel(n, c, zint, yint, xint, cc)) - - # Trilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _trilinear(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - if coordinate_transformation_mode == "half_pixel": - in_z = z_ratio * (z + 0.5) - 0.5 - in_y = y_ratio * (y + 0.5) - 0.5 - in_x = x_ratio * (x + 0.5) - 0.5 - else: - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - zint = te.floor(in_z).astype("int32") - zfract = in_z - te.floor(in_z) - - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + if isinstance(size, tuple): + size = list(size) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + for i in range(3): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) - p000 = _get_pixel(n, c, zint, yint, xint, cc) - p001 = _get_pixel(n, c, zint, yint, xint + 1, cc) - p010 = _get_pixel(n, c, zint, yint + 1, xint, cc) - p011 = _get_pixel(n, c, zint, yint + 1, xint + 1, cc) - p100 = _get_pixel(n, c, zint + 1, yint, xint, cc) - p101 = _get_pixel(n, c, zint + 1, yint, xint + 1, cc) - p110 = _get_pixel(n, c, zint + 1, yint + 1, xint, cc) - p111 = _get_pixel(n, c, zint + 1, yint + 1, xint + 1, cc) - - dep00 = _lerp(p000, p100, zfract) - dep01 = _lerp(p001, p101, zfract) - dep10 = _lerp(p010, p110, zfract) - dep11 = _lerp(p011, p111, zfract) - col0 = _lerp(dep00, dep01, xfract) - col1 = _lerp(dep10, dep11, xfract) - value = _lerp(col0, col1, yfract) - return _cast_output(value) - - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "trilinear": - compute_func = _trilinear - else: - raise ValueError("%s method is not supported." % method) + def compute_func(*indices): + return _resize_3d( + indices, + data, + in_d, + in_h, + in_w, + size[0], + size[1], + size[2], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, + ) - return te.compute(output_shape, compute_func, name="resize3d", tag=tag.INJECTIVE) + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) diff --git a/python/tvm/topi/intel_graphics/depthwise_conv2d.py b/python/tvm/topi/intel_graphics/depthwise_conv2d.py index fabd63b8778c..02af465248a6 100644 --- a/python/tvm/topi/intel_graphics/depthwise_conv2d.py +++ b/python/tvm/topi/intel_graphics/depthwise_conv2d.py @@ -20,7 +20,6 @@ from tvm import te from tvm import autotvm from ..utils import traverse_inline -from .. import tag from .. import nn from ..nn.depthwise_conv2d import depthwise_conv2d_infer_layout @@ -136,188 +135,6 @@ def _callback(op): return s -def schedule_depthwise_conv2d_nhwc(outs): - """Schedule for depthwise_conv2d nhwc forward. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of depthwise_conv2d - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for depthwise_conv2d nhwc. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _schedule(temp, Filter, DepthwiseConv2d): - s[temp].compute_inline() - FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) - if DepthwiseConv2d.op in s.outputs: - Output = DepthwiseConv2d - CL = s.cache_write(DepthwiseConv2d, "local") - else: - Output = outs[0].op.output(0) - s[DepthwiseConv2d].set_scope("local") - - block_x = te.thread_axis("blockIdx.x") - thread_x = te.thread_axis("threadIdx.x") - - b, h, w, c = s[Output].op.axis - - # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value - target = tvm.target.Target.current() - if target and (target.kind.name not in ["cuda", "nvptx"]): - num_thread = target.max_num_threads - xoc, xic = s[Output].split(c, factor=num_thread) - s[Output].reorder(xoc, b, h, w, xic) - xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2) - fused = s[Output].fuse(yo, xo) - fused = s[Output].fuse(fused, b) - fused = s[Output].fuse(fused, xoc) - - s[Output].bind(fused, block_x) - s[Output].bind(xic, thread_x) - - if DepthwiseConv2d.op in s.outputs: - s[CL].compute_at(s[Output], xic) - else: - s[DepthwiseConv2d].compute_at(s[Output], xic) - - _, _, ci, fi = s[FS].op.axis - s[FS].compute_at(s[Output], fused) - fused = s[FS].fuse(fi, ci) - s[FS].bind(fused, thread_x) - - scheduled_ops = [] - - def traverse(OP): - """Internal traverse function""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): - if OP not in s.outputs: - s[OP].compute_inline() - for tensor in OP.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - # schedule depthwise_conv2d - if OP.tag == "depthwise_conv2d_nhwc": - PaddedInput = OP.input_tensors[0] - Filter = OP.input_tensors[1] - if isinstance(Filter.op, tvm.te.ComputeOp) and "dilate" in Filter.op.tag: - s[Filter].compute_inline() - DepthwiseConv2d = OP.output(0) - _schedule(PaddedInput, Filter, DepthwiseConv2d) - - scheduled_ops.append(OP) - - traverse(outs[0].op) - return s - - -def schedule_depthwise_conv2d_backward_input_nhwc(outs): - """Schedule for depthwise_conv2d nhwc backward wrt input. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of depthwise_conv2d - backward wrt input in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for depthwise_conv2d backward - wrt input with layout nhwc. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _schedule(Padded_out_grad, In_grad): - s[Padded_out_grad].compute_inline() - - block_x = te.thread_axis("blockIdx.x") - thread_x = te.thread_axis("threadIdx.x") - _, h, w, c = In_grad.op.axis - - fused_hwc = s[In_grad].fuse(h, w, c) - xoc, xic = s[In_grad].split(fused_hwc, factor=128) - - s[In_grad].bind(xoc, block_x) - s[In_grad].bind(xic, thread_x) - - def traverse(OP): - # inline all one-to-one-mapping operators except the last stage (output) - if OP.tag == "depthwise_conv2d_backward_input_nhwc": - Padded_out_grad = OP.input_tensors[0] - Dilated_out_grad = Padded_out_grad.op.input_tensors[0] - s[Dilated_out_grad].compute_inline() - In_grad = OP.output(0) - _schedule(Padded_out_grad, In_grad) - else: - raise ValueError("Depthwise conv backward wrt input for non-NHWC is not supported.") - - traverse(outs[0].op) - return s - - -def schedule_depthwise_conv2d_backward_weight_nhwc(outs): - """Schedule for depthwise_conv2d nhwc backward wrt weight. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of depthwise_conv2d - backward wrt weight in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for depthwise_conv2d backward - wrt weight with layout nhwc. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _schedule(Weight_grad): - block_x = te.thread_axis("blockIdx.x") - thread_y = te.thread_axis("threadIdx.y") - thread_x = te.thread_axis("threadIdx.x") - - db, dh, dw = Weight_grad.op.reduce_axis - - fused_dbdhdw = s[Weight_grad].fuse(db, dh, dw) - _, ki = s[Weight_grad].split(fused_dbdhdw, factor=8) - BF = s.rfactor(Weight_grad, ki) - - fused_fwcm = s[Weight_grad].fuse(*s[Weight_grad].op.axis) - - xo, xi = s[Weight_grad].split(fused_fwcm, factor=32) - - s[Weight_grad].bind(xi, thread_x) - s[Weight_grad].bind(xo, block_x) - - s[Weight_grad].bind(s[Weight_grad].op.reduce_axis[0], thread_y) - s[BF].compute_at(s[Weight_grad], s[Weight_grad].op.reduce_axis[0]) - - def traverse(OP): - # inline all one-to-one-mapping operators except the last stage (output) - if OP.tag == "depthwise_conv2d_backward_weight_nhwc": - Padded_in = OP.input_tensors[1] - s[Padded_in].compute_inline() - Weight_grad = OP.output(0) - _schedule(Weight_grad) - else: - raise ValueError("Depthwise conv backward wrt weight for non-NHWC is not supported.") - - traverse(outs[0].op) - return s - - @depthwise_conv2d_infer_layout.register("intel_graphics") def _depthwise_conv2d_infer_layout(workload, _): """Infer input/output shapes and layouts from a workload and cfg. diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index 52fe011a70e9..f3ef55b9a30c 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -30,6 +30,7 @@ # reuse some compute declarations from ARM CPU from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw +from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nhwc logger = logging.getLogger("topi") @@ -95,37 +96,59 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs): def _callback(op): # schedule conv2d if "spatial_conv2d_output" in op.tag: - output = op.output(0) - conv = op.input_tensors[0] + _schedule_spatial_pack(cfg, s, op, layout="NCHW") + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_spatial_pack.mali") +def conv2d_nhwc_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with NHWC layout""" + return conv2d_spatial_pack_nhwc( + cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=3 + ) - data_vec = conv.op.input_tensors[0] - data_pad = data_vec.op.input_tensors[0] - s[data_pad].compute_inline() - kernel_vec = conv.op.input_tensors[1] - if kernel_vec.op.name == "kernel_vec": - kernel = kernel_vec.op.input_tensors[0] - else: - kernel = kernel_vec - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() +@autotvm.register_topi_schedule("conv2d_nhwc_spatial_pack.mali") +def schedule_conv2d_nhwc_spatial_pack(cfg, outs): + """Create schedule for conv2d_nhwc""" + s = te.create_schedule([x.op for x in outs]) - _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec) + def _callback(op): + # schedule conv2d + if "spatial_conv_output_NHWC" in op.tag: + _schedule_spatial_pack(cfg, s, op, layout="NHWC") traverse_inline(s, outs[0].op, _callback) return s -def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): +def _schedule_spatial_pack(cfg, s, op, layout): """schedule the spatial packing for conv2d""" + + assert layout in ("NCHW", "NHWC") + + output = op.output(0) + conv = op.input_tensors[0] + data_vec = conv.op.input_tensors[0] + data_pad = data_vec.op.input_tensors[0] + s[data_pad].compute_inline() + kernel_vec = conv.op.input_tensors[1] + if kernel_vec.op.name == "kernel_vec": + kernel = kernel_vec.op.input_tensors[0] + else: + kernel = kernel_vec + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data = s[data_vec].op.input_tensors[0] max_unroll = 16 vec_size = [1, 2, 4, 8, 16] # get tunable parameters (they are defined in compute) - BC, TC, VC = cfg["tile_co"].size - BH, TH, VH = cfg["tile_oh"].size - BW, TW, VW = cfg["tile_ow"].size + _, TC, VC = cfg["tile_co"].size + _, TH, VH = cfg["tile_oh"].size + _, TW, VW = cfg["tile_ow"].size # schedule padding if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag: @@ -133,21 +156,29 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_pad].compute_inline() # schedule data packing - if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": - _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis + if layout == "NCHW": + if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": + _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis + else: + _, h, w, ci, vh, vw = s[data_vec].op.axis + z, y, x, unroll1, unroll2 = h, w, ci, vh, vw else: - _, h, w, ci, vh, vw = s[data_vec].op.axis - tile_and_bind3d(s, data_vec, h, w, ci, 1) - if vh.dom.extent.value < max_unroll: - s[data_vec].unroll(vh) - if vw.dom.extent.value < max_unroll: - s[data_vec].unroll(vw) + if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": + _, oho, owo, _, _, ic, ohi, owi = s[data_vec].op.axis + else: + _, oho, owo, ohi, owi, ic = s[data_vec].op.axis + z, y, x, unroll1, unroll2 = oho, owo, ohi, ic, owi + tile_and_bind3d(s, data_vec, z, y, x, 1) + if unroll1.dom.extent.value < max_unroll: + s[data_vec].unroll(unroll1) + if unroll2.dom.extent.value < max_unroll: + s[data_vec].unroll(unroll2) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec": if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads - co, ci, kh, kw, vc = s[kernel_vec].op.axis - fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) + ax1, ax2, ax3, ax4, ax5 = s[kernel_vec].op.axis + fused = s[kernel_vec].fuse(ax1, ax2, ax3, ax4, ax5) fused, vec = s[kernel_vec].split(fused, VC) bb, tt = s[kernel_vec].split(fused, max_threads) s[kernel_vec].bind(bb, te.thread_axis("blockIdx.x")) @@ -156,25 +187,37 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[kernel_vec].vectorize(vec) # schedule convolution - n, c, h, w, vh, vw, vc = s[conv].op.axis - kc, kh, kw = s[conv].op.reduce_axis - - cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc]) - tile_and_bind3d(s, conv, c, h, w, TC, TH, TW) - + ic, kh, kw = s[conv].op.reduce_axis + if layout == "NCHW": + kh_dim, kw_dim = kernel_vec.shape[2], kernel_vec.shape[3] + else: + kh_dim, kw_dim = kernel_vec.shape[0], kernel_vec.shape[1] cfg["ann_reduce"].apply( s, conv, [kh, kw], - axis_lens=[get_const_int(kernel_vec.shape[2]), get_const_int(kernel_vec.shape[3])], + axis_lens=[get_const_int(kh_dim), get_const_int(kw_dim)], max_unroll=max_unroll, ) + if layout == "NCHW": + n, c, h, w, vh, vw, vc = s[conv].op.axis + cfg["reorder_0"].apply(s, conv, [n, c, h, w, ic, kh, kw, vh, vw, vc]) + tile_and_bind3d(s, conv, c, h, w, TC, TH, TW) + unroll_vec_axes = [vh, vw, vc] + axis_lens = [VH, VW, VC] + else: + n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis + cfg["reorder_conv"].apply(s, conv, [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci]) + tile_and_bind3d(s, conv, oho, owo, oco, TH, TW, TC) + unroll_vec_axes = [ohi, owi, oci] + axis_lens = [VH, VW, VC] + cfg["ann_spatial"].apply( s, conv, - [vh, vw, vc], - axis_lens=[VH, VW, VC], + unroll_vec_axes, + axis_lens, max_unroll=max_unroll, vec_size=vec_size, cfg=cfg, @@ -184,9 +227,12 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): if output.op not in s.outputs: # has bias s[output].compute_inline() output = s.outputs[0] - - _, co, oh, ow = s[output].op.axis - tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW) + if layout == "NCHW": + _, co, oh, ow = s[output].op.axis + tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW) + else: + _, oh, ow, co = s[output].op.axis + tile_and_bind3d(s, output, oh, ow, co, TH, TW, TC) return s diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index b6ed5a373e81..26d45feb0387 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -16,28 +16,50 @@ # under the License. """Batch matrix multiplication""" # pylint: disable=invalid-name +import logging import tvm from tvm import te, auto_scheduler from ..utils import get_const_tuple +logger = logging.getLogger("topi") -def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting for batch dimension. + +def batch_matmul( + tensor_a, + tensor_b, + oshape=None, + out_dtype=None, + transpose_a=False, + transpose_b=True, + auto_scheduler_rewritten_layout="", +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- - x : tvm.te.Tensor - 3-D with shape [batch, M, K] + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. - y : tvm.te.Tensor - 3-D with shape [batch, N, K] + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. oshape : List[Optional] Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. - auto_scheduler_rewritten_layout: str = "" + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns @@ -45,35 +67,79 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - x_shape = get_const_tuple(x.shape) + assert len(tensor_a.shape) == 3, "tensor_a only support 3-dim" + if transpose_a: + XB, XK, XI = get_const_tuple(tensor_a.shape) + else: + XB, XI, XK = get_const_tuple(tensor_a.shape) if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout - y_shape = auto_scheduler.get_shape_from_rewritten_layout( - auto_scheduler_rewritten_layout, ["b", "j", "k"] + YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["b", "k", "j"] ) - auto_scheduler.remove_index_check(y) + auto_scheduler.remove_index_check(tensor_b) else: - y_shape = get_const_tuple(y.shape) - assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul" + assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim" + if transpose_b: + YB, YJ, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, YJ = get_const_tuple(tensor_b.shape) - XB = x_shape[0] - YB = y_shape[0] - _, M, K = x.shape - k = te.reduce_axis((0, K), name="k") + assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y are inconsistent" + k = te.reduce_axis((0, XK), name="k") if oshape is None: assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" - batch = te.max(XB, YB) - N = y.shape[1] - oshape = (batch, M, N) + batch = ( + tvm.tir.expr.SizeVar("batch", "int32") + if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var) + else te.max(XB, YB) + ) + oshape = (batch, XI, YJ) + if out_dtype is None: + out_dtype = tensor_a.dtype + if tensor_a.dtype != tensor_b.dtype: + logger.warning( + "tensor_a has different data type with tensor_b: %s, %s", + tensor_a.dtype, + tensor_b.dtype, + ) + + if (transpose_a, transpose_b) == (True, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TT" + elif (transpose_a, transpose_b) == (True, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TN" + elif (transpose_a, transpose_b) == (False, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_NT" + else: # (transpose_a, transpose_b) == (False, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_NN" output = te.compute( oshape, - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + compute_lambda, + name=compute_name, tag="batch_matmul", - attrs={"layout_free_placeholders": [y]}, + attrs={"layout_free_placeholders": [tensor_b]}, ) - if auto_scheduler_rewritten_layout: output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 3f72bdc4b667..7cb4b09b8805 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -176,10 +176,13 @@ def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layou else: KH, KW, CIG, CO = get_const_tuple(kernel.shape) - pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW))) dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) + pt, pl, pb, pr = get_pad_tuple( + padding, + (get_const_int((KH - 1) * dilation_h + 1), get_const_int((KW - 1) * dilation_w + 1)), + ) GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index a3639b57e7e0..48ffb8c6d9ff 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -24,7 +24,7 @@ from .dilate import dilate from .pad import pad from .utils import get_pad_tuple -from ..utils import simplify +from ..utils import simplify, get_const_tuple # workload description of depthwise-conv2d Workload = namedtuple( @@ -50,11 +50,47 @@ ) -def _get_workload(data, kernel, stride, padding, dilation, out_dtype): - """Get the workload structure.""" - _, in_channel, height, width = [x.value for x in data.shape] - channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape] - out_channel = channel * channel_multiplier +def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"): + """Get the workload structure for a depthwise conv2d. + + Input data and filter should use NCHW layout. + """ + if data_layout == "NCHW": + _, in_channel, height, width = get_const_tuple(data.shape) + filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) + elif data_layout == "NHWC": + _, height, width, in_channel = get_const_tuple(data.shape) + kh, kw, filter_channel, channel_multiplier = get_const_tuple(kernel.shape) + elif data_layout == "NCHWc": + _, in_channel_chunk, height, width, in_channel_block = get_const_tuple(data.shape) + in_channel = in_channel_chunk * in_channel_block + ( + filter_channel_chunk, + cm_chunk, + kh, + kw, + cm_block, + filter_channel_block, + ) = get_const_tuple(kernel.shape) + filter_channel = filter_channel_chunk * filter_channel_block + channel_multiplier = cm_chunk * cm_block + + assert ( + in_channel_block == filter_channel_block + ), "Incorrect dimensions, data has block size {}, but filter has block size {}".format( + in_channel_block, filter_channel_block + ) + + else: + raise ValueError("Data layout {} not supported".format(data_layout)) + + assert ( + in_channel == filter_channel + ), "Incorrect dimensions, data has {} channels but filter expects {} channels".format( + in_channel, filter_channel + ) + + out_channel = filter_channel * channel_multiplier dilation_h, dilation_w = ( dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) ) @@ -102,8 +138,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No Filter : tvm.te.Tensor 4-D with shape [in_channel, channel_multiplier, filter_height, filter_width] - stride : tuple of two ints - The spatial stride along height and width + stride : int or a list/tuple of two ints + The spatial stride, or (stride_height, stride_width). padding : int or str Padding size, or ['VALID', 'SAME'] diff --git a/python/tvm/topi/nn/mapping.py b/python/tvm/topi/nn/mapping.py index c048fc86d4d5..0e0b1825df30 100644 --- a/python/tvm/topi/nn/mapping.py +++ b/python/tvm/topi/nn/mapping.py @@ -29,7 +29,7 @@ def scale_shift_nchw(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NCHW + 4-D input tensor, NCHW layout [batch, channel, height, width] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -54,7 +54,7 @@ def scale_shift_nhwc(Input, Scale, Shift): Parameters ---------- Input : tvm.te.Tensor - Input tensor, layout is NHWC + 4-D input tensor, NHWC layout [batch, height, width, channel] Scale : tvm.te.Tensor Scale tensor, 1-D of size channel number @@ -70,3 +70,30 @@ def scale_shift_nhwc(Input, Scale, Shift): return te.compute( Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name="ScaleShift" ) + + +@tvm.te.tag_scope(tag=tag.BROADCAST) +def scale_shift_nchwc(Input, Scale, Shift): + """Batch normalization operator in inference. + + Parameters + ---------- + Input : tvm.te.Tensor + 5-D input tensor, NCHWc layout [batch, channel_chunk, height, width, channel_block] + + Scale : tvm.te.Tensor + Scale tensor, 2-D of size [channel_chunk, channel_block] + + Shift : tvm.te.Tensor + Shift tensor, 2-D of size [channel_chunk, channel_block] + + Returns + ------- + Output : tvm.te.Tensor + Output tensor, layout is NHWC + """ + return te.compute( + Input.shape, + lambda b, cc, i, j, cb: Input[b, cc, i, j, cb] * Scale[cc, cb] + Shift[cc, cb], + name="ScaleShift", + ) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 73998db6f162..948847e60d92 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -31,7 +31,7 @@ def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr): Parameters ---------- data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] weight_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -78,7 +78,7 @@ def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight): 1-D with shape [(M + 1) // bs_r] (BSR) weight: - 2-D with shape [N, K], float32 + 2-D with shape [N, K] Returns ------- @@ -105,7 +105,7 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_ Parameters ---------- dense_data : tvm.te.Tensor - 2-D with shape [M, K], float32 + 2-D with shape [M, K] sparse_data : tvm.te.Tensor 1-D with shape [nnz] (CSR) or @@ -239,7 +239,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Parameters ---------- sparse_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] sparse_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -250,7 +250,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): Returns ------- out_data : tvm.te.Tensor - 1-D with shape [nonzeros], dtype of 'float32' + 1-D with shape [nonzeros] out_indices : tvm.te.Tensor 1-D with shape [nonzeros], dtype of 'int32' @@ -275,7 +275,7 @@ def sparse_transpose(sparse_data, sparse_indices, sparse_indptr): ins[0], ins[1], ins[2], outs[0], outs[1], outs[2] ), tag="sparse_transpose_csr", - dtype=["float32", "int32", "int32"], + dtype=[sparse_data.dtype, "int32", "int32"], name="out", ) diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index b95835f6e103..36b9349a139d 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -92,7 +92,9 @@ def upsampling( else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" - return topi.image.resize( + if method[0:2] == "bi": + method = method[2:] + return topi.image.resize2d( data, reshape_size, layout=layout, @@ -188,6 +190,8 @@ def upsampling3d( ) else: raise ValueError("not support this layout {} yet".format(layout)) + if method[0:3] == "tri": + method = method[3:] return topi.image.resize3d( data, resize_shape, diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 8b6bb114b181..2ef97e2edc5c 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """Pseudorandom number kernels.""" +import numpy as np + import tvm import tvm.topi -import numpy as np + from ... import tir from ...tir import ir_builder diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py index f5796730f762..32a7e297b04c 100644 --- a/python/tvm/topi/scan.py +++ b/python/tvm/topi/scan.py @@ -23,7 +23,7 @@ from ..te import extern from ..tir import decl_buffer, generic, ir_builder from .math import cast -from .utils import get_const_int, prod +from . import utils def scanop( @@ -93,11 +93,11 @@ def maybe_cast(x): if axis is None: axis = 0 - cumsum_axis_len = prod(data.shape) + cumsum_axis_len = utils.prod(data.shape) shape = (cumsum_axis_len,) else: if not isinstance(axis, int): - axis = get_const_int(axis) + axis = utils.get_const_int(axis) shape = data.shape cumsum_axis_len = shape[axis] diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index 39ba3332fc72..4d659c801103 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -20,6 +20,7 @@ from tvm import te from .. import tag from ..utils import simplify +from ...tir.generic import cast def csrmm_default(data, indices, indptr, weight, bias=None): @@ -57,6 +58,12 @@ def csrmm_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 M = simplify(indptr.shape[0] - 1) @@ -74,9 +81,9 @@ def csrmm_default_ir(data, indices, indptr, weight, out): _, N = weight.shape with irb.for_range(0, N, kind="vectorize", name="n") as n: with irb.for_range(0, M, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row * N + n] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row * N + n] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -92,7 +99,7 @@ def csrmm_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmm", - dtype="float32", + dtype=data.dtype, name="out", ) if bias is not None: diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index a2d22afe01e0..3c2016c6513a 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -19,6 +19,7 @@ import tvm from tvm import te from .. import tag +from ...tir.generic import cast def csrmv_default(data, indices, indptr, weight, bias=None): @@ -50,6 +51,12 @@ def csrmv_default(data, indices, indptr, weight, bias=None): assert isinstance( weight, te.tensor.Tensor ), "weight matrix is assumed to be tvm.te.Tensor, but weight is `%s`" % (type(weight)) + assert ( + data.dtype == weight.dtype + ), "Data and weight must have the same dtype, but they have %s and %s" % ( + data.dtype, + weight.dtype, + ) if bias is not None: assert len(bias.shape) == 1 batch = indptr.shape[0] - 1 @@ -64,9 +71,9 @@ def csrmv_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 with irb.for_range(0, num_rows, kind="parallel", name="row") as row: - dot = irb.allocate("float32", (1,), name="dot", scope="local") - out_ptr[row] = 0.0 - dot[0] = 0.0 + dot = irb.allocate(data.dtype, (1,), name="dot", scope="local") + out_ptr[row] = cast(0, data.dtype) + dot[0] = cast(0, data.dtype) row_start = indptr_ptr[row] row_end = indptr_ptr[row + 1] row_elems = row_end - row_start @@ -82,7 +89,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): [data, indices, indptr, weight], lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), tag="csrmv", - dtype="float32", + dtype=data.dtype, name="csrmv", ) if bias is not None: diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index afb251417315..d10c49f5c084 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -32,12 +32,14 @@ from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_python import deformable_conv2d_nchw_python, deformable_conv2d_nhwc_python -from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc +from .depthwise_conv2d_python import ( + depthwise_conv2d_python_nchw, + depthwise_conv2d_python_nhwc, + depthwise_conv2d_python_nchwc, +) from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python -from .upsampling_python import upsampling_python, upsampling3d_python -from .bilinear_resize_python import bilinear_resize_python -from .trilinear_resize3d_python import trilinear_resize3d_python +from .resize_python import resize1d_python, resize2d_python, resize3d_python from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python @@ -70,3 +72,4 @@ from .space_to_batch_nd import space_to_batch_nd_python from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss +from .dense import dense diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index 96d1fcbb5bc3..18fa7e8c4b33 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -19,7 +19,7 @@ import numpy as np -def batch_matmul(x, y, out_dtype=None): +def batch_matmul(x, y, out_dtype=None, trans_x=False, trans_y=True): """batch_matmul operator implemented in numpy. Parameters @@ -38,13 +38,22 @@ def batch_matmul(x, y, out_dtype=None): out : numpy.ndarray 3-D with shape [batch, M, N] """ - XB, M, _ = x.shape - YB, N, _ = y.shape + if trans_x: + XB, _, M = x.shape + else: + XB, M, _ = x.shape + if trans_y: + YB, N, _ = y.shape + else: + YB, _, N = y.shape batch = max(XB, YB) dtype = x.dtype if out_dtype is None else out_dtype out = np.zeros((batch, M, N)).astype(dtype) for i in range(batch): + xx = x[i if XB != 1 else 0].astype(dtype) + yy = y[i if YB != 1 else 0].astype(dtype) out[i] = np.dot( - x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype) + xx.T if trans_x else xx, + yy.T if trans_y else yy, ) return out diff --git a/python/tvm/topi/testing/bilinear_resize_python.py b/python/tvm/topi/testing/bilinear_resize_python.py deleted file mode 100644 index b1fb8b0b4845..000000000000 --- a/python/tvm/topi/testing/bilinear_resize_python.py +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Bilinear Scale in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): - """Bilinear scaling using python""" - (new_h, new_w) = out_size - (ib, ic) = (1, 1) - - if layout == "NHWC": - (batch, h, w, channel) = image.shape - scaled_image = np.ones((batch, new_h, new_w, channel)) - # NCHWinic - elif nchw_pack_layout(layout): - (batch, channel, h, w, ib, ic) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w, ib, ic)) - else: - (batch, channel, h, w) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - height_scale = np.float32(h - 1) / np.float32(out_size[0] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[1] - 1) - else: - height_scale = np.float32(h) / np.float32(out_size[0]) - width_scale = np.float32(w) / np.float32(out_size[1]) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _img_scale(b, m, i, n): - for j in range(new_h): - for k in range(new_w): - if coordinate_transformation_mode == "half_pixel": - in_y = (j + 0.5) * height_scale - 0.5 - else: - in_y = j * height_scale - y0 = int(math.floor(in_y)) - y1 = max(min(y0 + 1, h - 1), 0) - y0 = max(y0, 0) - y_lerp = in_y - math.floor(in_y) - - if coordinate_transformation_mode == "half_pixel": - in_x = (k + 0.5) * width_scale - 0.5 - else: - in_x = k * width_scale - x0 = int(math.floor(in_x)) - x1 = max(min(x0 + 1, w - 1), 0) - x0 = max(x0, 0) - x_lerp = in_x - math.floor(in_x) - - if layout == "NHWC": - A = image[b][y0][x0][i] - B = image[b][y0][x1][i] - C = image[b][y1][x0][i] - D = image[b][y1][x1][i] - elif nchw_pack_layout(layout): - A = image[b][i][y0][x0][m][n] - B = image[b][i][y0][x1][m][n] - C = image[b][i][y1][x0][m][n] - D = image[b][i][y1][x1][m][n] - else: - A = image[b][i][y0][x0] - B = image[b][i][y0][x1] - C = image[b][i][y1][x0] - D = image[b][i][y1][x1] - - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NHWC": - scaled_image[b][j][k][i] = pixel - elif nchw_pack_layout(layout): - scaled_image[b][i][j][k][m][n] = pixel - else: - scaled_image[b][i][j][k] = pixel - - for b in range(batch): - for m in range(ib): - for i in range(channel): - for n in range(ic): - _img_scale(b, m, i, n) - - return scaled_image diff --git a/python/tvm/topi/testing/dense.py b/python/tvm/topi/testing/dense.py new file mode 100644 index 000000000000..7871cd71892a --- /dev/null +++ b/python/tvm/topi/testing/dense.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Dense in python""" +import numpy as np + + +def dense(x, y, bias, use_bias=False, use_relu=False, out_dtype=None): + """dense operator implemented in numpy. + + Parameters + ---------- + x : numpy.ndarray + 2-D with shape [M, K] + + y : numpy.ndarray + 2-D with shape [N, K] + + bias: numpy.ndarray + 1-D with shape [M,] + + out_dtype: string, optional + Specify the dtype of output + + Returns + ------- + out : numpy.ndarray + 2-D with shape [M, N] + """ + dtype = x.dtype if out_dtype is None else out_dtype + if use_bias: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + bias + else: + out = np.dot(x.astype(dtype), y.T.astype(dtype)) + + if use_relu: + out = np.maximum(out, 0) + + return out diff --git a/python/tvm/topi/testing/depthwise_conv2d_python.py b/python/tvm/topi/testing/depthwise_conv2d_python.py index 2239c56134f5..1ec64b7e7b82 100644 --- a/python/tvm/topi/testing/depthwise_conv2d_python.py +++ b/python/tvm/topi/testing/depthwise_conv2d_python.py @@ -67,17 +67,15 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): ] elif padding == "SAME": out_channel = in_channel * channel_multiplier - out_height = np.int(np.ceil(float(in_height) / float(stride_h))) - out_width = np.int(np.ceil(float(in_width) / float(stride_w))) + out_height = int(np.ceil(float(in_height) / float(stride_h))) + out_width = int(np.ceil(float(in_width) / float(stride_w))) output_np = np.zeros((batch, out_channel, out_height, out_width)) - pad_along_height = np.int( - np.max((out_height - 1) * stride_h + filter_height - in_height, 0) - ) - pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) + pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) + pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) + pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) + pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) + pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) + pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) index_h = pad_top_scipy - pad_top_tvm index_w = pad_left_scipy - pad_left_tvm for i in range(batch): @@ -91,6 +89,63 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): return output_np +def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding): + """Depthwise convolution operator in NCHWc layout. + + Parameters + ---------- + input_np : numpy.ndarray + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + filter_np : numpy.ndarray + 6-D with shape [out_channel_chunk, channel_multiplier_chunk, + filter_height, filter_width, + channel_multiplier_block, out_channel_block] + + stride : list / tuple of 2 ints + [stride_height, stride_width] + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + output_np : np.ndarray + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + # Transform to NCHW + batch_size, in_channel_chunk, in_height, in_width, in_channel_block = input_np.shape + input_nchw = input_np.transpose(0, 1, 4, 2, 3).reshape( + (batch_size, in_channel_chunk * in_channel_block, in_height, in_width) + ) + + ( + out_channel_chunk, + channel_multiplier_chunk, + filter_height, + filter_width, + channel_multiplier_block, + out_channel_block, + ) = filter_np.shape + filter_nchw = filter_np.transpose(0, 5, 1, 4, 2, 3).reshape( + ( + out_channel_chunk * out_channel_block, + channel_multiplier_chunk * channel_multiplier_block, + filter_height, + filter_width, + ) + ) + + # Perform conv2d + output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) + + # Transform back + batch_size, out_channel, out_height, out_width = output_np.shape + return output_np.reshape( + (batch_size, out_channel_chunk, out_channel_block, out_height, out_width) + ).transpose(0, 1, 3, 4, 2) + + def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): """Depthwise convolution operator in nchw layout. @@ -138,17 +193,15 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): ] if padding == "SAME": out_channel = in_channel * channel_multiplier - out_height = np.int(np.ceil(float(in_height) / float(stride_h))) - out_width = np.int(np.ceil(float(in_width) / float(stride_w))) + out_height = int(np.ceil(float(in_height) / float(stride_h))) + out_width = int(np.ceil(float(in_width) / float(stride_w))) output_np = np.zeros((batch, out_height, out_width, out_channel)) - pad_along_height = np.int( - np.max((out_height - 1) * stride_h + filter_height - in_height, 0) - ) - pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) + pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) + pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) + pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) + pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) + pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) + pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) index_h = pad_top_scipy - pad_top_tvm index_w = pad_left_scipy - pad_left_tvm for i in range(batch): diff --git a/python/tvm/topi/testing/poolnd_python.py b/python/tvm/topi/testing/poolnd_python.py index 43440d32f44e..28bf5fc26497 100644 --- a/python/tvm/topi/testing/poolnd_python.py +++ b/python/tvm/topi/testing/poolnd_python.py @@ -18,12 +18,60 @@ """Ground truth max and average pooling operators in python.""" import itertools import math -from typing import List, Tuple +from typing import List, Tuple, Optional import numpy as np import tvm +def _get_supported_layout(dims: int): + """ + Returns layout that is supported by poolnd_python based on number of + dimensions of input tensor + """ + assert dims in [3, 4, 5], f"{dims}-dimensional tensor is not supported" + if dims == 3: + return "NCW" + if dims == 4: + return "NCHW" + # dims == 5 + return "NCDHW" + + +def _convert_to_layout( + input_tensor: np.ndarray, + layout: str, +) -> np.ndarray: + """ + Converts back to original layout after the algorithm is finished + """ + supported_layout = _get_supported_layout(input_tensor.ndim) + if layout is not None and supported_layout != layout: + # Generate transpose list + transpose_list = [] + for d in layout: + transpose_list.append(supported_layout.index(d)) + return input_tensor.transpose(transpose_list) + return input_tensor + + +def _convert_from_layout( + input_tensor: np.ndarray, + layout: str, +) -> np.ndarray: + """ + Converts tensor to one of suppored layouts + """ + supported_layout = _get_supported_layout(input_tensor.ndim) + if layout is not None and supported_layout != layout: + # Generate transpose list + transpose_list = [] + for d in supported_layout: + transpose_list.append(layout.index(d)) + return input_tensor.transpose(transpose_list) + return input_tensor + + def get_slice( spatial_dimensions: int, pad_np: np.array, @@ -90,8 +138,12 @@ def poolnd_python( count_include_pad: bool = True, ceil_mode: bool = False, dtype: str = "float32", + layout: Optional[str] = None, ) -> np.array: """Ground truth pooling operator impelmented in numpy.""" + + np_data = _convert_from_layout(np_data, layout) + out_shape = [np_data.shape[0], np_data.shape[1]] for dim in range(2, len(np_data.shape)): i = dim - 2 @@ -158,4 +210,4 @@ def poolnd_python( else: raise ValueError("Pool type {} is not supported".format(pool_type)) - return ret_np + return _convert_to_layout(ret_np, layout) diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py new file mode 100644 index 000000000000..13b460f07e1d --- /dev/null +++ b/python/tvm/topi/testing/resize_python.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Upsampling in python""" +import math +import numpy as np +from tvm.topi.utils import nchw_pack_layout + + +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" + scale = image_width / target_width + if coordinate_transformation_mode == "half_pixel": + in_x = (x + 0.5) * scale - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_x = (image_width - 1) / (target_width - 1) * x if target_width > 1 else 0 + elif coordinate_transformation_mode == "asymmetric": + in_x = scale * x + else: + raise ValueError( + "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) + ) + return in_x + + +def get_index(x, image_width, target_width, coordinate_transformation_mode): + """get and round the nearest index for nearest_neighbor""" + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + if coordinate_transformation_mode == "align_corners": + # round prefer ceil + out = int(math.floor(in_x + 0.5)) + else: + out = int(math.floor(in_x)) + out = max(min(out, image_width - 1), 0) + return out + + +def resize3d_nearest(arr, scale, coordinate_transformation_mode): + """Populate the array by scale factor""" + d, h, w = arr.shape + out_d, out_h, out_w = [int(round(i * s)) for i, s in zip(arr.shape, scale)] + out = np.empty((out_d, out_h, out_w)) + for z in range(out_d): + for y in range(out_h): + for x in range(out_w): + in_z = get_index(z, d, out_d, coordinate_transformation_mode) + in_y = get_index(y, h, out_h, coordinate_transformation_mode) + in_x = get_index(x, w, out_w, coordinate_transformation_mode) + out[z, y, x] = arr[in_z, in_y, in_x] + return out + + +def resize3d_linear(data_in, scale, coordinate_transformation_mode): + """Trilinear 3d scaling using python""" + dtype = data_in.dtype + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + indexes = np.mgrid[0:2, 0:2, 0:2] + + def _get_patch(zint, yint, xint): + # Get the surrounding values + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] + return p + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + wz = np.array([1.0 - zfract, zfract], dtype=dtype) + wy = np.array([1.0 - yfract, yfract], dtype=dtype) + wx = np.array([1.0 - xfract, xfract], dtype=dtype) + + p = _get_patch(zint, yint, xint) + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) + + return data_out + + +def resize3d_cubic(data_in, scale, coordinate_transformation_mode): + """Tricubic 3d scaling using python""" + dtype = data_in.dtype + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + def _cubic_spline_weights(t, alpha=-0.5): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return np.array([w1, w2, w3, w4]) + + indexes = np.mgrid[-1:3, -1:3, -1:3] + + def _get_patch(zint, yint, xint): + # Get the surrounding values + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] + return p + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + wz = _cubic_spline_weights(zfract) + wy = _cubic_spline_weights(yfract) + wx = _cubic_spline_weights(xfract) + + p = _get_patch(zint, yint, xint) + + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) + + return data_out + + +def resize3d_ncdhw( + data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners" +): + """reference kernel for 3D image resizing""" + ishape = data.shape + + oshape = ( + ishape[0], + ishape[1], + int(round(ishape[2] * scale[0])), + int(round(ishape[3] * scale[1])), + int(round(ishape[4] * scale[2])), + ) + + output_np = np.zeros(oshape, dtype=data.dtype) + + for b in range(oshape[0]): + for c in range(oshape[1]): + if method == "nearest_neighbor": + output_np[b, c, :, :, :] = resize3d_nearest( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + elif method == "linear": + output_np[b, c, :, :, :] = resize3d_linear( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + elif method == "cubic": + output_np[b, c, :, :, :] = resize3d_cubic( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + else: + raise ValueError("Unknown resize method", method) + + return output_np + + +def resize1d_python( + data, + scale, + layout="NCW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NWC": + data = data.transpose([0, 2, 1]) + + data = np.expand_dims(data, axis=[2, 3]) + output_np = resize3d_ncdhw(data, (1, 1) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NWC": + output_np = output_np.transpose([0, 2, 1]) + + return output_np + + +def resize2d_python( + data, + scale, + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of scaling using nearest neighbour""" + + if layout == "NHWC": + data = data.transpose([0, 3, 1, 2]) + elif nchw_pack_layout(layout): + ishape = data.shape + transposed = data.transpose([0, 4, 1, 5, 2, 3]) + tshape = transposed.shape + data = transposed.reshape( + tshape[0] * tshape[1], tshape[2] * tshape[3], tshape[4], tshape[5] + ) + + data = np.expand_dims(data, axis=2) + output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NHWC": + output_np = output_np.transpose([0, 2, 3, 1]) + elif nchw_pack_layout(layout): + output_np = output_np.reshape(tshape[0:4] + output_np.shape[2:]) + output_np = output_np.transpose([0, 2, 4, 5, 1, 3]) + + return output_np + + +def resize3d_python( + data, + scale, + layout="NCDHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NDHWC": + data = data.transpose([0, 4, 1, 2, 3]) + + output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode) + + if layout == "NDHWC": + output_np = output_np.transpose([0, 2, 3, 4, 1]) + + return output_np diff --git a/python/tvm/topi/testing/trilinear_resize3d_python.py b/python/tvm/topi/testing/trilinear_resize3d_python.py deleted file mode 100644 index d603e987d5ef..000000000000 --- a/python/tvm/topi/testing/trilinear_resize3d_python.py +++ /dev/null @@ -1,111 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks -"""Trilinear 3D resize in python""" -import math -import numpy as np - - -def trilinear_resize3d_python( - data_in, out_size, layout, coordinate_transformation_mode="align_corners" -): - """Trilinear 3d scaling using python""" - (new_d, new_h, new_w) = out_size - - if layout == "NDHWC": - (batch, d, h, w, channel) = data_in.shape - data_out = np.ones((batch, new_d, new_h, new_w, channel)) - else: - (batch, channel, d, h, w) = data_in.shape - data_out = np.ones((batch, channel, new_d, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - depth_scale = np.float32(d - 1) / np.float32(out_size[0] - 1) - height_scale = np.float32(h - 1) / np.float32(out_size[1] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - depth_scale = np.float32(d) / np.float32(out_size[0]) - height_scale = np.float32(h) / np.float32(out_size[1]) - width_scale = np.float32(w) / np.float32(out_size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _in_coord(new_coord, scale, shape, mode): - if mode == "half_pixel": - in_coord = (new_coord + 0.5) * scale - 0.5 - else: - in_coord = new_coord * scale - coord0 = int(math.floor(in_coord)) - coord1 = max(min(coord0 + 1, shape - 1), 0) - coord0 = max(coord0, 0) - coord_lerp = in_coord - math.floor(in_coord) - return coord0, coord1, coord_lerp - - for b in range(batch): - for i in range(channel): - for m in range(new_d): - for j in range(new_h): - for k in range(new_w): - z0, z1, z_lerp = _in_coord( - m, depth_scale, d, coordinate_transformation_mode - ) - y0, y1, y_lerp = _in_coord( - j, height_scale, h, coordinate_transformation_mode - ) - x0, x1, x_lerp = _in_coord( - k, width_scale, w, coordinate_transformation_mode - ) - - if layout == "NDHWC": - A0 = data_in[b][z0][y0][x0][i] - B0 = data_in[b][z0][y0][x1][i] - C0 = data_in[b][z0][y1][x0][i] - D0 = data_in[b][z0][y1][x1][i] - A1 = data_in[b][z1][y0][x0][i] - B1 = data_in[b][z1][y0][x1][i] - C1 = data_in[b][z1][y1][x0][i] - D1 = data_in[b][z1][y1][x1][i] - else: - A0 = data_in[b][i][z0][y0][x0] - B0 = data_in[b][i][z0][y0][x1] - C0 = data_in[b][i][z0][y1][x0] - D0 = data_in[b][i][z0][y1][x1] - A1 = data_in[b][i][z1][y0][x0] - B1 = data_in[b][i][z1][y0][x1] - C1 = data_in[b][i][z1][y1][x0] - D1 = data_in[b][i][z1][y1][x1] - - A = _lerp(A0, A1, z_lerp) - B = _lerp(B0, B1, z_lerp) - C = _lerp(C0, C1, z_lerp) - D = _lerp(D0, D1, z_lerp) - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NDHWC": - data_out[b][m][j][k][i] = pixel - else: - data_out[b][i][m][j][k] = pixel - - return data_out diff --git a/python/tvm/topi/testing/upsampling_python.py b/python/tvm/topi/testing/upsampling_python.py deleted file mode 100644 index dd187c4d8cff..000000000000 --- a/python/tvm/topi/testing/upsampling_python.py +++ /dev/null @@ -1,136 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Upsampling in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def upsample_nearest(arr, scale): - """Populate the array by scale factor""" - h, w = arr.shape - out_h = int(round(h * scale[0])) - out_w = int(round(w * scale[1])) - out = np.empty((out_h, out_w)) - for y in range(out_h): - for x in range(out_w): - in_y = math.floor(y / scale[0]) - in_x = math.floor(x / scale[1]) - out[y, x] = arr[in_y, in_x] - return out - - -def upsampling_python(data, scale, layout="NCHW"): - """Python version of scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) - return output_np - # NCHWinic - if nchw_pack_layout(layout): - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ishape[4], - ishape[5], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for ib in range(oshape[4]): - for c in range(oshape[1]): - for ic in range(oshape[5]): - output_np[b, c, :, :, ib, ic] = upsample_nearest( - data[b, c, :, :, ib, ic], scale - ) - return output_np - - if layout == "NHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - ishape[3], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[3]): - output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) - - -def upsample3d_nearest(arr, scale): - """Populate the array by scale factor""" - d, h, w = arr.shape - out_d = int(round(d * scale[0])) - out_h = int(round(h * scale[1])) - out_w = int(round(w * scale[2])) - out = np.empty((out_d, out_h, out_w)) - for z in range(out_d): - for y in range(out_h): - for x in range(out_w): - in_z = math.floor(z / scale[0]) - in_y = math.floor(y / scale[1]) - in_x = math.floor(x / scale[2]) - out[z, y, x] = arr[in_z, in_y, in_x] - return out - - -def upsampling3d_python(data, scale, layout="NCDHW"): - """Python version of 3D scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCDHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - int(round(ishape[4] * scale[2])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale) - return output_np - if layout == "NDHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - int(round(ishape[3] * scale[2])), - ishape[4], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[4]): - output_np[b, :, :, :, c] = upsample3d_nearest(data[b, :, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 3a056cfb4326..be3df2be5f6a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -31,6 +31,16 @@ class InvalidShapeError(ValueError): """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" +def ncw_pack_layout(layout_info): + """Check whether the layout type is NCWinic""" + return layout_info[:3] == "NCW" and "c" in layout_info and "n" in layout_info + + +def ncw_xc_layout(layout_info): + """Check whether the layout type is NCWxc""" + return layout_info[:3] == "NCW" and "c" in layout_info and layout_info[3:-1].isnumeric() + + def nchw_pack_layout(layout_info): """Check whether the layout type is NCHWinic""" return layout_info[:4] == "NCHW" and "c" in layout_info and "n" in layout_info diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 37bdd09d6ca6..13ca851f0e38 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -20,52 +20,66 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from .. import generic +from .. import generic, nn from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y, out_shape=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting in batch dimension. +def batch_matmul( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- cfg : ConfigSpace - Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the outputs + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) - assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" - assert XK == YK, "shapes of x and y is inconsistent" - B = te.max(XB, YB) - K = XK - if out_shape is not None: - assert out_shape[0] == B, "got invalid output shape" - assert out_shape[1] == M, "got invalid output shape" - assert out_shape[2] == N, "got invalid output shape" if cfg.is_fallback: + if transpose_a: + _, K, M = get_const_tuple(tensor_a.shape) + else: + _, M, K = get_const_tuple(tensor_a.shape) + if transpose_b: + _, N, _ = get_const_tuple(tensor_b.shape) + else: + _, _, N = get_const_tuple(tensor_b.shape) _default_batch_matmul_config(cfg, M, N, K) - - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (B, M, N), - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), - tag="batch_matmul", + return nn.batch_matmul( + tensor_a, + tensor_b, + out_shape, + out_dtype, + transpose_a, + transpose_b, ) - return C @autotvm.register_topi_schedule("batch_matmul.x86") @@ -137,20 +151,32 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_blas_common(cfg, x, y, out_shape, lib): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. +def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib): + """Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and + `tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch + dimension. Parameters ---------- cfg : ConfigSpace Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the output + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + trans_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + trans_b : Optional[bool] = True + Whether the second tensor is in transposed format. + lib : A contrib module which implements batch_matmul function cblas and mkl are supported @@ -159,9 +185,15 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) + assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul" + if trans_a: + XB, XK, M = get_const_tuple(tensor_a.shape) + else: + XB, M, XK = get_const_tuple(tensor_a.shape) + if trans_b: + YB, N, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, N = get_const_tuple(tensor_a.shape) assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: @@ -169,13 +201,18 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return lib.batch_matmul(x, y, False, True) + return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b) @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_cblas( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using cblas""" - return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas + ) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") @@ -185,9 +222,14 @@ def schedule_batch_matmul_cblas(_, outs): @autotvm.register_topi_compute("batch_matmul_mkl.x86") -def batch_matmul_mkl(cfg, x, y, out_shape=None): +def batch_matmul_mkl( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using mkl""" - return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl + ) @autotvm.register_topi_schedule("batch_matmul_mkl.x86") diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 189ac5bd34bd..6f2c202e3f61 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -27,8 +27,7 @@ from tvm.contrib import mkldnn from .utils import get_fp32_len -from .injective import schedule_injective_from_existing -from .. import tag +from .. import generic, tag from ..utils import traverse_inline, get_const_tuple @@ -306,17 +305,6 @@ def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, tr return C -def schedule_matmul_blas_common(outs): - """Default matmul schedule for BLAS library""" - s = te.create_schedule([x.op for x in outs]) - te.schedule.AutoInlineInjective(s) - - for out in outs: - if "dense" not in out.op.tag and "matmul" not in out.op.tag: - schedule_injective_from_existing(s, out) - return s - - @autotvm.register_topi_compute("dense_cblas.x86") def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): """Compute dense using cblas. This is an alias of matmul_nt operator.""" @@ -326,7 +314,7 @@ def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_cblas.x86") def schedule_dense_cblas(_, outs): """Create schedule for dense_cblas. This is an alias of matmul_nt operator.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) @autotvm.register_topi_compute("dense_mkl.x86") @@ -338,7 +326,7 @@ def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): """Create schedule for dense_mkl. This is an alias of matmul_nt operator.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) @autotvm.register_topi_compute("dense_mkldnn.x86") @@ -350,7 +338,7 @@ def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_mkldnn.x86") def schedule_dense_mkldnn(_, outs): """Create schedule for dense_mkldnn. This is an alias of matmul_nt operator.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) @autotvm.register_topi_compute("matmul_cblas.x86") @@ -366,7 +354,7 @@ def matmul_cblas( @autotvm.register_topi_schedule("matmul_cblas.x86") def schedule_matmul_cblas(_, outs): """Create schedule for matmul_cblas.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) @autotvm.register_topi_compute("matmul_mkl.x86") @@ -382,7 +370,7 @@ def matmul_mkl( @autotvm.register_topi_schedule("matmul_mkl.x86") def schedule_matmul_mkl(_, outs): """Create schedule for matmul_mkl.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) @autotvm.register_topi_compute("matmul_mkldnn.x86") @@ -398,4 +386,4 @@ def matmul_mkldnn( @autotvm.register_topi_schedule("matmul_mkldnn.x86") def schedule_matmul_mkldnn(_, outs): """Create schedule for matmul_mkldnn.""" - return schedule_matmul_blas_common(outs) + return generic.schedule_extern(outs) diff --git a/python/tvm/topi/x86/group_conv2d.py b/python/tvm/topi/x86/group_conv2d.py index 0501c5534cf2..0e10052e2428 100644 --- a/python/tvm/topi/x86/group_conv2d.py +++ b/python/tvm/topi/x86/group_conv2d.py @@ -43,7 +43,9 @@ def schedule_group_conv2d_nchw(outs): return schedule_group_conv2d_nchwc(outs) -def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout="NCHW"): +def _get_default_config( + cfg, data, kernel, strides, padding, dilation, groups, out_dtype, layout="NCHW" +): """ Get default schedule config for the workload """ @@ -55,7 +57,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, static_data_shape.append(dim) data = te.placeholder(static_data_shape, dtype=data.dtype) - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) + wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout) _fallback_schedule(cfg, wkl) @@ -159,6 +161,7 @@ def group_conv2d_nchw_spatial_pack( ), strides, padding, + dilation, groups, out_dtype, ) diff --git a/python/tvm/topi/x86/pooling.py b/python/tvm/topi/x86/pooling.py index 91108ac7485d..db0f9faf1970 100644 --- a/python/tvm/topi/x86/pooling.py +++ b/python/tvm/topi/x86/pooling.py @@ -26,8 +26,8 @@ def vectorize(fused_axis, num_parallel_axis, vectorize_limit=64): reorder_axis = [fused_axis] for i in range(num_parallel_axis, len(sch.op.axis) - 1): reorder_axis.append(sch.op.axis[i]) - kw, kh = sch.op.reduce_axis - fuse_k = sch.fuse(kw, kh) + k = sch.op.reduce_axis + fuse_k = sch.fuse(*k) c = sch.op.axis[len(sch.op.axis) - 1] reorder_axis += [fuse_k, c] sch.reorder(*reorder_axis) @@ -83,7 +83,7 @@ def schedule_pool(outs, layout): def _schedule(PaddedInput, Pool): if isinstance(PaddedInput.op, te.tensor.ComputeOp): s[PaddedInput].compute_inline() - do_vectorize = layout[-1] not in "HWhw" + do_vectorize = layout[-1] not in "DHWdhw" _parallel_sch(s[Pool], outs[0].shape, do_vectorize) def traverse(OP): diff --git a/rust/tvm-graph-rt/src/module/mod.rs b/rust/tvm-graph-rt/src/module/mod.rs index 511ba4b37132..a345758deca1 100644 --- a/rust/tvm-graph-rt/src/module/mod.rs +++ b/rust/tvm-graph-rt/src/module/mod.rs @@ -52,6 +52,7 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< values.len() as i32, &mut ret_val, &mut ret_type_code, + std::ptr::null_mut(), ); if exit_code == 0 { Ok(RetValue::from_tvm_value(ret_val, ret_type_code)) diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index d80bd9598246..170ccce0a9f1 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -84,21 +84,26 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed={}", build_path.display()); println!("cargo:rerun-if-changed={}/include", source_path.display()); - if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static=tvm"); - // TODO(@jroesch): move this to tvm-build as library_path? - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - - if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib=tvm"); - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); + match &std::env::var("CARGO_CFG_TARGET_ARCH").unwrap()[..] { + "wasm32" => {} + _ => { + if cfg!(feature = "static-linking") { + println!("cargo:rustc-link-lib=static=tvm"); + // TODO(@jroesch): move this to tvm-build as library_path? + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + + if cfg!(feature = "dynamic-linking") { + println!("cargo:rustc-link-lib=dylib=tvm"); + println!( + "cargo:rustc-link-search=native={}/build", + build_path.display() + ); + } + } } let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs index 1da64fd60483..1ebac09bf611 100644 --- a/rust/tvm-sys/src/device.rs +++ b/rust/tvm-sys/src/device.rs @@ -65,14 +65,14 @@ use thiserror::Error; #[repr(i64)] pub enum DeviceType { CPU = 1, - CUDA, - CUDAHost, - OpenCL, - Vulkan, - Metal, - VPI, - ROCM, - ExtDev, + CUDA = 2, + CUDAHost = 3, + OpenCL = 4, + Vulkan = 7, + Metal = 8, + VPI = 9, + ROCM = 10, + ExtDev = 12, } impl Default for DeviceType { diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index f874e672bb66..f9ac3b461c69 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -40,6 +40,7 @@ pub mod ffi { num_args: c_int, out_ret_value: *mut TVMValue, out_ret_tcode: *mut u32, + resource_handle: *mut c_void, ) -> c_int; } diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ba549959ac98..94db659e25c9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) // and recursively mark the corresponding components for (size_t i = 0; i < simplified_result.size(); ++i) if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + if (UsesVar(simplified_result[idx], + [v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) || + UsesVar(simplified_result[idx], + [v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; })) mark_used(i); } }; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index f0634feac083..d81159bf05c9 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset_contains)) { + if (UsesVar(coeff[i - 2], vset_contains)) { return Array(); } } diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c1daae967b47..ac78c55ed610 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator { */ Optional TryFuseIters(IterSumExpr expr) { if (!is_zero(expr->base)) return NullOpt; - if (expr->args.size() == 1) return expr->args[0]; // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -1086,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective) { + Analyzer analyzer; + Array rewrite = + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + if (rewrite.empty()) { + return indices; + } + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; +} + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. @@ -1385,5 +1399,130 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); }); +class InverseAffineIterMapTransformer { + public: + explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} + + Map operator()(const Array& iter_map, + const Array& outputs) { + ICHECK(iter_map.size() == outputs.size()); + std::vector post_dfs_order = ReverseTopologyOrder(iter_map); + + // initialize back propagation accumulator + for (const IterMapExprNode* node : post_dfs_order) { + backprop_.Set(GetRef(node), Integer(0)); + } + for (size_t i = 0; i < iter_map.size(); i++) { + backprop_.Set(iter_map[i], outputs[i]); + } + + // run back propagation + for (const IterMapExprNode* node : post_dfs_order) { + if (node->IsInstance()) { + Visit_(Downcast(GetRef(node))); + } else { + ICHECK(node->IsInstance()); + Visit_(Downcast(GetRef(node))); + } + } + return std::move(inverse_); + } + + private: + void Visit_(const IterSumExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base; + + // Case 1: Propagate to the input node directly when the sum expression has only one components + if (iter_map_expr->args.size() == 1) { + const auto& source = iter_map_expr->args[0]; + backprop_.Set(source, backprop_.at(source) + input); + return; + } + + // Case 2: If the sum expression has multiple components, check the fuse pattern and then split + // the sum expression for each components. + // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2 + // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the + // propagated value to get the corresponding components of i1 and i2, which are + // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively. + CheckFusePattern(iter_map_expr); + for (size_t i = iter_map_expr->args.size(); i > 0; i--) { + const IterSplitExpr& split = iter_map_expr->args[i - 1]; + backprop_.Set(split, + backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); + } + } + + std::vector ReverseTopologyOrder(const Array& iter_map) { + std::vector post_dfs_order; + std::unordered_map visited; + + std::function fvisit = [&](const IterMapExpr& expr) { + if (visited[expr]) { + return; + } + visited[expr] = true; + if (const auto* sum_expr = expr.as()) { + for (const IterSplitExpr& child : sum_expr->args) { + fvisit(child); + } + } else { + const auto* split_expr = expr.as(); + ICHECK(split_expr); + if (const auto* source = split_expr->source->source.as()) { + fvisit(GetRef(source)); + } + } + post_dfs_order.push_back(expr.get()); + }; + for (const IterSumExpr& expr : iter_map) { + fvisit(expr); + } + std::reverse(post_dfs_order.begin(), post_dfs_order.end()); + return post_dfs_order; + } + + void Visit_(const IterSplitExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor; + const IterMark& source = iter_map_expr->source; + if (source->source.as()) { + IterSumExpr source_expr = Downcast(source->source); + backprop_.Set(source_expr, backprop_.at(source_expr) + input); + } else { + Var source_var = Downcast(source->source); + if (inverse_.count(source_var)) { + inverse_.Set(source_var, inverse_.at(source_var) + input); + } else { + inverse_.Set(source_var, input); + } + } + } + + /* + * \brief Check the fuse pattern of sum_expr. We assume components of sum_expr is sorted in + * descending order of lower_factor. + */ + void CheckFusePattern(const IterSumExpr sum_expr) { + ICHECK(sum_expr->args.size()); + PrimExpr expected_scale = sum_expr->args.back()->scale; + for (size_t i = sum_expr->args.size(); i > 0; i--) { + ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale)); + expected_scale *= sum_expr->args[i - 1]->extent; + } + } + + Analyzer* analyzer_; + Map backprop_; // the accumulator of backpropgation + Map inverse_; // the result of inverse transformation +}; + +Map InverseAffineIterMap(const Array& iter_map, + const Array outputs) { + Analyzer analyzer; + return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); +} + +TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433dadd..ff6536ab066b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -882,6 +884,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7522f20523c8..54edbaee35cd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); - alloc_storage_scope_[v] = op->value.as()->value; - PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { auto tensor = Downcast(op->producer); - ICHECK(alloc_storage_scope_.count(tensor->op)); - if (!alloc_storage_scope_[tensor->op].empty()) { + if (!op->storage_scope.empty()) { PrintIndent(); stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { stream << "), '"; PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[tensor->op] << "')\n"; + stream << op->storage_scope << "')\n"; } PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index b01ca2763e28..47c13f73022f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor, * \param tensor The tensor to allocate a name. */ std::string GetTensorID(const Tensor& tensor); - /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cd8173717d5f..2008fe5e47b8 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -88,7 +88,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, data_alignment, offset_factor, buffer_type); } @@ -222,6 +222,7 @@ Array CreatePassList(bool disable_loop_partition, bool for pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); } pass_list.push_back(tir::transform::BF16Legalize()); @@ -377,6 +378,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } @@ -388,7 +390,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target if (target->GetAttr("unpacked-api").value_or(Bool(false))) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { - mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); + mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); } mixed_pass_list.push_back(tir::transform::SplitHostDevice()); @@ -437,14 +439,18 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target } if (target->kind->device_type == kDLCPU && target_host == target) { - ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + // TODO(@jroesch): This check is no longer true we need to figure out if we care about this. + // We need to relax this check for just TIR functions. + // ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " + // << "and host_target are both llvm target." + // << "\n"; } return {mhost, mdevice}; } +// Can we make this take one annotated IRModule? +// // Build for heterogeneous execution. runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc new file mode 100644 index 000000000000..3454b6011c9b --- /dev/null +++ b/src/ir/affine_type.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/affine_type.cc + * \brief The Type information for quantized nodes. + */ +#include +#include +#include + +namespace tvm { + +using tvm::ReprPrinter; +using namespace tvm::runtime; + +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TensorAffineType") + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { + return TensorAffineType(scale, zero_point, dtype); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " + << node->dtype << ")"; + }); + +TupleAffineType::TupleAffineType(Array types) { + ObjectPtr n = make_object(); + n->types = std::move(types); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TupleAffineType").set_body_typed([](Array types) { + return TupleAffineType(types); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleAffineType(["; + for (size_t i = 0; i < node->types.size(); ++i) { + p->stream << node->types[i]; + if (i < node->types.size() - 1) { + p->stream << ", "; + } + } + p->stream << "])"; + }); + +} // namespace tvm diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 7340f6977943..4e79d0e74c59 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -40,7 +40,6 @@ Source::Source(SourceName src_name, std::string source) { // NB(@jroesch): std::string source_str = n->source; for (auto c : source_str) { - DLOG(INFO) << "char=" << c; if (c == '\n') { // Record the length of the line. n->line_map.back().second = length; diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc new file mode 100644 index 000000000000..17ba84e68df4 --- /dev/null +++ b/src/printer/model_library_format_printer.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include + +#include "text_printer.h" + +namespace tvm { +namespace printer { + +class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { + public: + ModelLibraryFormatPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) + : text_printer_{show_meta_data, annotate, show_warning} {} + + const char* type_key() const override { return "model_library_format_printer"; } + + std::string Print(const ObjectRef& node) { + Doc doc; + doc << text_printer_.PrintFinal(node); + return doc.str(); + } + + TVMRetValue GetVarName(tir::Var var) { + TVMRetValue rv; + std::string var_name; + if (text_printer_.GetVarName(var, &var_name)) { + rv = var_name; + } + + return rv; + } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (name == "print") { + return TypedPackedFunc( + [sptr_to_self, this](ObjectRef node) { return Print(node); }); + } else if (name == "get_var_name") { + return TypedPackedFunc( + [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); + } else { + return PackedFunc(); + } + } + + private: + TextPrinter text_printer_; +}; + +TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter") + .set_body_typed([](bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) { + return ObjectRef( + make_object(show_meta_data, annotate, show_warning)); + }); + +} // namespace printer +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7a529cc0b914..0332a2d539d2 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -256,6 +257,13 @@ class TIRTextPrinter : public StmtFunctor, /*! \brief Print the node */ Doc Print(const ObjectRef& node); + /*! \brief Place into `s` the name used in the preceding Print call for `v`. + * \param v Var instance to check. Must point to a VarNode visited by Print. + * \param s String to receive the name. + * \return true when a name re-mapping was found. + */ + bool GetVarName(::tvm::tir::Var v, std::string* s); + private: /*! \brief whether show meta data */ bool show_meta_; @@ -394,6 +402,8 @@ class TextPrinter { /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; + bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } + Doc PrintFinal(const ObjectRef& node) { Doc doc; if (node->IsInstance()) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 04c5ea1cdf99..f232994480f8 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -35,6 +35,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -204,8 +205,8 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { if (!is_zero(buf->elem_offset)) { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (GetRef(buf).scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(GetRef(buf).scope()); } if (buf->data_alignment != 128) { doc << ", align=" << buf->data_alignment; @@ -447,8 +448,9 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; + auto scope = GetPtrStorageScope(op->buffer_var); doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << ")"; + << Print(op->extents) << "), storage_scope = " << scope; if (!is_one(op->condition)) { doc << " if " << Print(op->condition); } @@ -485,23 +487,6 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForKind2String(ForKind t) { - switch (t) { - case ForKind::kSerial: - return "serial"; - case ForKind::kParallel: - return "parallel"; - case ForKind::kVectorized: - return "vectorized"; - case ForKind::kUnrolled: - return "unroll"; - case ForKind::kThreadBinding: - return "thread_binding"; - } - LOG(FATAL) << "Unknown ForKind"; - return "Unknown"; -} - Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " @@ -598,8 +583,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { << Print(alloc_buf->shape) << ")" << Doc::NewLine(); } for (const auto& match_buf : block_op->match_buffers) { - body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source) - << ")" << Doc::NewLine(); + body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" + << Doc::NewLine(); } if (block_op->init.defined()) { Doc init_block; @@ -734,5 +719,15 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { return doc; } +bool TIRTextPrinter::GetVarName(Var v, std::string* s) { + auto it = memo_var_.find(v); + if (it == memo_var_.end()) { + return false; + } + + *s = it->second.str(); + return true; +} + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4bbe17064c87..cc7536b48cfd 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -301,8 +303,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { } else { doc << ", elem_offset=" << Print(buf->elem_offset); } - if (buf->scope != "global") { - doc << ", scope=" << Doc::StrLiteral(buf->scope); + if (buf.scope() != "global") { + doc << ", scope=" << Doc::StrLiteral(buf.scope()); } if (buf->data_alignment != -1) { doc << ", align=" << buf->data_alignment; @@ -335,29 +337,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { const Buffer& buf = op->buffer; buf_not_in_headers.insert(buf.get()); - Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source); - if (!buf->strides.empty()) { - doc << ", strides=" << Print(buf->strides); - } - if (buf->offset_factor != 0 && buf->elem_offset->IsInstance()) { - Var elem_offset = Downcast(buf->elem_offset); - if (memo_var_.find(elem_offset) != memo_var_.end()) { - doc << ", elem_offset=" << Print(buf->elem_offset); - } else { - // implicitly define elem_offset - memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset"); - var_not_in_headers.insert(elem_offset.get()); - } - } else { - doc << ", elem_offset=" << Print(buf->elem_offset); - } - if (buf->data_alignment != -1) { - doc << ", align=" << buf->data_alignment; - } - if (buf->offset_factor != 0) { - doc << ", offset_factor=" << buf->offset_factor; - } - doc << ")"; + Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", " + << memo_buf_decl_[op->buffer] << ")"; return doc; } @@ -578,31 +559,6 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; - // merge attr with allocate when possible - if (op->node->IsInstance() && op->attr_key == "storage_scope" && - op->body->IsInstance()) { - const auto* alloc = Downcast(op->body).get(); - if (alloc->buffer_var.same_as(op->node)) { - var_not_in_headers.insert(alloc->buffer_var.get()); - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) - << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ") as " << Print(op->node) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); - } else { - doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " - << PrintDType(alloc->dtype) << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); - } - return doc; - } - } // merge attr with realize when possible if (op->node->IsInstance() && op->attr_key == "realize_scope" && op->body->IsInstance()) { @@ -680,8 +636,26 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr"; - return Doc(); + var_not_in_headers.insert(op->buffer_var.get()); + Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " + << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ") as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " + << PrintDType(op->dtype) << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(op->body); + } + return doc; } Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { @@ -709,23 +683,6 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForKind2String(ForKind t) { - switch (t) { - case ForKind::kSerial: - return "serial"; - case ForKind::kParallel: - return "parallel"; - case ForKind::kVectorized: - return "vectorized"; - case ForKind::kUnrolled: - return "unroll"; - case ForKind::kThreadBinding: - return "thread_binding"; - } - LOG(FATAL) << "Unknown ForKind"; - return "Unknown"; -} - Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers.insert(op->loop_var.get()); @@ -1013,8 +970,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; - header_var << PrintDType(var->dtype) << ")"; + auto type = GetRef(var)->type_annotation; + if (auto* ptr_type = type.as()) { + auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << PrintDType(prim_type->dtype) << ", " + << Doc::StrLiteral(ptr_type->storage_scope) << ")"; + } else { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } } } doc << Doc::Indent(4, header_attr << header_var << header_buf << body); diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 840878390018..53c680b722cd 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -76,17 +76,20 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) } } -AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { +AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& func_name, + const std::string& target) { auto ret = regions_.emplace(AnnotatedRegion()); (*ret.first)->id_ = region_id_++; (*ret.first)->target_ = target; + (*ret.first)->func_name_ = func_name; return *ret.first; } class AnnotatedRegionSet::Creator : protected MixedModeVisitor { public: - Creator(const Op& region_begin_op, const Op& region_end_op) - : begin_op_(region_begin_op), end_op_(region_end_op) {} + Creator(const Op& region_begin_op, const Op& region_end_op, + const std::string& func_name = "default") + : begin_op_(region_begin_op), end_op_(region_end_op), func_name_(func_name) {} AnnotatedRegionSet Create(const Expr& expr) { VisitExpr(expr); @@ -144,7 +147,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { ICHECK(!region.defined()); // Create a new region. - region = region_set_->MakeRegion(target); + region = region_set_->MakeRegion(func_name_, target); region->nodes_.insert(GetRef(call)); region->ins_.push_back(GetRef(call)); } else { @@ -213,10 +216,13 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { const Op begin_op_; /*! \brief Region 'end' annotation operator. */ const Op end_op_; + /*! \brief The unique function name that is used to be the name of this region set. */ + const std::string func_name_; }; -AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) { - return Creator(begin, end).Create(expr); +AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end, + const std::string& func_name) { + return Creator(begin, end, func_name).Create(expr); } TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 2e4eec23f733..aca42397916c 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -62,6 +62,9 @@ class AnnotatedRegionNode : public Object { /*! \brief Get the region ID. */ int GetID() const { return id_; } + /*! \brief Get the region name. */ + std::string GetName() const { return func_name_; } + /*! \brief Get the region target. */ std::string GetTarget() const { return target_; } @@ -80,6 +83,8 @@ class AnnotatedRegionNode : public Object { protected: /*! \brief The region ID. */ int id_{-1}; + /*! \brief The func name. */ + std::string func_name_ = "default"; /*! \brief The target for this region. */ std::string target_ = "default"; /*! \brief The inputs to this region. */ @@ -177,7 +182,7 @@ class AnnotatedRegionSetNode : public Object { * * \return The new region. */ - AnnotatedRegion MakeRegion(const std::string& target); + AnnotatedRegion MakeRegion(const std::string& func_name, const std::string& target); std::unordered_set regions_; /*! \brief The next region ID to assign. */ @@ -256,10 +261,12 @@ class AnnotatedRegionSet : public ObjectRef { * \param expr The relay expr from which to construct the set. * \param begin Region begin annotation operator. * \param end Region end annotation operator. + * \param func_name function name * * \return The created RegionSet for the expression. */ - static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end); + static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end, + const std::string& func_name = "default"); private: /*! \brief Helper class to construct a RegionSet from an expr.*/ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9b495adbdea8..221df958a8cb 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor { fi_node->tir_primfuncs.Set(primfunc_target, primfunc); fi_node->relay_primfuncs.Set(primfunc_target, relay_func); } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); + function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); } void VisitExpr_(const CallNode* op) override { @@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor { << "(i.e functions composed of fusable operator invocations)"; } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); ICHECK(ext_func.defined()) << "External function is not defined."; UpdateConstants(func, ¶ms_); // Generate the TIR function call - CreateFuncCall(GetRef(op), ext_func->func_name); + CreateFuncCall(GetRef(op), ext_func->prim_fn_var->name_hint); return; } @@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor { } target = targets_[call_dev_type]; } - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); + + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); + if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor { UpdateFunctionMetadata(lowered_func, func, target); // Generate the TIR function call - CreateFuncCall(GetRef(op), lowered_func->func_name); + CreateFuncCall(GetRef(op), lowered_func->prim_fn_var->name_hint); } void VisitExpr_(const VarNode* op) override { @@ -625,8 +625,6 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); - body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), - body); } allocated[sid] = true; } @@ -652,7 +650,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief mod */ runtime::Module* mod_; /*! \brief list of input expressions (i.e., variable passed by the user) */ - std::vector input_vars_; + std::vector input_vars_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ @@ -722,7 +720,8 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { - te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + te::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global")); sids_table_[sid] = buffer_var; } } @@ -783,8 +782,12 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs.Set(target_host_str, mod_run); } ret.function_metadata = std::move(function_metadata_); - ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name); + + std::vector input_var_names(input_vars_.size()); + std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), + [](Var input_var) -> String { return input_var->name_hint(); }); + ret.metadata = + runtime::Metadata(input_var_names, return_sid_.size(), runtime::kTvmExecutorAot, mod_name); return ret; } }; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ea53c34c793b..f407436e5868 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -313,57 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - pass_seqs.push_back(transform::SimplifyInference()); - - // Convert Dynamic ops to static versions - pass_seqs.push_back(transform::DynamicToStatic()); - - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - *rv = false; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::InferType()); - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = GetPassPrefix(targets, false); if (targets.size() == 1) { const auto& target = (*targets.begin()).second; diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f0b43b14c650..6142e8323dea 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -46,569 +46,14 @@ #include "../../runtime/meta_data.h" #include "../transforms/pass_utils.h" +#include "te_compiler_cache.h" #include "utils.h" namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(LoweredOutputNode); -TVM_REGISTER_NODE_TYPE(CachedFuncNode); -TVM_REGISTER_NODE_TYPE(CCacheKeyNode); -TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); -LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { - auto n = make_object(); - n->outputs = std::move(outputs); - n->implementation = std::move(impl); - data_ = std::move(n); -} - -CCacheKey::CCacheKey(Function source_func, Target target) { - auto n = make_object(); - n->source_func = std::move(source_func); - n->target = std::move(target); - data_ = std::move(n); -} - -Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = tir::as_const_int(val); - if (pval != nullptr) { -#ifndef TVM_INDEX_DEFAULT_I64 - ICHECK_LE(pval[0], std::numeric_limits::max()); - ICHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(IntImm(DataType::Int(32), *pval)); -#else - res.push_back(val); -#endif // TVM_INDEX_DEFAULT_I64 - } else if (val->IsInstance()) { - res.push_back(val.as()->ToVar()); - } else { - res.push_back(val); - } - } - return res; -} - -// The getter to get schedule from compile engine. -// Get schedule from functor. -class ScheduleGetter : public backend::MemoizedExprTranslator> { - public: - explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - } - - CachedFunc Create(const Function& prim_func) { - auto cache_node = make_object(); - cache_node->target = target_; - for (Var param : prim_func->params) { - Array inputs; - if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); - inputs.push_back(tensor); - } - } - memo_[param] = inputs; - } - readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - ICHECK(anchor_op_.defined()); - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : cache_node->outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - cache_node->schedule = std::move(schedule); - return CachedFunc(cache_node); - } - - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); - scalars_.push_back(value->op); - return {value}; - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); - static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); - ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - - Array inputs; - int count_tuple = 0; - for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - - Array outputs; - OpImplementation impl; - // Skip fcompute for device copy operators as it is not registered. - if (op == device_copy_op_) { - const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); - } else { - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; - } - - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern > anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - // Set the name to `__copy`. It will be detected in graph executor to perform - // data copy across devices. - if (op == device_copy_op_) { - readable_name_stream_.str(std::string()); - readable_name_stream_ << "__copy"; - } else { - readable_name_stream_ << '_' << op->name; - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; - } - - private: - tvm::Target target_; - Op anchor_op_; - Attrs anchor_attrs_; - int anchor_op_pattern_{-1}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - bool use_auto_scheduler_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; -}; - -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target) { - return ScheduleGetter(target).Create(source_func); -} - -// Creates shape function from functor. -class MakeShapeFunc : public backend::MemoizedExprTranslator> { - public: - MakeShapeFunc() {} - - std::pair Create(const Function& prim_func) { - for (auto param : prim_func->params) { - param_states_[param] = kNoNeed; - Array data_inputs; - Array shape_inputs; - - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { - // Add data placeholder - Shape shape = GetShape(ttype->shape); - tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); - data_inputs.push_back(data_tensor); - // Add shape placeholder - int64_t ndim = shape.size(); - Shape sshape; - if (ndim > 0) { - sshape.push_back(tvm::Integer(ndim)); - } - tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); - shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } - } - param_data_[param] = data_inputs; - param_shapes_[param] = shape_inputs; - } - readable_name_stream_ << "shape_func"; - auto cache_node = make_object(); - cache_node->outputs = VisitExpr(prim_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - cache_node->func_name = candidate_name; - - // set inputs - for (auto param : prim_func->params) { - int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); - if (state & kNeedInputData) { - for (auto t : param_data_[param]) { - cache_node->inputs.push_back(t); - } - } - if (state & kNeedInputShape) { - for (auto t : param_shapes_[param]) { - cache_node->inputs.push_back(t); - } - } - } - - CachedFunc cfunc(cache_node); - // generate schedule for shape func - Array out_ops; - for (auto t : cache_node->outputs) { - out_ops.push_back(t->op); - } - auto schedule = te::create_schedule(out_ops); - tvm::te::AutoInlineInjective(schedule); - for (const auto& scalar : scalars_) { - auto scalar_op = scalar->op; - if (schedule->Contain(scalar_op)) { - schedule[scalar_op].compute_inline(); - } - } - return std::make_pair(schedule, cfunc); - } - - Array VisitExpr(const Expr& expr) final { - if (expr.as()) { - // Do not memoize vars because shape functions could use either the data - // or the shape of a var each time. - return ExprFunctor::VisitExpr(expr); - } - // For other case, do memoized visit - return backend::MemoizedExprTranslator>::VisitExpr(expr); - } - - Array VisitExpr_(const VarNode* var_node) final { - auto var = GetRef(var_node); - auto it = param_states_.find(var); - if (it == param_states_.end()) { - LOG(FATAL) << "Free variable " << var->name_hint(); - return {}; - } else { - ICHECK(data_dependents_per_input_.size()); - auto data_dependent = data_dependents_per_input_.back(); - if (data_dependent) { - param_states_[var] |= kNeedInputData; - return param_data_[var]; - } else { - param_states_[var] |= kNeedInputShape; - return param_shapes_[var]; - } - } - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(data_dependents_per_input_.size()); - bool data_dependent = data_dependents_per_input_.back(); - if (!op->is_scalar()) { - // This is a constant weight, extract the shape of the weight tensor. - // This can not be data dependent. - CHECK(!data_dependent); - auto ttype = op->checked_type().as(); - int ndim = static_cast(ttype->shape.size()); - Array out_shape{ndim}; - te::Tensor value = tvm::te::compute( - out_shape, - [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = make_const(DataType::Int(64), 0); - for (int i = 0; i < ndim; i++) { - ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); - } - return ret; - }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - if (data_dependent) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } - } - - Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttrMap("FShapeFunc"); - static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); - ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; - Op op = Downcast(call_node->op); - ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependent shape func"; - ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; - ICHECK_GT(tshape_data_dependent.count(op), 0) - << "Internal error, cannot find TShapeDataDependent for " << op->name; - - Array dep_spec = tshape_data_dependent[op]; - if (dep_spec.size() == 1) { - // This is for cases when data dependence is specified per op - // Replicate 0 or 1 flag to all arguments - for (size_t i = 1; i < call_node->args.size(); ++i) { - dep_spec.push_back(dep_spec[0]); - } - } - - // Visit all inputs - Array inputs; - int count_tuple = 0; - for (size_t i = 0; i < call_node->args.size(); ++i) { - Expr arg = call_node->args[i]; - if (arg->checked_type().as()) { - ++count_tuple; - } - data_dependents_per_input_.push_back(dep_spec[i]->value != 0); - for (te::Tensor tensor : VisitExpr(arg)) { - inputs.push_back(tensor); - } - data_dependents_per_input_.pop_back(); - } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } - // Get output ndims - auto ret_type = call_node->checked_type(); - Array out_ndims; - if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } else { - auto rtype = ret_type.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(rtype); - for (size_t i = 0; i < rtype->fields.size(); ++i) { - auto ttype = rtype->fields[i].as(); - ICHECK(ttype); - out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); - } - } - // Call shape function - auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); - readable_name_stream_ << "_" << op->name; - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - Array input_shapes = VisitExpr(op->tuple); - Array out; - out.push_back(input_shapes[op->index]); - return out; - } - - private: - /*! \brief String stream for function name */ - std::ostringstream readable_name_stream_; - /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; - /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; - /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; - /*! \brief Stack of data dependencies for shape function, specified per each op input */ - std::vector data_dependents_per_input_; - /*! \brief Scalars used in the shape function */ - Array scalars_; -}; - class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. @@ -616,19 +61,19 @@ class CompileEngineImpl : public CompileEngineNode { return LowerInternal(key, mangle_fn)->cached_func; } + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { auto mangle_fn = [](String name) { return name; }; CCacheValue value = LowerInternal(key, mangle_fn); if (value->packed_func != nullptr) return value->packed_func; - // build the function. - tvm::runtime::Module m; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(value->cached_func->funcs, key->target); - } else { - m = build(value->cached_func->funcs, key->target, Target(nullptr)); - } - value->packed_func = m.GetFunction(value->cached_func->func_name); + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); return value->packed_func; } @@ -643,6 +88,7 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); ICHECK(code_gen.defined()) << "No external codegen is set"; @@ -651,7 +97,9 @@ class CompileEngineImpl : public CompileEngineNode { auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false); + << AsText(src_func, false) << "\n" + << "Functions with external codegen must have the " + << tvm::attr::kGlobalSymbol << " attr set."; std::string sn = symbol_name.value(); if (!cached_symbol.count(sn)) { @@ -669,7 +117,12 @@ class CompileEngineImpl : public CompileEngineNode { src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); runtime::Module ext_mod = (*pf)(src_func); - ICHECK(ext_mod.defined()) << "No external runtime is generated."; + // todo(@zhiics, @jroesch): Should this be a user visible error? + ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name + << "even though it was requested" + "by the annotated function " + << PrettyPrint(src_func); + ret.push_back(ext_mod); } } @@ -734,44 +187,49 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto cache_node = make_object(); + auto ir_module = IRModule(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node.value()); - cache_node->target = Target("ext_dev"); - cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); - value->cached_func = CachedFunc(cache_node); + auto func_name = std::string(name_node.value()); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + ir_module->Add(global_var, key->source_func); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } + // Enforce use the target. With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object(*(cfunc.operator->())); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(mangle_fn(name), &name_map_); + }); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; if (const CallNode* call_node = body.as()) { if (call_node->attrs.as()) { - value->cached_func = CachedFunc(cache_node); + value->cached_func = cfunc; return value; } } - cache_node->func_name = GetUniqueName(mangle_fn(cache_node->func_name)); // NOTE: array will copy on write. - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { all_args.push_back(arg); } // lower the function std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; - value->cached_func = CachedFunc(cache_node); return value; } + // implement lowered shape func CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); @@ -790,47 +248,17 @@ class CompileEngineImpl : public CompileEngineNode { With target_scope(key->target); ICHECK(!value->cached_func.defined()); - auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object(*(spair.second.operator->())); - cache_node->func_name = GetUniqueName(cache_node->func_name); - cache_node->target = key->target; - - Array all_args = cache_node->inputs; - for (te::Tensor arg : cache_node->outputs) { - all_args.push_back(arg); - } - using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); - std::unordered_map binds; - cache_node->funcs = tvm::LowerSchedule(spair.first, all_args, cache_node->func_name, binds); - value->cached_func = CachedFunc(cache_node); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; return value; } - /*! - * \brief Get unique name from name. - * \param name The orginal name. - * \return Updated name which is unique. - */ - std::string GetUniqueName(std::string name) { - for (size_t i = 0; i < name.length(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - while (true) { - auto it = name_map_.find(name); - if (it == name_map_.end()) { - name_map_[name] = 1; - return name; - } else { - std::ostringstream os; - os << name << "_" << it->second; - ++(it->second); - name = os.str(); - } - } - return name; - } + /*! \brief compiler cache lock*/ std::mutex mutex_; /*! \brief internal name map to get an unique name */ @@ -874,10 +302,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { - return runtime::get_name_mangled(mod_name, name); - }; - return self->Lower(key, mangle_fn); + return self->Lower(key, mod_name); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index f766fcf97ea7..4afdc6d30485 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -19,8 +19,12 @@ /*! * \file relay/backend/compile_engine.h - * \brief Internal compialtion engine handle function cache. - * and interface to low level code generation. + * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * This layer represents the older design of the Relay compilation flow and is being deprecated + * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of + * Relay functions. + * */ #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ @@ -36,157 +40,12 @@ #include #include +#include "te_compiler_cache.h" + namespace tvm { namespace relay { -/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ -enum ShapeFuncParamState { - kNoNeed = 0, - kNeedInputData = 1, - kNeedInputShape = 2, - kNeedBoth = 3, -}; - -struct LoweredOutputNode : public Object { - /*! \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The implementation used to compute the output */ - OpImplementation implementation; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("outputs", &outputs); - v->Visit("implementation", &implementation); - } - - static constexpr const char* _type_key = "relay.LoweredOutput"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); -}; - -class LoweredOutput : public ObjectRef { - public: - TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); - - TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); -}; - -/*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Object { - /* \brief compiled target */ - tvm::Target target; - /*! \brief Function name */ - std::string func_name; - /* \brief The inputs to the function */ - tvm::Array inputs; - /* \brief The outputs to the function */ - tvm::Array outputs; - /*! \brief The schedule to the function */ - te::Schedule schedule; - /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule(Map({})); - - /*! \brief Parameter usage states in the shape function. */ - tvm::Array shape_func_param_states; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("target", &target); - v->Visit("func_name", &func_name); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("schedule", &schedule); - v->Visit("funcs", &funcs); - v->Visit("shape_func_param_states", &shape_func_param_states); - } - - static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); -}; - -class CachedFunc : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); -}; - -class CCacheKey; -/*! \brief Compile cache key */ -class CCacheKeyNode : public Object { - public: - /*! \brief The source function to be lowered. */ - Function source_func; - /*! \brief The hardware target.*/ - Target target; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source_func", &source_func); - v->Visit("target", &target); - } - /*! \return The hash value of CCacheKey. */ - inline size_t Hash() const; - /*! - * \brief check content equality - * \param other The other value. - * \return The result of equality check. - */ - inline bool Equal(const CCacheKeyNode* other) const; - - static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); - - private: - /*! - * \brief internal cached hash value. - */ - mutable size_t hash_{0}; -}; - -/*! \brief cache entry used in compile engine */ -class CCacheKey : public ObjectRef { - public: - CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} - - /*! - * \brief The constructor - * \param source_func The source function. - * \param target The target device. - */ - TVM_DLL CCacheKey(Function source_func, Target target); - - const CCacheKeyNode* operator->() const { return static_cast(get()); } - // comparator - inline bool operator==(const CCacheKey& other) const { - ICHECK(defined() && other.defined()); - return (*this)->Equal(other.operator->()); - } - using ContainerType = CCacheKeyNode; -}; - -/*! \brief Node container for compile cache. */ -class CCacheValueNode : public Object { - public: - /*! \brief The corresponding function */ - CachedFunc cached_func; - /*! \brief Result of Packed function generated by JIT */ - PackedFunc packed_func; - /*! \brief usage statistics */ - int use_count{0}; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cached_func", &cached_func); - v->Visit("use_count", &use_count); - } - static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); -}; - -/*! \brief cache entry used in compile engine */ -class CCacheValue : public ObjectRef { - public: - CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { return static_cast(get_mutable()); } - const CCacheValueNode* operator->() const { return static_cast(get()); } - using ContainerType = CCacheValueNode; -}; +using namespace tvm::relay::tec; /*! * \brief Backend compilation engine for @@ -199,10 +58,18 @@ class CompileEngineNode : public Object { /*! * \brief Get lowered result. * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions + * \param mod_name The mangling function for mangling names. * \return The result. */ virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \param mod_name The module name to mangle the functions. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. @@ -242,49 +109,7 @@ class CompileEngine : public ObjectRef { TVM_DLL static CompileEngine& Global(); }; -/*! - * \brief Create schedule for target. - * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \return Pair of schedule and cache. - * The funcs field in cache is not yet populated. - */ -CachedFunc CreateSchedule(const Function& source_func, const Target& target); - -/*! - * \brief Check if the type is dynamic. - * \param ty The type to be checked. - * \return The result. - */ -bool IsDynamic(const Type& ty); - -// implementations -inline size_t CCacheKeyNode::Hash() const { - if (hash_ != 0) return hash_; - // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); - if (hash_ == 0) hash_ = 1; - return hash_; -} - -inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (Hash() != other->Hash()) return false; - return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); -} - } // namespace relay } // namespace tvm -namespace std { -// overload hash -template <> -struct hash<::tvm::relay::CCacheKey> { - size_t operator()(const ::tvm::relay::CCacheKey& key) const { - ICHECK(key.defined()); - return key->Hash(); - } -}; -} // namespace std #endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index bca8e8244093..cc54a52be200 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -35,11 +36,13 @@ #include #include -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); namespace backend { class GraphNode; @@ -52,7 +55,6 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using TargetsMap = std::unordered_map; /*! \brief Node types */ enum GraphNodeType { @@ -176,112 +178,86 @@ class GraphOpNode : public GraphNode { const std::string op_type_name_{"tvm_op"}; }; -/*! \brief Code generator for graph executor */ +/*! \brief Code generator for the graph executor, produces a module containing the graph JSON, + * module, and parameters. + */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: - GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { - compile_engine_ = CompileEngine::Global(); + GraphExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets) : mod_(mod) { targets_ = targets; } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to relay primitive functions - */ - void UpdateMainWorkspaceSize(const Function& func) { - // This is a Map> - std::unordered_map> sid_workspace; - // This is a Map - std::unordered_map device_io; - // This is a Map - std::unordered_map device_consts; - - // Initialize the maps to zero - for (const auto& kv : storage_device_map_) { - auto sids = kv.second[0]; - auto devices = kv.second[1]; - CHECK_EQ(sids.size(), devices.size()); - for (uint32_t i = 0; i < sids.size(); i++) { - sid_workspace[devices[i]][sids[i]] = 0; - device_io[devices[i]] = 0; - device_consts[devices[i]] = 0; - } - } + StorageInfo GetStorageInfo(const Expr& e) { + size_t count = memory_plan_->expr_to_storage_info.count(e); + ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; + auto storage_info = memory_plan_->expr_to_storage_info[e]; + return storage_info; + } - // Collect sizes of tensors - for (const auto& kv : storage_device_map_) { - auto size_bytes = CalculateRelayExprSizeBytes(kv.first->checked_type()); - auto sids = kv.second[0]; - auto devices = kv.second[1]; - if (kv.first->IsInstance()) { - for (const auto& dev : devices) { - device_consts[dev] += size_bytes; - } - continue; - } else if (kv.first->IsInstance() || kv.first == func->body) { - for (const auto& dev : devices) { - device_io[dev] += size_bytes; - } - continue; - } - for (uint32_t i = 0; i < sids.size(); i++) { - // Here we record the largest size of the tensor - // that share the same storage id, because storage_id will - // be shared between multiple tensors that are not live simultaneously. - if (size_bytes > sid_workspace[devices[i]][sids[i]]) { - sid_workspace[devices[i]][sids[i]] = size_bytes; - } - } - } + LoweredOutput Codegen(relay::Function func, String mod_name) { + mod_name_ = mod_name; - // This is a Map - std::unordered_map device_workspace; - // Once we know the sizes of sids, we need to accumulate per device - for (const auto& dev_sid_size : sid_workspace) { - auto dev = dev_sid_size.first; - device_workspace[dev] = 0; - for (const auto& sid_size : dev_sid_size.second) { - device_workspace[dev] += sid_size.second; - } - } + // TODO(@jroesch): we need to split device planning and memory planning + // first we run device assignment, then we perform lowering, and then + // storage planning in ideal world. + + memory_plan_ = GraphPlanMemory(func); + + // This first phase moves from implicit use of compile engine, + // to instead explicitly lowering the incoming IRModule, and then + // performing the preexisting graph executor code generation phase. + IRModule mod = IRModule::FromExpr(func); + + // Build a map from each operation to device. + tec::DeviceMap device_context_map; + for (const auto& it : memory_plan_->expr_to_storage_info) { + auto expr = it.first; + auto storage_info = it.second; + auto device_types = storage_info->device_types; + // CHECK_EQ(device_types.size(), 1); + tvm::Device dev; + dev.device_id = 0; + dev.device_type = device_types[0]; + device_context_map.insert({expr, dev}); + } + + auto lowered_module = tec::LowerTE( + mod, targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - for (const auto& dev_and_size : device_workspace) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->workspace_sizes.Set(tgt, dev_and_size.second); - fi_node->relay_primfuncs.Set(tgt, func); - } - for (const auto& dev_and_size : device_io) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->io_sizes.Set(tgt, dev_and_size.second); - } - for (const auto& dev_and_size : device_consts) { - auto tgt = GetTargetFromInteger(dev_and_size.first); - fi_node->constant_sizes.Set(tgt, dev_and_size.second); - } + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + }); - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); - } + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); + auto main_module = lowered_module.main_module; + main_module = relay::transform::InferType()(main_module); + relay::Function main_func = Downcast(main_module->Lookup("main")); + + // Now that we have lowered all operators to TIR code, we can proceed with compilation. + // + // We need to unfortunately re-plan as the previous results have been invalidated by lowering + // we will fix this in future refactors. + memory_plan_ = GraphPlanMemory(main_func); + + // The graph planner also can not handle planning calls to global variables to we must remap - LoweredOutput Codegen(relay::Function func, String mod_name) { - auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); - storage_device_map_ = (*pf)(func); - mod_name_ = mod_name; - UpdateMainWorkspaceSize(func); // First we convert all the parameters into input nodes. - for (auto param : func->params) { + for (auto param : main_func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } - heads_ = VisitExpr(func->body); + + heads_ = VisitExpr(main_func->body); std::ostringstream os; + dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; @@ -292,17 +268,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(param_storage_ids_[param.first]), param.second))); } - - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); ret.function_metadata = std::move(function_metadata_); + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; return ret; } @@ -331,20 +299,18 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator AddNode(GraphObjectPtr node, Expr expr) { auto checked_type = expr->checked_type(); - size_t count = storage_device_map_.count(expr); - ICHECK_GT(count, 0) << "Expr is not existing in storage plan"; - auto storage_device_info = storage_device_map_[expr]; - ICHECK_EQ(storage_device_info.size(), 3); + + auto storage_info = GetStorageInfo(expr); // storage - std::vector storage_info; - for (auto& v : storage_device_info[0]) { - storage_info.push_back(v->value); + std::vector storage_ids; + for (auto v : storage_info->storage_ids) { + storage_ids.push_back(v); } - node->attrs_["storage_id"] = std::move(storage_info); + node->attrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (auto& v : storage_device_info[1]) { - device_types.push_back(v->value); + for (auto v : storage_info->device_types) { + device_types.push_back(static_cast(v)); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { @@ -404,7 +370,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorvalue; + param_storage_ids_[name] = GetStorageInfo(expr)->storage_ids[0]; params_[name] = op->data; return to_return; } @@ -420,8 +386,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, const std::string& op_name, - const std::string& func_name, GraphAttrs attrs) { + bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { + StorageInfo lit = GetStorageInfo(lhs); + StorageInfo rit = GetStorageInfo(rhs); + int64_t lhs_storage_id = lit->storage_ids[0]; + int64_t rhs_storage_id = rit->storage_ids[0]; + return lhs_storage_id == rhs_storage_id; + } + + std::vector GraphAddCallNode(const CallNode* op, const std::string& func_name, + GraphAttrs attrs) { std::vector inputs; for (auto arg : op->args) { auto res = VisitExpr(arg); @@ -429,161 +403,52 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(op)); - } - bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { - auto lit = storage_device_map_.find(lhs); - auto rit = storage_device_map_.find(rhs); - ICHECK(lit != storage_device_map_.end()); - ICHECK(rit != storage_device_map_.end()); - int64_t lhs_storage_id = ((*lit).second)[0][0]->value; - int64_t rhs_storage_id = ((*rit).second)[0][0]->value; - return lhs_storage_id == rhs_storage_id; - } + /// An adapted version of the storage optimization for the time being. + bool reshape_only = false; + if (op->attrs.defined()) { + if (auto tir_call_attrs = op->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + if (metadata.count(attr::kReshapeOnly) && + Downcast(metadata[attr::kReshapeOnly])->value == 1) { + reshape_only = true; + } - /*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ - Target GetTargetFromInteger(int64_t dev_type) { - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - return (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(dev_type); - } - if (targets_.count(dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - return targets_[dev_type]; - } - } + auto relay_attrs = Downcast(tir_call_attrs->metadata["relay_attrs"]); - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = relay_target->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; + for (auto p : relay_attrs->dict) { + if (p.second.as()) { + attrs[p.first] = std::string(Downcast(p.second)); } } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node)); - } - - std::vector VisitExpr_(const CallNode* op) override { - Expr expr = GetRef(op); - Function func; - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); - } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); - } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; } - // Copy attrs from function into the graph node - // For now we only handle strings - GraphAttrs attrs; - for (auto p : func->attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); - } + if (reshape_only && ShareSameStorage(GetRef(op), op->args[0])) { + auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); + return AddNode(node, GetRef(op)); } - auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); - auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); - Target target; - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = (*pf0)(func, target); - CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name, attrs); - } - - // In the current flat memory allocation scenario - // the flat memory allocator can always allocate input - // and output of the reshape to the same memory, we can turn reshape only - // function to a nop. - // - // NOTE that for non-flat memory this is not necessarily true. - // - // TODO(tvm-team) Update checks of flat memory enablement when we support - // opaque-nd memory planning to skip this path. - if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) { - return GraphAddCallNode(op, "reshape_nop", "__nop", attrs); - } + // Compute the operator name, because we used the get unique name when generating the kernel. + auto op_name = _GetUniqueName(func_name); + auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); + return AddNode(node, GetRef(op)); + } - ICHECK_GE(storage_device_map_.count(expr), 0); - auto& device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; - target = GetTargetFromInteger(call_dev_type); - // Normal Relay Function + std::vector VisitExpr_(const CallNode* call_node) override { + relay::Call call = GetRef(call_node); + if (auto global_node = call->op.as()) { + auto prim_fn_name = global_node->name_hint; - CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_); - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); + return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); + } else { + ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to have their callee " + "normalized to a GlobalVar but found a " + << call->GetTypeKey() << "." + << "AST: " << PrettyPrint(call) << PrettyPrint(call) << std::endl; + return {}; } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name, - attrs); } std::vector VisitExpr_(const LetNode* op) override { @@ -714,7 +579,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> var_map_; /*! \brief target device */ - TargetsMap targets_; + tec::TargetMap targets_; /*! * \brief parameters (i.e. ConstantNodes found in the graph). * These are take as inputs to the GraphExecutor. @@ -724,17 +589,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + StaticMemoryPlan memory_plan_; /*! \brief the module name we use to mangle the function names */ String mod_name_; - /*! \brief lowered funcs */ - std::unordered_map lowered_funcs_; - /*! \brief lowered funcs */ + /*! \brief function metadata */ Map function_metadata_; /*! \brief name map */ std::unordered_map name_map_; - /*! \brief compile engine */ - CompileEngine compile_engine_; }; class GraphExecutorCodegenModule : public runtime::ModuleNode { @@ -747,11 +608,11 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; - TargetsMap targets; + tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); ICHECK(dev_type); - targets[dev_type->value] = it.second; + targets[static_cast(dev_type->value)] = it.second; } codegen_ = std::make_shared(reinterpret_cast(mod), targets); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 351469d6e1ca..93c823d8a007 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -23,15 +23,20 @@ * the program in the graph executor. */ #include +#include #include #include +#include #include #include "../../support/arena.h" +#include "./utils.h" namespace tvm { namespace relay { +using backend::StaticMemoryPlan; +using backend::StorageInfo; using IntegerArray = Array; struct StorageToken { @@ -48,6 +53,18 @@ struct StorageToken { int64_t storage_id{-1}; }; +std::ostream& operator<<(std::ostream& os, StorageToken tok) { + return os << "StorageToken: " << std::endl + << "ref_counter: " << tok.ref_counter << std::endl + << "max_bytes: " << tok.max_bytes << std::endl + << "tttype: " << tok.ttype + << std::endl + // ok idk how to print this properly + << "tttype shape: " << tok.ttype->shape << std::endl + << "device_type: " << tok.device_type << std::endl + << "storage_id: " << tok.storage_id << std::endl; +} + class StorageAllocaBaseVisitor : public ExprVisitor { public: // run the visitor on a function. @@ -114,7 +131,8 @@ class StorageAllocaBaseVisitor : public ExprVisitor { const std::vector& GetToken(const Expr& expr) { this->VisitExpr(expr); auto it = token_map_.find(expr.operator->()); - ICHECK(it != token_map_.end()); + ICHECK(it != token_map_.end()) + << "Expression: `" << PrettyPrint(expr) << "` not found in storage map."; return it->second; } /*! @@ -168,6 +186,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { void VisitExpr_(const CallNode* op) final { // create token for the call node. CreateToken(op, true); + // for each input, visit argument token. for (Expr arg : op->args) { for (StorageToken* tok : GetToken(arg)) { @@ -196,31 +215,32 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map > Plan(const Function& func) { + StaticMemoryPlan Plan(const Function& func) { prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); this->Run(func); // The value of smap contains two integer arrays where the first array // contains the planned storage ids and the second holds the device types. - Map > smap; + Map smap; int num_annotated_nodes = 0; int num_nodes = 0; for (const auto& kv : token_map_) { - std::vector storage_ids; - std::vector device_types; - std::vector sid_sizes_byte; + std::vector storage_ids; + std::vector device_types; + std::vector sid_sizes_byte; + for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; } num_nodes++; storage_ids.push_back(tok->storage_id); - device_types.push_back(tok->device_type); + device_types.push_back(static_cast(tok->device_type)); sid_sizes_byte.push_back(GetMemorySize(tok)); } - smap.Set(GetRef(kv.first), - Array({storage_ids, device_types, sid_sizes_byte})); + auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte); + smap.Set(GetRef(kv.first), storage_info); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { @@ -228,7 +248,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { << "expressions are assigned with virtual device types. Either all " "or none of the expressions are expected to be annotated."; } - return smap; + + return backend::StaticMemoryPlan(smap); } protected: @@ -279,6 +300,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { args.push_back(tok); } } + // Under the flat-memory setting. // we can force aliasing the input and output of reshape // to make it an nop. Note that this is not true @@ -288,12 +310,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // TODO(tvm-team) Update checks of flat memory enablement when we support // opaque-nd memory planning to skip this path. if (IsReshape(op)) { + // TODO(@electriclilies, jroesch): This check is failing because the size of args is 3 + // I can't figure out where the extra args are coming from, I assume it must be related + // to the relay_attrs field we added to the TIRCallArgs, but I don't know where / how + // that's happening... ICHECK_EQ(args.size(), 1U); ReuseInputToken(op, args[0]); } else { // create token for the call node. CreateToken(op, true); } + // check if there is orphaned output that can be released immediately. for (StorageToken* tok : token_map_.at(op)) { CheckForRelease(tok); @@ -320,6 +347,15 @@ class StorageAllocator : public StorageAllocaBaseVisitor { if (const auto* fn = call->op.as()) { return fn->HasNonzeroAttr(attr::kReshapeOnly); } + + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + return false; } /*! @@ -419,9 +455,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map > GraphPlanMemory(const Function& func) { - return StorageAllocator().Plan(func); -} +StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index eeba010dc164..6ebb17e93eca 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -32,7 +32,9 @@ #include #include +#include "../transforms/pass_utils.h" #include "compile_engine.h" +#include "te_compiler.h" namespace tvm { namespace relay { @@ -213,9 +215,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: Interpreter(IRModule mod, Device device, Target target) - : mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) { - engine_ = CompileEngine::Global(); - } + : mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -285,7 +285,7 @@ class Interpreter : public ExprFunctor, Array ComputeDynamicShape(const Function& func, const Array& args) { CCacheKey key(func, Target("llvm")); - auto cfunc = engine_->LowerShapeFunc(key); + auto cfunc = compiler_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); std::vector values(arity); @@ -381,7 +381,7 @@ class Interpreter : public ExprFunctor, } else { m = build(cfunc->funcs, cfunc->target, Target(nullptr)); } - shape_func = m.GetFunction(cfunc->func_name); + shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); // Get output shapes @@ -484,7 +484,7 @@ class Interpreter : public ExprFunctor, out_shapes = ComputeDynamicShape(func, args); } - PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_)); + PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_)); TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); @@ -554,11 +554,11 @@ class Interpreter : public ExprFunctor, // We should not find operators after running fusion, // and operator lowering. // - // We have some functions cotaining chunks of operators + // We have some functions containing chunks of operators // which will be loaded into operator map. if (const auto* op_node = call->op.as()) { LOG(FATAL) << "found " << op_node->name - << "; operators should be removed by future passes; try " + << "; operators should have been removed by previous passes; try " "fusing and lowering"; } if (auto con = call->op.as()) { @@ -568,9 +568,9 @@ class Interpreter : public ExprFunctor, ObjectRef fn_val = Eval(call->op); if (const InterpreterClosureObj* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); - return this->Invoke(closure, args); + return Invoke(closure, args); } else if (const RecClosureObj* closure_node = fn_val.as()) { - return this->Invoke(closure_node->clos, args, closure_node->bind); + return Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position"; @@ -709,17 +709,17 @@ class Interpreter : public ExprFunctor, Target target_; // Object stack. Stack stack_; - // Backend compile engine. - CompileEngine engine_; + // TE-to-TIR lowerer (compiler). + TECompiler compiler_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; TypedPackedFunc CreateInterpreter(IRModule mod, Device device, Target target) { if (mod.defined()) { - // eta expand to support constructors in argument position - transform::Sequential seq({transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false), + transform::Sequential seq({// eta expand to support constructors in argument position + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType()}); transform::PassContext pass_ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc new file mode 100644 index 000000000000..208e6356355d --- /dev/null +++ b/src/relay/backend/te_compiler.cc @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "te_compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "te_compiler.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +// TODO(@jroesch, @csullivan): declare directly elsewhere +backend::StaticMemoryPlan GraphPlanMemory(const Function& func); + +namespace tec { + +using namespace tvm::relay::transform; + +TVM_REGISTER_OBJECT_TYPE(TECompilerNode); + +class TECompilerImpl : public TECompilerNode { + public: + // Lower the function. + CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { + return LowerInternal(key, mangle_fn)->cached_func; + } + + CachedFunc Lower(const CCacheKey& key, const String mod_name) { + auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; + + return Lower(key, mangle_fn); + } + + // For now, build one module per function. + PackedFunc JIT(const CCacheKey& key) final { + auto mangle_fn = [](String name) { return name; }; + CCacheValue value = LowerInternal(key, mangle_fn); + if (value->packed_func != nullptr) { + return value->packed_func; + } + auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); + value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); + return value->packed_func; + } + + CachedFunc LowerShapeFunc(const CCacheKey& key) final { + return LowerShapeFuncInternal(key)->cached_func; + } + + Map GetLoweredFunctions() { + Map lowered_functions; + for (const auto& it : cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } + return lowered_functions; + } + + Array LowerExternalFunctions() { + Array ret; + std::unordered_map cached_symbol; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + ICHECK(src_func.defined()); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); + std::string code_gen_name = code_gen.value(); + cached_ext_funcs.push_back(it.first); + + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" + << AsText(src_func, false); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + ICHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + ICHECK(pf) << "Failed to find the codegen tool for " << ext_name; + // No need to keep compiler attribute at this point, functions have been + // extracted for specific codegen. + src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); + runtime::Module ext_mod = (*pf)(src_func); + + ICHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + + void Clear() final { cache_.clear(); } + + // List all items in the cache. + Array ListItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + + /*! + * \brief Get the cache key of the function that is being lowered currently + * \return the cache key + */ + CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } + + private: + // implement lowered func + CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = cache_.find(key); + if (it != cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 1; + cache_[key] = value; + } + cur_ccache_key_ = key; + + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + if (key->source_func->GetAttr(attr::kCompiler).defined()) { + auto ir_module = IRModule(); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "External function has not been attached a name yet."; + auto func_name = GetUniqueName(name_node.value(), &name_map_); + auto target = Target("ext_dev"); + auto global_var = GlobalVar(func_name); + global_var->checked_type_ = key->source_func->checked_type(); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + return value; + } + + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { + auto mangled = mangle_fn(name); + return GetUniqueName(mangled, &name_map_); + }); + + // Skip lowering for device copy node. + const Expr body = (key->source_func)->body; + if (const CallNode* call_node = body.as()) { + if (call_node->attrs.as()) { + value->cached_func = cfunc; + return value; + } + } + + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); + value->cached_func = cfunc; + return value; + } + + // implement lowered shape func + CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = shape_func_cache_.find(key); + if (it != shape_func_cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_object()); + value->use_count = 0; + shape_func_cache_[key] = value; + } + // Enforce use the target. + With target_scope(key->target); + + ICHECK(!value->cached_func.defined()); + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { + return GetUniqueName(name, &name_map_); + }); + + value->cached_func = cached_func; + return value; + } + + std::unordered_map GetOpWeights() { + std::unordered_map weights; + for (auto pair : cache_) { + auto value = pair.second; + auto name = value->cached_func->prim_fn_var->name_hint; + weights[name] = value->use_count; + } + return weights; + } + + /*! \brief compiler cache lock*/ + std::mutex mutex_; + /*! \brief internal name map to get an unique name */ + std::unordered_map name_map_; + /*! \brief internal compiler cache */ + std::unordered_map cache_; + /*! \brief internal compiler cache for shape funcs */ + std::unordered_map shape_func_cache_; + /*! \brief the cache key of the function that is being lowered currently*/ + CCacheKey cur_ccache_key_; +}; + +TECompiler::TECompiler() { + auto object = make_object(); + data_ = object; +} + +using AnalysisRemapping = std::unordered_map; + +std::tuple IsDeviceCopy(const Function& func) { + if (auto call_node = func->body.as()) { + if (auto op_node = call_node->op.as()) { + if (op_node->name == "device_copy") { + auto attrs = call_node->attrs.as(); + auto dst = attrs->dst_dev_type; + auto src = attrs->src_dev_type; + return std::tuple(true, src, dst); + } + } + } + + return std::tuple(false, -1, -1); +} + +class LowerTensorExpr : public ExprMutator { + public: + LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map, + ProcessFn process_fn, const String& module_name, TECompiler compiler) + : module_(module), + targets_(targets), + device_context_map_(device_ctx_map), + process_fn(process_fn), + module_name_(module_name), + compiler_(compiler) {} + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + Function func; + + if (call->op.as()) { + func = GetRef(call->op.as()); + } else { + return ExprMutator::VisitExpr_(call); + } + + if (!func->HasNonzeroAttr(attr::kPrimitive)) { + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func); + return ExprMutator::VisitExpr_(call); + } + + // Process inputs. + Array args; + for (size_t i = 0; i < expr->args.size(); i++) { + args.push_back(VisitExpr(expr->args[i])); + } + + Target target; + + if (func->GetAttr(attr::kCompiler).defined()) { + target = Target("ext_dev"); + CCacheKey key = CCacheKey(func, target); + CachedFunc ext_func = compiler_->Lower(key, module_name_); + ICHECK(ext_func.defined()) << "Lowering returned undefined function for " + << ext_func->prim_fn_var->name_hint; + + Map prim_fns; + + for (auto prim_fn : ext_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto ret_call = Call(ext_func->prim_fn_var, args, {}); + return std::move(ret_call); + } + + ICHECK_GE(device_context_map_.count(expr), 0) + << "Could not find an entry in the device context map for " << PrettyPrint(expr) + << "The memory planning was either not performed for this precise node, or there is bug " + "in the memory planner."; + + auto& device_context = this->device_context_map_[expr]; + target = GetTargetFromInteger(device_context.device_type, targets_); + // Non-External Relay Function + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key, module_name_); + + Map prim_fns; + + for (auto prim_fn : lowered_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + } + + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target); + + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn(func_with_metadata); + + auto tir_call_attrs = make_object(); + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } + + auto device_copy = IsDeviceCopy(func); + if (std::get<0>(device_copy)) { + auto source_device = std::get<1>(device_copy); + auto dst_device = std::get<2>(device_copy); + tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); + tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); + } + + tir_call_attrs->metadata.Set("relay_attrs", func->attrs); + + Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs)); + return std::move(ret_call); + } + + IRModule module_; + TargetMap targets_; + DeviceMap device_context_map_; + ProcessFn process_fn; + String module_name_; + TECompiler compiler_; +}; + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { + if (targets.size() == 1) { + // The homogeneous execution case, return the only target. + const auto& it = targets.begin(); + return (*it).second; + } else { + // The heterogeneous execution case, return the target associated with the + // given device type. + // If "dev_type" equals to 0, the device name only can be got from + // "targets", and it may not be "llvm", so here just set it to "unknown". + std::string dev_name = "unknown"; + if (dev_type != 0) { + dev_name = runtime::DeviceName(dev_type); + } + + if (targets.count(dev_type) == 0) { + std::stringstream msg; + msg << "No target is specified for provided device name: `" << dev_name << "`\n\n" + << dev_name << " mapped to device type (" << dev_type + << ") which was not found in the target map.\n" + << "Availible targets: \n"; + for (auto target : targets) { + msg << " " << target.first << "-> " << target.second << "\n"; + } + LOG(FATAL) << msg.str(); + } + return targets[dev_type]; + } +} + +/*! + * \brief Update the "main" control function's metadata + * + * \param mod The module + * \param targets Map of targets + * \return function_infos Function info for each function in the module + */ + +backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets, + Map storage_info_map) { + CHECK_EQ(mod->functions.size(), 1) + << "There should only be one function in the module passed to UpdateMainWorkspaceSize"; + Function func = Downcast(mod->Lookup("main")); + + // This is a Map> + std::unordered_map, EnumClassHash> sid_workspace; + // This is a Map + std::unordered_map device_io; + // This is a Map + std::unordered_map device_consts; + + // Initialize the mapping from all storage identifiers to workspace sizes, + // the amount of device io, and the device constants. + for (const auto& kv : storage_info_map) { + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + CHECK_EQ(storage_ids.size(), devices.size()); + for (uint32_t i = 0; i < devices.size(); i++) { + sid_workspace[devices[i]][storage_ids[i]] = 0; + device_io[devices[i]] = 0; + device_consts[devices[i]] = 0; + } + } + + // Iterate the storage map to compute all the tensor sizes in the program. + // There are 3 cases in this code: + // + // First we need to compute the sizes of all + // inline constants. + // + // Second we compute the size of any bound variable as these are input and output + // sizes of the program. + // + // Finally for all other expressions we check which storage identifier they have + // been assigned and we compute the maximal size of the storage, as tensors can + // share storage with other tensors which are the same size or larger. + // + // In this final case there is only one allocation for all tensors which share storage + // which will be the maximal size of all tensors which were assigned to it. + for (const auto& kv : storage_info_map) { + Expr expr = kv.first; + int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type()); + backend::StorageInfo storage_info = kv.second; + std::vector storage_ids = storage_info->storage_ids; + std::vector devices = storage_info->device_types; + + if (expr->IsInstance()) { + for (const auto& dev : devices) { + device_consts[dev] += size_bytes; + } + continue; + } else if (expr->IsInstance() || expr.same_as(func->body)) { + CHECK_GE(devices.size(), 1) << "must be at least one device"; + for (const auto& dev : devices) { + device_io[dev] += size_bytes; + } + continue; + } + + // TODO(@electriclilies): This code is never being called which means sid_workspace is not + // updated.. This means that storage info is probably not being created correctly. Or is not + // equivalent to what was here previously + for (uint32_t i = 0; i < storage_ids.size(); i++) { + // Here we record the largest size of the tensor + // that share the same storage id, because storage_id will + // be shared between multiple tensors that are not live simultaneously. + if (size_bytes > sid_workspace[devices[i]][storage_ids[i]]) { + sid_workspace[devices[i]][storage_ids[i]] = size_bytes; + } + } + } + + // This is a Map + std::unordered_map device_workspace; + // Once we know the sizes of sids, we need to accumulate per device + for (const auto& dev_sid_size : sid_workspace) { + auto dev = dev_sid_size.first; + device_workspace[dev] = 0; + for (const auto& sid_size : dev_sid_size.second) { + device_workspace[dev] += sid_size.second; + } + } + + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + // Initialize all target workspaces to zero + for (const auto& kv : targets) { + auto tgt = kv.second; + workspace_sizes.Set(tgt, 0); + } + + for (const auto& dev_and_size : device_workspace) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + workspace_sizes.Set(tgt, dev_and_size.second); + relay_primfuncs.Set(tgt, func); + } + for (const auto& dev_and_size : device_io) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + io_sizes.Set(tgt, dev_and_size.second); + } + + for (const auto& dev_and_size : device_consts) { + auto tgt = GetTargetFromInteger(dev_and_size.first, targets); + constant_sizes.Set(tgt, dev_and_size.second); + } + + return backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs, + relay_primfuncs); +} + +// TODO(@electriclilies): Is the function passed in here relay_func?? +// Also should this be inlined? +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata) { // NOLINT(*) + // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored + // there Now the goal is to take only one func because process_fn should be controlling the + // iteration However, to do the workspace calculations we need the primfuncs. So process_fn needs + // to either access the cached funcs or be directly passed primfuncs This is bad and ideally we + // don't want process_fn to look at primfuncs There's also the question now of what the function + // metadatas are and how they are used if we can do something else to replicate the behavior of + // the function metadatas that might be good (ie annotating functions or something). + Map workspace_sizes; + Map io_sizes; + Map constant_sizes; + Map tir_primfuncs; + Map relay_primfuncs; + + Optional> prim_fns = + relay_func->GetAttr>("prim_funcs"); + CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; + + Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); + CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; + + Optional relay_target = relay_func->GetAttr("target"); + CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; + + for (const auto& kv : prim_fns.value()) { + auto prim_fn = Downcast(kv.second); + CHECK(prim_fn.defined()) << "the primitive function must be defined"; + + auto workspace_byte_alignment = + relay_target.value()->GetAttr("workspace_byte_alignment").value_or(16); + + Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); + + // Workspace sizes + Target prim_fn_target; + if (prim_fn->attrs->dict.count("target")) { + prim_fn_target = Downcast(prim_fn->attrs->dict["target"]); + } else { + prim_fn_target = relay_target.value(); + } + + workspace_sizes.Set(prim_fn_target, workspace_size); + + // Calculating size for I/O + for (auto const& param : prim_fn->params) { + auto p_shape = prim_fn->buffer_map[param]->shape; + int num_of_elements = 1; + for (const auto& dim_index_expr : p_shape) { + if (dim_index_expr->IsInstance()) { + num_of_elements *= dim_index_expr.as()->value; + } else { + // If shape is dynamic, we cannot calculate workspace in compile time. + num_of_elements = 0; + } + } + int element_size = prim_fn->buffer_map[param]->dtype.bytes(); + io_sizes.Set(prim_fn_target, element_size * num_of_elements); + } + + constant_sizes.Set(prim_fn_target, 0); + tir_primfuncs.Set(prim_fn_target, prim_fn); + relay_primfuncs.Set(prim_fn_target, relay_func); + } + + backend::FunctionInfo fi = backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, + tir_primfuncs, relay_primfuncs); + + // The primitive function name here corresponds to the string we will use to generate + // this Relay function at the low level. + function_metadata.Set(prim_fn_var.value()->name_hint, fi); +} + +LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + TECompiler compiler; + + CHECK_EQ(module->functions.size(), 1) + << "There should only be one function in the module passed to LowerTE"; + + auto pass = CreateFunctionPass( + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExpr lower_te(module, targets, device_context_map, process_fn, module_name, + compiler); + return Downcast(lower_te.VisitExpr(func)); + }, + 0, "LowerTensorExpr", {}); + + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + backend::FunctionInfo func_info = + UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + + auto updated_module = pass(module); + + // A temporary solution until we can rewrite the auto-scheduler task extraction code to work + // in a more reasonable way. + if (backend::IsAutoSchedulerEnabled()) { + const auto* te_compiler_update_weights = + runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights"); + + ICHECK(te_compiler_update_weights != nullptr) + << "auto_scheduler.relay_integration.te_compiler_update_weights"; + + Map weight_map; + + for (auto pair : compiler->GetOpWeights()) { + weight_map.Set(pair.first, pair.second); + } + + (*te_compiler_update_weights)(weight_map); + } + + LoweredModule lowered_module; + lowered_module.main_module = updated_module; + lowered_module.per_target_module = compiler->GetLoweredFunctions(); + lowered_module.external_mods = compiler->LowerExternalFunctions(); + lowered_module.main_func_info = func_info; + return lowered_module; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h new file mode 100644 index 000000000000..8376b99d79cd --- /dev/null +++ b/src/relay/backend/te_compiler.h @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tir_compiler.h + * * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. + * + * + * This represents the new design of the Relay compilation flow and will replace the interface + * contained in compile_engine.h as we migrate towards a standard pass based lowering of + * Relay functions. + * + * This files provides an internal API which lowers Relay programs to components which + * can be combined with TVM produced kernels to compile an entire program. + * + * The result of lowering contains a combination of `runtime::Module`s produced by external + * compilers and a set of lowered PrimFns which can be code generated for targets. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" +#include "../transforms/pass_utils.h" +#include "./te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +// This class is needed to avoid a GCC 5 bug that prevents maps containing enums +// from being compiled. If i386 GCC version is increased, we can remove it. +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } +}; + +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// we should a version of context which works in Map +using TargetMap = std::unordered_map; +using DeviceMap = + std::unordered_map; +using ProcessFn = std::function; + +/*! + * \brief A compiler which lowers primitive Relay functions to tensor expressions + * and schedules them into TIR functions. + */ +class TECompilerNode : public Object { + public: + /*! \brief destructor */ + virtual ~TECompilerNode() {} + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; + + /*! + * \brief Get lowered result. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; + + /* Return all functions which have been lowered by the compiler, keyed by target. */ + virtual Map GetLoweredFunctions() = 0; + + /*! + * \brief Just in time compile to get a PackedFunc. + * \param key The key to the cached function. + * \return The result. + */ + virtual PackedFunc JIT(const CCacheKey& key) = 0; + /*! + * \brief Lower the shape function. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + + virtual std::unordered_map GetOpWeights() = 0; + + /*! \brief clear the cache. */ + virtual void Clear() = 0; + + void VisitAttrs(AttrVisitor*) {} + + static constexpr const char* _type_key = "relay.TECompiler"; + TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object); +}; + +/*! \brief cache entry used in compile engine */ +class TECompiler : public ObjectRef { + public: + TECompiler(); + explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} + TECompilerNode* operator->() { return static_cast(get_mutable()); } + using ContainerType = TECompilerNode; +}; + +/*! \brief The result of lowering a module, for now we need to pass an aggregate data structure + * which contains more then a single module in order to interact with the today API. + */ +struct LoweredModule { + /*! \brief The module which contains the Relay code. */ + IRModule main_module; + /*! \brief The module which contains per target code. */ + Map per_target_module; + /*! \brief The external runtime modules which must be combined with the lowered code. */ + Array external_mods; + // TODO(@electriclilies): THis might need to become a map + /*! \brief The info for this function (not sure what a better description is??) + * + */ + backend::FunctionInfo main_func_info; +}; + +/*! + * \brief A function to create the function metadata for an input function (ie calculate buffer + * input/output sizes) + * \param relay_func The function to calculate function metadata for + * \param function_metadata The map that stores all the function metadatas + */ +void UpdateFunctionMetadata(Function relay_func, + Map& function_metadata); // NOLINT(*) + +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ +Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); + +/*! \brief Lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. + * + * \param compiler The TE-to-TIR compliler (which caches lowered functions) + * \param module The IRModule. + * \param targets The mapping for devices to targets. + * \param device_map An analysis result mapping each sub-expression to a device. + * \return The lowered module, see above. + */ +// TODO(@electriclilies): Not sure if this default initialization is correct... +LoweredModule LowerTE( + const IRModule& module, TargetMap targets, DeviceMap device_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + ProcessFn process_fn = [](Function f) {}); + +} // namespace tec +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_H_ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc new file mode 100644 index 000000000000..7cae7fcd4b09 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.cc @@ -0,0 +1,695 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./te_compiler_cache.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_utils.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +TVM_REGISTER_NODE_TYPE(LoweredOutputNode); +TVM_REGISTER_NODE_TYPE(CachedFuncNode); +TVM_REGISTER_NODE_TYPE(CCacheKeyNode); +TVM_REGISTER_NODE_TYPE(CCacheValueNode); + +LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { + auto n = make_object(); + n->outputs = std::move(outputs); + n->implementation = std::move(impl); + data_ = std::move(n); +} + +CCacheKey::CCacheKey(Function source_func, Target target) { + auto n = make_object(); + n->source_func = std::move(source_func); + n->target = std::move(target); + data_ = std::move(n); +} + +CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, IRModule funcs) { + auto n = make_object(); + n->target = target; + n->prim_fn_var = prim_fn_var; + n->inputs = inputs; + n->outputs = outputs; + n->schedule = schedule; + n->shape_func_param_states = shape_func_param_states; + n->funcs = funcs; + data_ = std::move(n); +} + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()) + << "dimension must be less then int32_t's max value"; + ICHECK_GE(pval[0], std::numeric_limits::min()) + << "dimension must be less then int32_t's max value"; + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + // currently all 'any' we meet in shape function are non-negative. + res.push_back(val.as()->ToSizeVar()); + } else { + res.push_back(val); + } + } + return res; +} + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : public backend::MemoizedExprTranslator> { + public: + explicit ScheduleBuilder(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + } + + CachedFunc Create(const Function& prim_func, std::function renamer) { + Array fn_inputs; + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + fn_inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + readable_name_stream_ << "fused"; + auto outputs = this->VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // NB(@jroesch): unfortunately the graph runtime deals with copy in + // a totally hacky way, we really need to rectify this but this will + // have to work for now. + std::string prim_fn_name = candidate_name; + if (prim_fn_name != "__copy") { + prim_fn_name = renamer(prim_fn_name); + } + auto prim_fn_var = GlobalVar(prim_fn_name); + prim_fn_var->checked_type_ = prim_func->checked_type(); + + ICHECK(anchor_op_.defined()); + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_name, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + + // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined()) { + ICHECK(anchor_implementation_.defined()); + schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Unexpected free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + scalars_.push_back(value->op); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) + << "Only functions with a single tuple input are allowed, but " << count_tuple + << " were provided."; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + // Skip fcompute for device copy operators as it is not registered. + if (op == device_copy_op_) { + const auto* copy_input = inputs[0].operator->(); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); + outputs = lowered_out->outputs; + impl = lowered_out->implementation; + } + + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + anchor_implementation_ = impl; + } + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expected output to be a tuple type " + << PrettyPrint(call_node->checked_type()); + + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + // Set the name to `__copy`. It will be detected in graph runtime to perform + // data copy across devices. + if (op == device_copy_op_) { + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + } else { + readable_name_stream_ << '_' << op->name; + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Primitive Functions can not contain nested functions."; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } + + private: + tvm::Target target_; + Op anchor_op_; + Attrs anchor_attrs_; + int anchor_op_pattern_{0}; + OpImplementation anchor_implementation_; + std::ostringstream readable_name_stream_; + Array scalars_; + bool use_auto_scheduler_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer) { + return ScheduleBuilder(target).Create(source_func, renamer); +} + +// Creates shape function from functor. +class MakeShapeFunc : public backend::MemoizedExprTranslator> { + public: + MakeShapeFunc() {} + + CachedFunc Create(const Function& prim_func, const Target& target, + std::function renamer) { + Array inputs; + TShapeDataDependent shape_func_param_states; + + for (auto param : prim_func->params) { + param_states_[param] = kNoNeed; + Array data_inputs; + Array shape_inputs; + + auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + // Add data placeholder + Shape shape = GetShape(ttype->shape); + tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); + data_inputs.push_back(data_tensor); + // Add shape placeholder + int64_t ndim = shape.size(); + Shape sshape; + if (ndim > 0) { + sshape.push_back(tvm::Integer(ndim)); + } + tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); + shape_inputs.push_back(shape_tensor); + }; + + if (const auto* ttype = param->checked_type().as()) { + add_placeholder(ttype); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + // TODO(@icemelon): Support recursive tuple + ICHECK(tuple_type); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype); + add_placeholder(ttype); + } + } + param_data_[param] = data_inputs; + param_shapes_[param] = shape_inputs; + } + + // Setup the name; + readable_name_stream_ << "shape_func"; + + // Create the `te::Tensor`s which represent the output. + auto outputs = VisitExpr(prim_func->body); + + // Generate a name. + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // Set all the inputs correctly. + for (auto param : prim_func->params) { + int state = param_states_[param]; + shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); + if (state & kNeedInputData) { + for (auto t : param_data_[param]) { + inputs.push_back(t); + } + } + if (state & kNeedInputShape) { + for (auto t : param_shapes_[param]) { + inputs.push_back(t); + } + } + } + + auto func_name = renamer(candidate_name); + auto prim_fn_gvar = GlobalVar(func_name); + prim_fn_gvar->checked_type_ = prim_func->checked_type(); + + // generate schedule for shape func + Array out_ops; + for (auto t : outputs) { + out_ops.push_back(t->op); + } + auto schedule = te::create_schedule(out_ops); + tvm::te::AutoInlineInjective(schedule); + for (const auto& scalar : scalars_) { + auto scalar_op = scalar->op; + if (schedule->Contain(scalar_op)) { + schedule[scalar_op].compute_inline(); + } + } + + Array all_args = Array(inputs); + for (te::Tensor arg : outputs) { + all_args.push_back(arg); + } + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + + std::unordered_map binds; + IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); + + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, + ir_module); + } + + Array VisitExpr(const Expr& expr) final { + if (expr.as()) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + return ExprFunctor::VisitExpr(expr); + } + // For other case, do memoized visit + return backend::MemoizedExprTranslator>::VisitExpr(expr); + } + + Array VisitExpr_(const VarNode* var_node) final { + auto var = GetRef(var_node); + auto it = param_states_.find(var); + if (it == param_states_.end()) { + LOG(FATAL) << "Unexpected free variable " << var->name_hint(); + return {}; + } else { + ICHECK(data_dependents_per_input_.size()); + auto data_dependent = data_dependents_per_input_.back(); + if (data_dependent) { + param_states_[var] |= kNeedInputData; + return param_data_[var]; + } else { + param_states_[var] |= kNeedInputShape; + return param_shapes_[var]; + } + } + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(data_dependents_per_input_.size()); + bool data_dependent = data_dependents_per_input_.back(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependent); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); + } + return ret; + }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + if (data_dependent) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependent shape func"; + ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; + ICHECK_GT(tshape_data_dependent.count(op), 0) + << "Internal error, cannot find TShapeDataDependent for " << op->name; + + Array dep_spec = tshape_data_dependent[op]; + if (dep_spec.size() == 1) { + // This is for cases when data dependence is specified per op + // Replicate 0 or 1 flag to all arguments + for (size_t i = 1; i < call_node->args.size(); ++i) { + dep_spec.push_back(dep_spec[0]); + } + } + + // Visit all inputs + Array inputs; + int count_tuple = 0; + for (size_t i = 0; i < call_node->args.size(); ++i) { + Expr arg = call_node->args[i]; + if (arg->checked_type().as()) { + ++count_tuple; + } + data_dependents_per_input_.push_back(dep_spec[i]->value != 0); + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + data_dependents_per_input_.pop_back(); + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + // Get output ndims + auto ret_type = call_node->checked_type(); + Array out_ndims; + if (const auto* ttype = ret_type.as()) { + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } else { + auto rtype = ret_type.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(rtype); + for (size_t i = 0; i < rtype->fields.size(); ++i) { + auto ttype = rtype->fields[i].as(); + ICHECK(ttype); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); + } + } + // Call shape function + auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); + readable_name_stream_ << "_" << op->name; + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) + << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type()); + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + + private: + /*! \brief String stream for function name */ + std::ostringstream readable_name_stream_; + /*! \brief Map from parameter to its shape function usage state */ + std::unordered_map param_states_; + /*! \brief Map from parameter to list of data placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; + /*! \brief Map from parameter to list of shape placeholder */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; + /*! \brief Stack of data dependencies for shape function, specified per each op input */ + std::vector data_dependents_per_input_; + /*! \brief Scalars used in the shape function */ + Array scalars_; +}; + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer) { + return MakeShapeFunc().Create(prim_func, target, renamer); +} + +/*! + * \brief Get unique name from name. + * \param name The orginal name. + * \return Updated name which is unique. + */ +std::string GetUniqueName(std::string name, std::unordered_map* name_map_) { + for (size_t i = 0; i < name.length(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + while (true) { + auto it = name_map_->find(name); + if (it == name_map_->end()) { + (*name_map_)[name] = 1; + return name; + } else { + std::ostringstream os; + os << name << "_" << it->second; + ++(it->second); + name = os.str(); + } + } + return name; +} + +} // namespace tec +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h new file mode 100644 index 000000000000..1c7511ffd7d2 --- /dev/null +++ b/src/relay/backend/te_compiler_cache.h @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/tec_compiler_cache.h + * \brief Utilities for compiling tensor expressions inside of the Relay compiler. + */ +#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ +#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { +namespace tec { + +/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ +enum ShapeFuncParamState { + kNoNeed = 0, + kNeedInputData = 1, + kNeedInputShape = 2, + kNeedBoth = 3, +}; + +struct LoweredOutputNode : public Object { + /*! \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The implementation used to compute the output */ + OpImplementation implementation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("outputs", &outputs); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "relay.LoweredOutput"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); +}; + +class LoweredOutput : public ObjectRef { + public: + TVM_DLL LoweredOutput(tvm::Array outputs, OpImplementation impl); + + TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); +}; + +class CCacheKey; +/*! \brief Compile cache key */ +class CCacheKeyNode : public Object { + public: + /*! \brief The source function to be lowered. */ + Function source_func; + /*! \brief The hardware target.*/ + Target target; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source_func", &source_func); + v->Visit("target", &target); + } + /*! \return The hash value of CCacheKey. */ + inline size_t Hash() const; + /*! + * \brief check content equality + * \param other The other value. + * \return The result of equality check. + */ + inline bool Equal(const CCacheKeyNode* other) const; + + static constexpr const char* _type_key = "relay.CCacheKey"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); + + private: + /*! + * \brief internal cached hash value. + */ + mutable size_t hash_{0}; +}; + +/*! \brief cache entry used in compile engine */ +class CCacheKey : public ObjectRef { + public: + CCacheKey() {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief The constructor + * \param source_func The source function. + * \param target The target device. + */ + TVM_DLL CCacheKey(Function source_func, Target target); + + const CCacheKeyNode* operator->() const { return static_cast(get()); } + // comparator + inline bool operator==(const CCacheKey& other) const { + ICHECK(defined() && other.defined()); + return (*this)->Equal(other.operator->()); + } + using ContainerType = CCacheKeyNode; +}; + +/*! \brief Node container to represent a cached function. */ +struct CachedFuncNode : public Object { + /* \brief compiled target */ + tvm::Target target; + /*! \brief Primitive Function Name */ + GlobalVar prim_fn_var; + /* \brief The inputs to the function */ + tvm::Array inputs; + /* \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The schedule to the function */ + te::Schedule schedule; + /*! \brief Parameter usage states in the shape function. */ + tvm::Array shape_func_param_states; + /*! \brief The lowered functions to support the function. */ + IRModule funcs = IRModule(Map({})); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("target", &target); + v->Visit("prim_fn_var", &prim_fn_var); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("schedule", &schedule); + v->Visit("funcs", &funcs); + v->Visit("shape_func_param_states", &shape_func_param_states); + } + + static constexpr const char* _type_key = "relay.CachedFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); +}; + +class CachedFunc : public ObjectRef { + public: + CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, + tvm::Array outputs, te::Schedule schedule, + tvm::Array shape_func_param_states, + IRModule funcs = IRModule(Map({}))); + + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; + +/*! \brief Node container for compile cache. */ +class CCacheValueNode : public Object { + public: + /*! \brief The corresponding function */ + CachedFunc cached_func; + /*! \brief Result of Packed function generated by JIT */ + PackedFunc packed_func; + /*! \brief usage statistics */ + int use_count{0}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cached_func", &cached_func); + v->Visit("use_count", &use_count); + } + static constexpr const char* _type_key = "relay.CCacheValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); +}; + +/*! \brief cache entry used in compile engine */ +class CCacheValue : public ObjectRef { + public: + CCacheValue() {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } + using ContainerType = CCacheValueNode; +}; + +Array GetShape(const Array& shape); + +/*! + * \brief Create schedule for target. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \return Pair of schedule and cache. + * The funcs field in cache is not yet populated. + */ +CachedFunc PrimFuncFor(const Function& source_func, const Target& target, + std::function renamer); + +CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, + std::function renamer); + +std::string GetUniqueName(std::string name, std::unordered_map* name_map); + +// implementations +inline size_t CCacheKeyNode::Hash() const { + if (hash_ != 0) return hash_; + // do structral hash, avoid 0. + hash_ = tvm::StructuralHash()(this->source_func); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); + if (hash_ == 0) hash_ = 1; + return hash_; +} + +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (Hash() != other->Hash()) return false; + return this->target->str() == other->target->str() && + tvm::StructuralEqual()(this->source_func, other->source_func); +} + +} // namespace tec +} // namespace relay +} // namespace tvm + +namespace std { +// overload hash +template <> +struct hash<::tvm::relay::tec::CCacheKey> { + size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const { + ICHECK(key.defined()); + return key->Hash(); + } +}; +} // namespace std + +#endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 3ea15438fe8f..4b4844599e29 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -24,6 +24,8 @@ #include "utils.h" +#include + namespace tvm { namespace relay { namespace backend { @@ -39,6 +41,30 @@ StorageInfo::StorageInfo(std::vector storage_ids, std::vector ids; + for (auto id : si->storage_ids) { + ids.push_back(id); + } + return ids; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { + Array device_types; + for (auto id : si->device_types) { + device_types.push_back(id); + } + return device_types; +}); + +TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes").set_body_typed([](StorageInfo si) { + Array storage_sizes_in_bytes; + for (auto id : si->storage_sizes_in_bytes) { + storage_sizes_in_bytes.push_back(id); + } + return storage_sizes_in_bytes; +}); + TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode); StaticMemoryPlan::StaticMemoryPlan(Map expr_to_storage_info) { @@ -73,6 +99,94 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { TVM_REGISTER_NODE_TYPE(FunctionInfoNode); +FunctionInfo::FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, + Map tir_primfuncs, + Map relay_primfuncs) { + ObjectPtr n = make_object(); + n->workspace_sizes = std::move(workspace_sizes); + n->io_sizes = std::move(io_sizes); + n->constant_sizes = std::move(constant_sizes); + n->tir_primfuncs = std::move(tir_primfuncs); + n->relay_primfuncs = std::move(relay_primfuncs); + data_ = std::move(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionInfoNode(\n" + << "workspace_sizes=" << node->workspace_sizes << ",\n io_sizes=" << node->io_sizes + << ",\n constant_sizes=" << node->constant_sizes + << ",\n tir_primfuncs=" << node->tir_primfuncs + << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; + }); + +Array GetPassPrefix(const Map& targets, bool is_vm) { + Array pass_seqs; + Array entry_functions{"main"}; + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + + if (is_vm) { + // eta expand to support constructors in argument position + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + } else { + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + } + + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == DataType::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); + if (is_vm) { + pass_seqs.push_back(transform::InlinePrimitives()); + } + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + if (!is_vm) { + pass_seqs.push_back(transform::InferType()); + } + pass_seqs.push_back(transform::AlterOpLayout()); + } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); + pass_seqs.push_back(transform::FoldConstant()); + return pass_seqs; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 7d7f026c298e..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -44,7 +44,12 @@ namespace tvm { namespace relay { +namespace transform { +Pass InlinePrimitives(); +} + namespace backend { +using Pass = tvm::transform::Pass; /*! * \brief The static storage information produced by memory planning. @@ -114,6 +119,10 @@ struct FunctionInfoNode : public Object { class FunctionInfo : public ObjectRef { public: + FunctionInfo(Map workspace_sizes, Map io_sizes, + Map constant_sizes, Map tir_primfuncs, + Map relay_primfuncs); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode); }; @@ -406,6 +415,18 @@ inline bool IsCompileEngineCacheDisabled() { .value(); } +/*! + * \brief Get the sequence of Relay optimization passes based on backend type. + * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight + * difference. This function unifies the shared optimization pass prefix between vm and graph + * runtime, and returns the pass prefix given the backend type. + * + * \param targets The device type to `Target` mapping. + * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. + * \return An array of passes. + */ +Array GetPassPrefix(const Map& targets, bool is_vm); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c50f2f65f949..ddb1911a6b71 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -978,7 +978,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // update primitive function map size_t primitive_index = 0; for (const auto& cfunc : context_.cached_funcs) { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); + exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } } @@ -1042,57 +1042,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - // eta expand to support constructors in argument position - pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); - - pass_seqs.push_back(transform::SimplifyInference()); - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::InlinePrimitives()); - - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = relay::backend::GetPassPrefix(targets, true); if (targets_.size() > 1) { // Handle heterogeneous compilation. @@ -1173,8 +1123,9 @@ void VMCompiler::Codegen() { if (target->kind->device_type == kDLExtDev) { // Collect metadata in functions that are handled by external codegen. - ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - Function func = Downcast(mod->Lookup(cfunc->func_name)); + auto name = cfunc->prim_fn_var->name_hint; + ICHECK(mod->ContainGlobalVar(name)); + Function func = Downcast(mod->Lookup(name)); backend::UpdateConstants(func, ¶ms_); } else if (funcs.count(target) == 0) { funcs.Set(target, mod); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index c9920a621b56..83ac55fce085 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -62,9 +62,17 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body - << ", " << node->type_params << ", " << node->attrs << ")"; + // TODO(@jroesch): previously this had a debug printer, the debug printer + // can cause exponential behavior and is currently dangerous, for these + // cases we need some kind of de-duping. + // + // See old implementation: + // + // auto* node = static_cast(ref.get()); + // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << + // node->body + // << ", " << node->type_params << ", " << node->attrs << ")"; + p->stream << PrettyPrint(ref); }); } // namespace relay diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 87cf89a223ec..002105f4d565 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -31,10 +31,10 @@ namespace tvm { namespace relay { namespace dyn { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // {data, size, out} ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -42,7 +42,7 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -66,24 +66,24 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - double bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Expr size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + double cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("dyn.image.resize"); + static const Op& op = Op::Get("dyn.image.resize2d"); return Call(op, {data, size}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("dyn.image.resize") +RELAY_REGISTER_OP("dyn.image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -100,12 +100,12 @@ RELAY_REGISTER_OP("dyn.image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("size", "Tensor", "The output size tensor.") .set_support_level(5) - .add_type_rel("DynResize", ResizeRel) + .add_type_rel("DynResize2D", Resize2DRel) .set_attr("TOpPattern", kInjective); } // namespace dyn diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b672c7f87c05..ee779841505c 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -31,8 +31,6 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); - template InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, @@ -58,15 +56,90 @@ InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params)); } -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +TVM_REGISTER_NODE_TYPE(Resize1DAttrs); + +bool Resize1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + + const Resize1DAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCW); + ICHECK(layout_converter.defined()) + << "Resize only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, param->size[0]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // assign output type + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + return true; +} + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize1D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); + attrs->size = std::move(size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("image.resize1d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.resize1d").set_body_typed(MakeResize1D); + +RELAY_REGISTER_OP("image.resize1d") + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 3D array of shape + (batch_size, channels, in_width) for NCW + (batch_size, in_width, channels) for NWC + +- **out**: Output is 3D array of shape + for layout NCW + (batch_size, channels, size[0]) + + for layout NWC + (batch_size, size[0], channels) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize1D", Resize1DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -90,25 +163,25 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("image.resize"); + static const Op& op = Op::Get("image.resize2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("image.resize") +RELAY_REGISTER_OP("image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -122,17 +195,17 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize", ResizeRel) - .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .add_type_rel("Resize2D", Resize2DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); -TVM_REGISTER_NODE_TYPE(Resize3dAttrs); +TVM_REGISTER_NODE_TYPE(Resize3DAttrs); -bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, +bool Resize3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -140,7 +213,7 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCDHW("NCDHW"); - const Resize3dAttrs* param = attrs.as(); + const Resize3DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); @@ -165,19 +238,23 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize3d(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize3D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize3d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3d); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3D); RELAY_REGISTER_OP("image.resize3d") .describe(R"code( @@ -194,11 +271,11 @@ Perform resize3d to input array with nearest neighbour or bilinear interpolation for layout NDHWC (batch_size, size[0], size[1], size[2], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize3d", Resize3dRel) + .add_type_rel("Resize3d", Resize3DRel) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 6f4db5ab268a..43ce6656cdc0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -49,7 +49,7 @@ Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtyp Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); -Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype); +Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, bool transpose_b); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); @@ -101,9 +101,9 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 4eaa12b17d7b..a96f167df2bb 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -932,89 +932,43 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). .set_support_level(1) .add_type_rel("GroupNorm", GroupNormRel); -// relay.nn.batch_matmul +// ------------------- relay.nn.batch_matmul TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs); -bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* x = types[0].as(); - const auto* y = types[1].as(); - if (x == nullptr || y == nullptr) return false; - - const auto* param = attrs.as(); - Array y_shape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - y_shape = y->shape; - } else { - y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"b", "j", "k"}); - } - - ICHECK(x->shape.size() == 3 && y_shape.size() == 3); - bool is_dyn = false; - Array oshape; - for (size_t i = 0; i < 3; ++i) { - if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { - is_dyn = true; - oshape.push_back(Any()); - } else { - if (i == 0) { - oshape.push_back(max(x->shape[i], y_shape[i])); - } else { - oshape.push_back(x->shape[i]); - } - } - } - if (!is_dyn) { - ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || - reporter->AssertEQ(y_shape[0], 1)) - << "BatchDot: batch dimensions don't match, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - - oshape.Set(2, y_shape[1]); - } - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = x->dtype; - } - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) { +Expr MakeBatchMatmul(Expr tensor_a, Expr tensor_b, DataType out_dtype, bool transpose_a, + bool transpose_b) { auto attrs = make_object(); attrs->out_dtype = out_dtype; + attrs->transpose_a = transpose_a; + attrs->transpose_b = transpose_b; static const Op& op = Op::Get("nn.batch_matmul"); - return Call(op, {x, y}, Attrs(attrs), {}); + return Call(op, {tensor_a, tensor_b}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") - .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` -are data in batch. + .describe(R"code(Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + +Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format +(transpose_a=False, transpose_b=True) by default. .. math:: - batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T) + batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T) -- **x**: `(b, m, k)` -- **y**: `(b, n, k)` +- **tensor_a**: `(b, m, k)` or `(b, k, m)` +- **tensor_b**: `(b, k, n)` or `(b, n, k)` - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) .set_num_inputs(2) - .add_argument("x", "3D Tensor", "First input.") - .add_argument("y", "3D Tensor", "Second input.") + .add_argument("tensor_a", "3D Tensor", "The first input.") + .add_argument("tensor_b", "3D Tensor", "The second input.") .set_support_level(10) - .add_type_rel("BatchMatmul", BatchMatmulRel); + .add_type_rel("BatchMatmul", BatchMatmulRel); +// ------------------- relay.nn.batch_matmul // relay.nn.cross_entropy bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 29f200c67c59..3dc63b31a205 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,10 +24,12 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include #include #include #include +#include #include #include "../op_common.h" @@ -137,6 +139,59 @@ bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + + const AttrType* param = attrs.as(); + ICHECK(param != nullptr); + bool transpose_a = param->transpose_a; + bool transpose_b = param->transpose_b; + const Array& y_shape = + param->auto_scheduler_rewritten_layout.size() == 0 + ? y->shape + : auto_scheduler::GetShapeFromRewrittenLayout( + param->auto_scheduler_rewritten_layout, + transpose_b ? tvm::runtime::Array({"b", "j", "k"}) + : tvm::runtime::Array({"b", "k", "j"})); + ICHECK(x->shape.size() == 3 && y_shape.size() == 3); + const PrimExpr& xb = x->shape[0]; + const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; + const PrimExpr& xk = x->shape[transpose_a ? 1 : 2]; + const PrimExpr& yb = y_shape[0]; + const PrimExpr& yk = y_shape[transpose_b ? 2 : 1]; + const PrimExpr& yj = y_shape[transpose_b ? 1 : 2]; + + bool is_dyn = false; + for (size_t i = 0; i < 3; ++i) { + if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { + is_dyn = true; + break; + } + } + if (!is_dyn) { + ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1)) + << "BatchDot: batch dimensions don't match, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + } + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = x->dtype; + } + // assign output type + const auto& out_b = + xb->IsInstance() || yb->IsInstance() ? tir::Any() : max(xb, yb); + reporter->Assign(types[2], TensorType(Array({out_b, xi, yj}), out_dtype)); + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2d14ee37fde0..9f9ed1c075cd 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2353,7 +2353,15 @@ bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs const InitOpAttrs* param = attrs.as(); ICHECK(param); - DataType out_dtype = types[0].as()->dtype; + DataType out_dtype; + if (auto ttype = types[0].as()) { + out_dtype = ttype->dtype; + } else { + ICHECK(types[0].as()) + << "Broadcast: expect to be TensorType but get " << types[0]; + return false; + } + std::vector oshape; const Array& cshape_array = param->shape.value(); diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc new file mode 100644 index 000000000000..4b0bcacacaa1 --- /dev/null +++ b/src/relay/qnn/op/batch_matmul.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/batch_matmul.cc + * \brief Property def of qnn batch_matmul operator. + */ + +#include +#include +#include +#include + +#include "../../op/nn/nn.h" +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.op.qnn.batch_matmul + +bool QnnBatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected Types: x, y, x_zero_point, y_zero_point, x_scale, y_scale, + // out_type + ICHECK_EQ(types.size(), 7); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "BatchMatmulAttrs cannot be nullptr."; + ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for input but was " << x->dtype; + ICHECK(y->dtype == DataType::Int(8) || y->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for weight but was " << y->dtype; + ICHECK(param->out_dtype == DataType::Int(32)) + << "Expected quantized batch_matmul type(int32) for output but was " << param->out_dtype; + + // Check the types of scale and zero points. + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + ICHECK(IsScalarType(types[2], DataType::Int(32))); // x_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // y_zero_point + ICHECK(IsScalarType(types[4], DataType::Float(32))); // x_scale + ICHECK(IsScalarType(types[5], DataType::Float(32))); // y_scale + + ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // BatchMatmul infer type function. + Array tensor_types = {types[0], types[1], types[6]}; + return BatchMatmulRel(tensor_types, 3, attrs, reporter); +} + +// Positional relay function to create quantized batch_matmul operator used by frontend FFI. +Expr MakeQuantizedBatchMatmul(Expr x, Expr y, Expr x_zero_point, Expr y_zero_point, Expr x_scale, + Expr y_scale, DataType out_dtype) { + auto attrs = make_object(); + attrs->out_dtype = out_dtype; + // For legacy reason, currently `qnn.batch_matmul` only supports + // (transpose_a=false, transpose_b=true) + // TODO(jcf94): extent to support all tensor format + attrs->transpose_a = false; + attrs->transpose_b = true; + static const Op& op = Op::Get("qnn.batch_matmul"); + return Call(op, {x, y, x_zero_point, y_zero_point, x_scale, y_scale}, Attrs(attrs), {}); +} + +Expr BatchMatmulFirstTerm(const Expr& quantized_x, const Expr& quantized_y, + const BatchMatmulAttrs* attrs) { + ICHECK(attrs->transpose_a == false && attrs->transpose_b == true) + << "Currently qnn.batch_matmul only supports (transpose_a=false, transpose_b=true)."; + return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype, attrs->transpose_a, + attrs->transpose_b); +} + +Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr& y_zero_point) { + Array axes = {2}; + return Multiply(y_zero_point, Sum(Cast(x_quantized_data, DataType::Int(32)), axes, true, false)); +} + +Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr& x_zero_point, + int broadcast_dim_size) { + Array axes = {2}; + auto reducemult = + Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), axes, true, false)); + Array newshape; + newshape = {1, 1, broadcast_dim_size}; + return Reshape(reducemult, newshape); +} + +Expr BatchMatmulFourthTerm(int x_zero_point_int, int y_zero_point_int, int reduction_dim_size) { + int32_t scalar_term = x_zero_point_int * y_zero_point_int * reduction_dim_size; + return MakeConstantScalar(DataType::Int(32), scalar_term); +} + +Expr BatchMatmulCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, + const Expr& term4) { + auto data1_term = Subtract(term1, term2); + auto data2_term = Subtract(term4, term3); + return Add(data1_term, data2_term); +} + +/* + * \brief Forward rewrite the qnn batch_matmul op. + * \param attrs The QNN batch_matmul attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for qnn batch_matmul op. + * \note Lowering of the qnn.batch_matmul operator + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantization + * params. + * + * Quantized batch_matmul multiplies two quantized tensors and returns a + * quantized tensor of default dtype of int32, with scale equaling to the + * product of scales of input tensors, and a zero point of zero. + * + * The lowering for asymmetric quantized batch_matmul looks similar to + * quantized conv2d and dense and originally was discussed here: + * https://discuss.tvm.apache.org/t/tf-lite-quantized-conv2d-operator-conversion/2651/7 + * + * The computation gets unrolled into following 4 terms + * C(m, n) = Sigma(k) (X(m, k) * Y(n, k)) + * + * RHS becomes + * Sigma(k) ([QX(m, k) - zp_x] * [QY(n, k) - zp_y]) + * + * Unrolling leads to following sequence + * Sigma(k) QX(m, k) * QX(n, k) // Term1 + * - Sigma(k) zp_y * QX(m, k) // Term2 + * - Sigma(k) zp_x * QY(n, k) // Term3 + * - Sigma(k) * zp_x * zp_y // Term4 + * + * Term4 can be computed at compile time, everything else depending on the + * input type. + */ +Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + ICHECK_EQ(new_args.size(), 6); + Expr quantized_x = new_args[0]; + Expr quantized_y = new_args[1]; + Expr x_zero_point = new_args[2]; + Expr y_zero_point = new_args[3]; + + const auto in_shape = get_shape(arg_types[0]); + const int reduction_dim_size = get_const_int(in_shape[2]); + + const auto y_shape = get_shape(arg_types[1]); + const int broadcast_dim_size = get_const_int(y_shape[1]); + + const auto* qnn_batch_matmul_attrs = attrs.as(); + + // Extract the integer zero points. + auto y_zero_point_int = GetScalarFromConstant(y_zero_point); + auto x_zero_point_int = GetScalarFromConstant(x_zero_point); + + // Get all the terms as described in the comments. + auto term1 = BatchMatmulFirstTerm(quantized_x, quantized_y, qnn_batch_matmul_attrs); + auto term2 = BatchMatmulSecondTerm(quantized_x, y_zero_point); + auto term3 = BatchMatmulThirdTerm(quantized_y, x_zero_point, broadcast_dim_size); + auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int, reduction_dim_size); + + // Combine those 4 terms depending on the zero points to get the best lowering. + if (x_zero_point_int == 0 && y_zero_point_int == 0) { + // term 2, 3 and 4 become zero. + return term1; + } else if (x_zero_point_int == 0 && y_zero_point_int != 0) { + // term 3 and term 4 become zero. + return Subtract(term1, term2); + } else if (x_zero_point_int != 0 && y_zero_point_int == 0) { + // term 2 and term 4 become zero. + return Subtract(term1, term3); + } else { + return BatchMatmulCombineTerms(term1, term2, term3, term4); + } +} + +RELAY_REGISTER_OP("qnn.batch_matmul") + .describe(R"code(Applies a linear transformation: :math:`Z = XY`. +- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, unit8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(6) + .add_argument("x", "quantized 2D Tensor", "First input data.") + .add_argument("y", "quantized 2D Tensor", "Second input data.") + .add_argument("x_scale", "Tensor", "The quantization scale of the x input tensor.") + .add_argument("x_zero_point", "Tensor", "The quantization zero_point of the x input tensor.") + .add_argument("y_scale", "Tensor", "The quantization scale of the y input tensor.") + .add_argument("y_zero_point", "Tensor", "The quantization zero_point of the y input tensor.") + .set_support_level(11) + .add_type_rel("QBatchMatmul", QnnBatchMatmulRel) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnBatchMatmulCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.batch_matmul").set_body_typed(MakeQuantizedBatchMatmul); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 751abfc5ca81..2f1d7d8da16c 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -51,14 +51,20 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis < 0) ? data->shape.size() + axis : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << quantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 769f37205790..46de3522061b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -279,14 +279,20 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << requantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << requantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // input_scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // input_zero_pt // For now, requantize output tensor is limited to full tensor uniform quantization. ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index da0bd35a332a..7a86af8aeffa 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -126,7 +126,7 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - CreateSchedule(GetRef(func), Target::Current()); + PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index f8c46d93c675..ddab87a4893e 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -68,6 +68,16 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { // shape[2] is the contraction axis and automatically consistent // if it were valid batch_matmul ops + // TODO(jcf94): Add full support of layout format + if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true && + attrs_b->transpose_a == false && attrs_b->transpose_b == true)) { + LOG(WARNING) << "For legacy reason, this pass only supports" + << " (transpose_a=false, transpose_b=true) now, skip combining these two with:" + << " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b + << " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b; + return false; + } + auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype); @@ -86,7 +96,8 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype)); + return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype, + origin_attrs->transpose_a, origin_attrs->transpose_b)); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 3cd9cca4fec4..d5404ba30f90 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -72,7 +72,8 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { CHECK_EQ(num_args, 2); const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype)); + return Downcast( + MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype, false, true)); } virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc index 6abf4c31d359..d7a9bfde57c3 100644 --- a/src/relay/transforms/defuse_ops.cc +++ b/src/relay/transforms/defuse_ops.cc @@ -41,19 +41,13 @@ class DefuseOpsMutator : public ExprMutator { public: class FuncBodyMutator : public ExprMutator { public: - explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } - - Expr VisitExpr_(const VarNode* n) { - const std::string& name = n->name_hint(); - ICHECK(!name.empty() && (name[0] == 'p')); - std::string id_str = name.substr(1); - int id = std::stoi(id_str); - ICHECK(id >= 0 && size_t(id) < args_.size()); - return args_[id]; - } + explicit FuncBodyMutator(std::unordered_map args) + : ExprMutator(), name_to_args_(std::move(args)) {} + + Expr VisitExpr_(const VarNode* n) { return name_to_args_[n->name_hint()]; } private: - Array args_; + std::unordered_map name_to_args_; }; Expr VisitExpr_(const CallNode* n) { @@ -62,7 +56,15 @@ class DefuseOpsMutator : public ExprMutator { if (const auto* call = new_n.as()) { if (const auto* func = call->op.as()) { if (func->body->IsInstance()) { - return FuncBodyMutator(call->args).Mutate(func->body); + std::unordered_map name_to_args; + for (size_t i = 0; i < func->params.size(); ++i) { + const std::string& pname = func->params[i]->name_hint(); + ICHECK(name_to_args.cend() == name_to_args.find(pname)) + << "Found multiple parameters share the same variable name `" << pname + << "` which introduces uncertainty in DefuseOps pass"; + name_to_args[pname] = call->args[i]; + } + return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body); } } } diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index e744fb51e0a6..02f9d474411a 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -53,7 +53,18 @@ bool IsOnDeviceNode(const ExprNode* node) { bool IsDeviceCopyNode(const ExprNode* node) { if (!node->IsInstance()) return false; const auto* call_node = static_cast(node); - return call_node->attrs.as(); + + if (call_node->attrs.as()) { + return true; + } + + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs) { + auto metadata = tir_call_attrs->metadata; + return metadata.count("source_device") == 1 && metadata.count("dst_device") == 1; + } + + return false; } } // namespace @@ -395,16 +406,31 @@ class DeviceInfo { const auto* call_node = static_cast(node); auto attrs = call_node->attrs.as(); - num_device_copy_ops_++; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments + if (attrs) { + num_device_copy_ops_++; dev_type_ = attrs->src_dev_type; + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = attrs->src_dev_type; + } + device_tag_[call] = attrs->dst_dev_type; + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = attrs->dst_dev_type; + } else { + auto attrs = call_node->attrs.as(); + CHECK(attrs) << "must be non-null"; + num_device_copy_ops_++; + dev_type_ = Downcast(attrs->metadata["source_device"]); + for (auto& arg : call->args) { + Visit(arg); + // restore the type for remaining arguments + dev_type_ = Downcast(attrs->metadata["source_device"]); + } + device_tag_[call] = Downcast(attrs->metadata["dst_device"]); + // update the out_dev_type_, which should be the dst_dev_type of last copy + out_dev_type_ = Downcast(attrs->metadata["dst_device"]); } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; } else { for (auto& arg : call->args) { int cur_dev_type = dev_type_; diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 7c947ba109bf..318022fb86f5 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -106,20 +106,20 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, - {Op::Get("dyn.image.resize"), + {Op::Get("dyn.image.resize2d"), [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* size = args[1].as()) { - const ResizeAttrs* param = call_node->attrs.as(); + const Resize2DAttrs* param = call_node->attrs.as(); ICHECK(param); auto size_int = ToVector(size->data); Array size_prim; for (size_t i = 0; i < size_int.size(); ++i) { size_prim.push_back(size_int[i]); } - return MakeResize(call_node->args[0], size_prim, param->layout, param->method, - param->coordinate_transformation_mode, param->rounding_method, - param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); + return MakeResize2D(call_node->args[0], size_prim, param->layout, param->method, + param->coordinate_transformation_mode, param->rounding_method, + param->cubic_alpha, param->cubic_exclude, param->out_dtype); } return Expr(nullptr); }}, diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index f883b4113656..b5f434e74c43 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -23,10 +23,14 @@ * to actual integer operations. */ +#include #include #include #include +namespace tvm { +namespace relay { + /* Description of FakeQuantizationToInteger * * The purpose of this pass is to find regions of the graph that follow @@ -63,65 +67,6 @@ * rewritten subgraph and the processing continues */ -namespace tvm { -namespace relay { - -/*! - * \brief AffineType representation - * \sa AffineType - */ -class AffineTypeNode : public Object { - public: - /*! \brief The scale of this type */ - Expr scale; - /*! \brief The zero point of this type */ - Expr zero_point; - /*! \brief The data type of this type */ - DataType dtype; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("scale", &scale); - v->Visit("zero_point", &zero_point); - v->Visit("dtype", &dtype); - } - - bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(scale); - hash_reduce(zero_point); - hash_reduce(dtype); - } - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const char* _type_key = "AffineTypeNode"; - TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); -}; - -/*! - * \brief Managed reference to AffineTypes. - * \sa AffineTypeNode - */ -class AffineType : public ObjectRef { - public: - TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { - ObjectPtr n = make_object(); - n->scale = std::move(scale); - n->zero_point = std::move(zero_point); - n->dtype = std::move(dtype); - data_ = std::move(n); - } - TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); -}; - -TVM_REGISTER_NODE_TYPE(AffineTypeNode); - using ExprSet = std::unordered_set; using ExprMap = std::unordered_map; using AffineTypeMap = Map; @@ -147,8 +92,14 @@ class SubgraphExtractor : public ExprVisitor { } const AffineTypeMap GetAffineTypes() { return affine_types_; } void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. if (expr.as() == nullptr && expr.as() == nullptr && - expr.as() == nullptr) { + expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + LOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; is_fake_quantized_ = false; } else { ExprVisitor::VisitExpr(expr); @@ -162,13 +113,14 @@ class SubgraphExtractor : public ExprVisitor { VisitExpr(call_node->args[0]); // Collect type of quantize ops affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); } else if (call_node->op == dequantize_op_) { // Collect type of dequantize ops - affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); @@ -225,19 +177,38 @@ class SubgraphMutator : public ExprMutator { } // Call the rewrite Array vals = fqfq[op](expr, affine_types_); - // Save teh outputs of the rewrite - ICHECK(vals.size() == 4) + // Save the outputs of the rewrite + ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " << AsText(op, false); out = Downcast(vals[0]); - affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), - DataType(String2DLDataType(Downcast(vals[3]))))); + affine_types_.Set(out, Downcast(vals[1])); } else { ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " << AsText(GetRef(call_node), false); } return out; } + + Expr VisitExpr_(const TupleNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto new_node = expr.as(); + Array types; + for (Expr field : new_node->fields) { + ICHECK(affine_types_[field].as()); + types.push_back(Downcast(affine_types_[field])); + } + affine_types_.Set(expr, TupleAffineType(types)); + return expr; + } + + Expr VisitExpr_(const TupleGetItemNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto tuple_type = affine_types_[expr.as()->tuple].as(); + affine_types_.Set(expr, tuple_type->types[node->index]); + return expr; + } + ExprSet subgraph_; AffineTypeMap affine_types_; AffineType out_type_; diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a93532895b5a..7056dfe79fee 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -202,7 +202,7 @@ using FForwardRewrite = TypedPackedFunc Prepare(const Expr& body) { this->Update(body, NullValue()); @@ -585,15 +585,22 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef { - auto it = message.find(call.get()); - if (it != message.end()) { - return it->second; - } else { - return ObjectRef(nullptr); + for (const auto& m : message) { + if (m.second.defined()) { + // run optimization + auto fcontext = [&](const Call& call) -> ObjectRef { + auto it = message.find(call.get()); + if (it != message.end()) { + return it->second; + } else { + return ObjectRef(nullptr); + } + }; + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } - }; - return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); + } + // no messages - no optimization + return data; } //---------------------------------------- @@ -618,7 +625,7 @@ using FBackwardTransform = // Generic Visitors for FScaleAxisBackward //---------------------------------------------- -class BackwardPrep : private ExprVisitor { +class BackwardPrep : private MixedModeVisitor { public: // The message on each node. std::unordered_map Prepare(const Expr& body) { @@ -643,6 +650,14 @@ class BackwardPrep : private ExprVisitor { // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; + Array in_messages = GetInMessages(call); + Message out_message = f(GetRef(call), in_messages); + if (out_message.defined()) { + message_[call] = out_message; + } + } + + Array GetInMessages(const CallNode* call) { Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); @@ -652,52 +667,34 @@ class BackwardPrep : private ExprVisitor { in_messages.push_back(NullValue()); } } - Message out_message = f(GetRef(call), in_messages); - if (out_message.defined()) { - message_[call] = out_message; - } + return in_messages; } }; -class BackwardTransformerNode : public Object, private ExprMutator { +/* + * Hybrid apporach is used with the transformation + * itself is recursive but the traversal is non-recursive + */ +class BackwardTransformerNode : public Object, private MixedModeMutator { public: + using MixedModeMutator::Mutate; // Run forward transform. Expr Fold(Expr expr) { message_ = BackwardPrep().Prepare(expr); - return this->Mutate(expr); - } - /*! - * \brief Transform the expr to consider the scaling. - * - * \param expr The input expression. - * \param axes The axes to scale. - * \param scale The scale applied to the axes. - * \return The result of transformation. - */ - Expr Transform(const Expr& expr, Message message, Expr scale) { - // NOTE: the result of Transform is memoized. - if (const CallNode* call_node = expr.as()) { - return Transform(call_node, message, scale); - } else { - ICHECK(!message.defined()) << "outstanding scale"; - return ExprMutator::VisitExpr(expr); + for (const auto& m : message_) { + if (m.second.defined()) { + // run optimization + return this->Mutate(expr); + } } + // no messages - no optimization + return expr; } + /*! - * \brief Normal way of mutating call node. - * \param call_node The call node to be mutated. - * \return the result of the call Mutation. + * \brief Transform the expr to consider the scaling. */ - Expr NormalCallTransform(const CallNode* call_node) { - const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; - } - Expr new_expr = ExprMutator::VisitExpr_(call_node); - memo_[call] = new_expr; - return new_expr; - } + Expr Transform(const Expr& expr, Message message, Expr scale); /*! * \brief Get the message propogated to the expr. * \param expr The expresison. @@ -719,11 +716,12 @@ class BackwardTransformerNode : public Object, private ExprMutator { // Valid axes on each node. std::unordered_map message_; // Override mutation of call. - Expr VisitExpr_(const CallNode* call_node) final { - return Transform(call_node, NullValue(), NullValue()); + Expr Rewrite_(const CallNode* call_node, const Expr& post) final { + return Transform(GetRef(call_node), NullValue(), NullValue()); } - // Transform of CallNode. - Expr Transform(const CallNode* call_node, Message message, Expr scale); + + public: + Expr NormalCallTransform(const CallNode* call_node) { return ExprMutator::VisitExpr_(call_node); } }; class BackwardTransformer : public ObjectRef { @@ -736,21 +734,39 @@ class BackwardTransformer : public ObjectRef { using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); - auto f = ftransform.get(call_node->op, nullptr); - if (f != nullptr) { +/*! + * \brief Transform the expr to consider the scaling. + * + * \param expr The input expression. + * \param message The axes to scale. + * \param scale The scale applied to the axes. + * \return The result of transformation. + */ +Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, Expr scale) { + if (const CallNode* call_node = expr.as()) { + static const auto& ftransform = + Op::GetAttrMap("FScaleAxisBackwardTransform"); + auto f = ftransform.get(call_node->op, nullptr); const Call call = GetRef(call_node); - const auto it = memo_.find(call); - if (it != memo_.end()) { - return it->second; + // ignore if there is a message + if (!message.defined()) { + const auto it = memo_.find(call); + if (it != memo_.end()) { + return it->second; + } + } + Expr new_expr = NullValue(); + if (f != nullptr) { + new_expr = f(call, message, scale, GetRef(this)); + } else { + ICHECK(!message.defined()) << "outstanding scale"; + new_expr = NormalCallTransform(call.operator->()); } - Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { ICHECK(!message.defined()) << "outstanding scale"; - return NormalCallTransform(call_node); + return this->Mutate(expr); } } @@ -813,6 +829,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } + Message lhs_message = transformer->GetMessage(call->args[0]); Message rhs_message = transformer->GetMessage(call->args[1]); StructuralEqual equal; @@ -959,7 +976,9 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp } else { wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, {big_ko_axis, small_ko_axis}); - if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->()); + if (!wscale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 03473b7d7455..b61567d0bae0 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include "../backend/compile_engine.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./pass_utils.h" #include "let_list.h" #include "pattern_utils.h" @@ -66,9 +68,18 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { - if (auto* func = expr.as()) { + if (const FunctionNode* func = expr.as()) { return func->HasNonzeroAttr(attr::kReshapeOnly); } + if (const CallNode* call = expr.as()) { + if (call->attrs.defined()) { + if (auto tir_call_attrs = call->attrs.as()) { + Map metadata = tir_call_attrs->metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); + } + } + } return false; } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 68f31a17ab1b..b48fbe44bd11 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -113,12 +113,19 @@ struct RegionFuncMetadata { class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { + std::set func_names; for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; + std::string f_name = f_var.as()->name_hint; + while (func_names.find(f_name) != func_names.end()) { + f_name += "_a"; + } + func_names.insert(f_name); // Creating regionset per function in the module. - auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp()); + auto region_set = + AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp(), f_name); regions_sets_[region_set] = f_func; } } @@ -301,7 +308,7 @@ class Partitioner : public MixedModeMutator { } std::string target = end_node->attrs.as()->compiler; - std::string name = target + "_" + std::to_string(region->GetID()); + std::string name = target + "_" + region->GetName() + "_" + std::to_string(region->GetID()); // Constant propagation if (!params_bind.empty()) { @@ -502,7 +509,7 @@ class NameMangleExtFuncs : public MixedModeMutator { // Walk the tree and mangle the functions. Then replace compiler functions // with mangled functions in the module - IRModule new_module; + IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports()); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4c6013792426..f29087dcc049 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -205,8 +205,13 @@ class TypeInferencer : private ExprFunctor, this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " << "without a module"); } - relay::Function e = Downcast(mod_->Lookup(var)); - return e->checked_type(); + + if (mod_->ContainGlobalVar(var->name_hint)) { + relay::Function e = Downcast(mod_->Lookup(var)); + return e->checked_type(); + } else { + return op->checked_type_; + } } Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 426d2f24ddf5..ff748b8826de 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -89,6 +89,10 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +SoftmaxEntry::SoftmaxEntry() { MIOPEN_CALL(miopenCreateTensorDescriptor(&shape_desc)); } + +SoftmaxEntry::~SoftmaxEntry() { MIOPEN_CALL(miopenDestroyTensorDescriptor(shape_desc)); } + } // namespace miopen } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index d3a8c7b9ad64..76913696b0b9 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -62,11 +62,18 @@ struct ConvEntry { void CleanWorkspace(); }; // ConvThreadEntry +struct SoftmaxEntry { + miopenTensorDescriptor_t shape_desc; + SoftmaxEntry(); + ~SoftmaxEntry(); +}; // SoftmaxEntry + struct MIOpenThreadEntry { MIOpenThreadEntry(); ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; + SoftmaxEntry softmax_entry; runtime::DeviceAPI* rocm_api{nullptr}; static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc new file mode 100644 index 000000000000..5a0f24ed7a84 --- /dev/null +++ b/src/runtime/contrib/miopen/softmax.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/miopen/softmax.cc + * \brief Use external miopen softmax function + */ +#include +#include + +#include "miopen_utils.h" + +namespace tvm { +namespace contrib { +namespace miopen { + +using namespace runtime; + +void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + ICHECK(axis >= 0 && axis < ndim); + // just fp32 for now + ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); + ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + + miopenSoftmaxMode_t mode; + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + mode = MIOPEN_SOFTMAX_MODE_INSTANCE; + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->softmax_entry.shape_desc, miopenFloat, + static_cast(N), static_cast(shape[ndim - 1]), + 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + mode = MIOPEN_SOFTMAX_MODE_CHANNEL; + MIOPEN_CALL(miopenSet4dTensorDescriptor( + entry_ptr->softmax_entry.shape_desc, miopenFloat, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); + } + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenSoftmaxForward_V2(entry_ptr->handle, &alpha, + entry_ptr->softmax_entry.shape_desc, x->data, &beta, + entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); +} + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { + softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); + +} // namespace miopen +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc new file mode 100644 index 000000000000..b9ba8f9984e9 --- /dev/null +++ b/src/runtime/contrib/papi/papi.cc @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace profiling { + +#define PAPI_CALL(func) \ + { \ + int e = (func); \ + if (e != PAPI_OK) { \ + LOG(FATAL) << "PAPIError: in function " #func " " << e << " " \ + << std::string(PAPI_strerror(e)); \ + } \ + } + +static const std::unordered_map> default_metric_names = { + {kDLCPU, + {"perf::CYCLES", "perf::STALLED-CYCLES-FRONTEND", "perf::STALLED-CYCLES-BACKEND", + "perf::INSTRUCTIONS", "perf::CACHE-MISSES"}}, + {kDLCUDA, {"cuda:::event:elapsed_cycles_sm:device=0"}}}; + +/*! \brief Object that holds the values of counters at the start of a function call. */ +struct PAPIEventSetNode : public Object { + /*! \brief The starting values of counters for all metrics of a specific device. */ + std::vector start_values; + /*! \brief The device these counters are for. */ + Device dev; + + explicit PAPIEventSetNode(std::vector start_values, Device dev) + : start_values(start_values), dev(dev) {} + + static constexpr const char* _type_key = "PAPIEventSetNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(PAPIEventSetNode, Object); +}; + +/* Get the PAPI component id for the given device. + * \param dev The device to get the component for. + * \returns PAPI component id for the device. Returns -1 if the device is not + * supported by PAPI. + */ +int component_for_device(Device dev) { + std::string component_name; + switch (dev.device_type) { + case kDLCPU: + component_name = "perf_event"; + break; + case kDLCUDA: + component_name = "cuda"; + break; + case kDLROCM: + component_name = "rocm"; + break; + default: + LOG(WARNING) << "PAPI does not support device " << DeviceName(dev.device_type); + return -1; + } + int cidx = PAPI_get_component_index(component_name.c_str()); + if (cidx < 0) { + LOG(FATAL) << "Cannot find PAPI component \"" << component_name + << "\". Maybe you need to build PAPI with support for this component (use " + "`./configure --with-components=" + << component_name << "`)."; + } + return cidx; +} + +/*! \brief MetricCollectorNode for PAPI metrics. + * + * PAPI (Performance Application Programming Interface) collects metrics on a + * variety of platforms including cpu, cuda and rocm. + * + * PAPI is avaliable at https://bitbucket.org/icl/papi/src/master/. + */ +struct PAPIMetricCollectorNode final : public MetricCollectorNode { + /*! \brief Construct a metric collector that collects a specific set of metrics. + * + * \param metrics A mapping from a device type to the metrics that should be + * collected on that device. You can find the names of available metrics by + * running `papi_native_avail`. + */ + explicit PAPIMetricCollectorNode(Map> metrics) { + for (auto& p : metrics) { + papi_metric_names[p.first->device] = {}; + for (auto& metric : p.second) { + papi_metric_names[p.first->device].push_back(metric); + } + } + } + explicit PAPIMetricCollectorNode() {} + + /*! \brief Initialization call. + * \param devices The devices this collector will be running on + */ + void Init(Array devices) { + if (!PAPI_is_initialized()) { + if (sizeof(long_long) > sizeof(int64_t)) { + LOG(WARNING) << "PAPI's long_long is larger than int64_t. Overflow may occur when " + "reporting metrics."; + } + CHECK_EQ(PAPI_library_init(PAPI_VER_CURRENT), PAPI_VER_CURRENT) + << "Error while initializing PAPI"; + } + + // If no metrics were provided we use the default set. The names were not + // initialized in the constructor because we did not know which devices we + // were running on. + if (papi_metric_names.size() == 0) { + for (auto wrapped_device : devices) { + Device device = wrapped_device->device; + auto it = default_metric_names.find(device.device_type); + if (it != default_metric_names.end()) { + papi_metric_names[device] = it->second; + } + } + } + + // create event sets for each device + for (auto wrapped_device : devices) { + Device device = wrapped_device->device; + int cidx = component_for_device(device); + // unknown device, skipping + if (cidx < 0) { + continue; + } + + auto it = papi_metric_names.find(device); + // skip devices with no metrics defined + if (it == papi_metric_names.end() || it->second.size() == 0) { + continue; + } + auto& metric_names = it->second; + + const PAPI_component_info_t* component = PAPI_get_component_info(cidx); + if (component->disabled) { + std::string help_message = ""; + switch (device.device_type) { + case kDLCPU: + help_message = + "Try setting `sudo sh -c 'echo 1 >/proc/sys/kernel/perf_event_paranoid'`"; + break; + case kDLCUDA: + help_message = + "Try enabling gpu profiling with `modprobe nvidia " + "NVreg_RestrictProfilingToAdminUsers=0`. If that does not work, try adding " + "`options nvidia \"NVreg_RestrictProfilingToAdminUsers=0\"` to " + "`/etc/modprobe.d/nvidia-kernel-common.conf`."; + break; + default: + break; + } + LOG(WARNING) << "PAPI could not initialize counters for " << DeviceName(device.device_type) + << ": " << component->disabled_reason << "\n" + << help_message; + continue; + } + + int event_set = PAPI_NULL; + PAPI_CALL(PAPI_create_eventset(&event_set)); + PAPI_CALL(PAPI_assign_eventset_component(event_set, cidx)); + if (device.device_type == kDLCPU) { + // we set PAPI_INHERIT to make it so threads created after this inherit the event_set. + PAPI_option_t opt; + memset(&opt, 0x0, sizeof(PAPI_option_t)); + opt.inherit.inherit = PAPI_INHERIT_ALL; + opt.inherit.eventset = event_set; + PAPI_CALL(PAPI_set_opt(PAPI_INHERIT, &opt)); + } + + if (static_cast(metric_names.size()) > PAPI_num_cmp_hwctrs(cidx)) { + PAPI_CALL(PAPI_set_multiplex(event_set)); + } + + // add all the metrics + for (auto metric : metric_names) { + int e = PAPI_add_named_event(event_set, metric.c_str()); + if (e != PAPI_OK) { + LOG(FATAL) << "PAPIError: " << e << " " << std::string(PAPI_strerror(e)) << ": " << metric + << "."; + } + } + // Because we may have multiple calls in flight at the same time, we + // start all the timers when we initialize. Then we calculate the metrics + // counts for a call by comparing counter values at the start vs end of + // the call. + PAPI_CALL(PAPI_start(event_set)); + event_sets[device] = event_set; + } + } + /*! \brief Called right before a function call. Reads starting values of the + * measured metrics. + * + * \param dev The device the function will be run on. + * \returns A `PAPIEventSetNode` containing values for the counters at the + * start of the call. Passed to a corresponding `Stop` call. + */ + ObjectRef Start(Device dev) final { + // Record counter values at the start of the call, so we can calculate the + // metrics for the call by comparing the values at the end of the call. + auto it = event_sets.find(dev); + if (it != event_sets.end()) { + int event_set = it->second; + std::vector values(papi_metric_names[dev].size()); + PAPI_CALL(PAPI_read(event_set, values.data())); + return ObjectRef(make_object(values, dev)); + } else { + return ObjectRef(nullptr); + } + } + /*! \brief Called right after a function call. Reads ending values of the + * measured metrics. Computes the change in each metric from the + * corresponding `Start` call. + * + * \param obj `PAPIEventSetNode` created by a call to `Start`. + * \returns A mapping from metric name to value. + */ + Map Stop(ObjectRef obj) final { + const PAPIEventSetNode* event_set_node = obj.as(); + std::vector end_values(papi_metric_names[event_set_node->dev].size()); + PAPI_CALL(PAPI_read(event_sets[event_set_node->dev], end_values.data())); + std::unordered_map reported_metrics; + for (size_t i = 0; i < end_values.size(); i++) { + if (end_values[i] < event_set_node->start_values[i]) { + LOG(WARNING) << "Detected overflow when reading performance counter, setting value to -1."; + reported_metrics[papi_metric_names[event_set_node->dev][i]] = + ObjectRef(make_object(-1)); + } else { + reported_metrics[papi_metric_names[event_set_node->dev][i]] = + ObjectRef(make_object(end_values[i] - event_set_node->start_values[i])); + } + } + return reported_metrics; + } + + ~PAPIMetricCollectorNode() final { + for (auto p : event_sets) { + PAPI_CALL(PAPI_stop(p.second, NULL)); + PAPI_CALL(PAPI_cleanup_eventset(p.second)); + PAPI_CALL(PAPI_destroy_eventset(&p.second)); + } + } + + /*! \brief Device-specific event sets. Contains the running counters (the int values) for that + * device. */ + std::unordered_map event_sets; + /*! \brief Device-specific metric names. Order of names matches the order in the corresponding + * `event_set`. */ + std::unordered_map> papi_metric_names; + + static constexpr const char* _type_key = "runtime.profiling.PAPIMetricCollector"; + TVM_DECLARE_FINAL_OBJECT_INFO(PAPIMetricCollectorNode, MetricCollectorNode); +}; + +/*! \brief Wrapper for `PAPIMetricCollectorNode`. */ +class PAPIMetricCollector : public MetricCollector { + public: + explicit PAPIMetricCollector(Map> metrics) { + data_ = make_object(metrics); + } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PAPIMetricCollector, MetricCollector, + PAPIMetricCollectorNode); +}; + +MetricCollector CreatePAPIMetricCollector(Map> metrics) { + return PAPIMetricCollector(metrics); +} + +TVM_REGISTER_OBJECT_TYPE(PAPIEventSetNode); +TVM_REGISTER_OBJECT_TYPE(PAPIMetricCollectorNode); + +TVM_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") + .set_body_typed([](Map> metrics) { + return PAPIMetricCollector(metrics); + }); + +} // namespace profiling +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index d8182b0e8378..08ac2ae0ec45 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -163,10 +163,19 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { auto profile = builder_->createOptimizationProfile(); for (int i = 0; i < network_->getNbInputs(); ++i) { auto name = network_->getInput(i)->getName(); - auto dims = network_->getInput(i)->getDimensions(); - profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + const uint32_t entry_id = entry_id_map_[name]; + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + auto dims = VectorToTrtDims(shape); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + // Set minimum batch size to 1 when dynamic batching is used. + if (network_->getInput(i)->getDimensions().nbDims >= 1 && + network_->getInput(i)->getDimensions().d[0] == -1) { + dims.d[0] = 1; + } + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); } config_->addOptimizationProfile(profile); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 6358e59ce3bc..5562f853383c 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -140,6 +140,12 @@ class TensorRTRuntime : public JSONRuntimeBase { const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); + if (!use_implicit_batch_) { + std::vector shape(data_entry_[eid]->shape, + data_entry_[eid]->shape + data_entry_[eid]->ndim); + auto dims = VectorToTrtDims(shape); + ICHECK(context->setBindingDimensions(binding_index, dims)); + } if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { @@ -300,7 +306,7 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.DeclareField("inputs", &engine_and_context.inputs); helper.DeclareField("outputs", &engine_and_context.outputs); helper.ReadAllFields(&reader); - const int batch_size = 1; + const int batch_size = GetBatchSize(); trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; return true; } diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index bf64096441be..7b7690b66528 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -1088,8 +1088,7 @@ int TVMGraphExecutor_SetupOpExecs(TVMGraphExecutor* executor) { printf("tvm_op: creating %s with node_id=%d\n", inode->param.func_name, nid); #endif // TVM_CRT_DEBUG TVMPackedFunc pf; - TVMGraphExecutor_CreateTVMOp(executor, &(inode->param), args, args_count, inode->inputs_count, - &pf); + TVMGraphExecutor_CreateTVMOp(executor, &(inode->param), args, args_count, &pf); executor->op_execs[nid] = pf; } } @@ -1109,7 +1108,7 @@ typedef struct TVMOpArgs { int32_t TVMGraphExecutor_CreateTVMOp(TVMGraphExecutor* executor, const TVMOpParam* param, DLTensorPtr* args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc* pf) { + TVMPackedFunc* pf) { int status = 0; uint32_t idx; TVMOpArgs arg_ptr; diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h index 47ef474778e0..c9b3ebe5c643 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h @@ -116,6 +116,6 @@ int TVMGraphExecutor_GetOutput(TVMGraphExecutor* executor, const int32_t idx, DL int32_t TVMGraphExecutor_CreateTVMOp(TVMGraphExecutor* executor, const TVMOpParam* param, DLTensorPtr* args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc* pf); + TVMPackedFunc* pf); #endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ diff --git a/src/runtime/crt/memory/stack_allocator.c b/src/runtime/crt/memory/stack_allocator.c index 7a41ca4241ab..ba205f8f209b 100644 --- a/src/runtime/crt/memory/stack_allocator.c +++ b/src/runtime/crt/memory/stack_allocator.c @@ -79,8 +79,15 @@ tvm_crt_error_t StackMemoryManager_Free(tvm_workspace_t* tvm_runtime_workspace, tvm_crt_error_t StackMemoryManager_Init(tvm_workspace_t* tvm_runtime_workspace, uint8_t* g_aot_memory, size_t workspace_size) { - tvm_runtime_workspace->next_alloc = g_aot_memory; - tvm_runtime_workspace->workspace = g_aot_memory; - tvm_runtime_workspace->workspace_size = workspace_size; + // We need to round up g_aot_memory in case it is not aligned to + // TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES. + uintptr_t unaligned_mask = TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - 1; + uint8_t* memory_aligned = + (uint8_t*)(((uintptr_t)g_aot_memory + unaligned_mask) & ~unaligned_mask); + uint32_t offset = (uintptr_t)(memory_aligned - g_aot_memory); + + tvm_runtime_workspace->next_alloc = memory_aligned; + tvm_runtime_workspace->workspace = memory_aligned; + tvm_runtime_workspace->workspace_size = workspace_size - offset; return kTvmErrorNoError; } diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a877bc634300..7d6879a62aba 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -153,12 +153,12 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -168,10 +168,10 @@ class CUDAWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), 0, strm, void_args, nullptr); + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -201,8 +201,8 @@ class CUDAWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; class CUDAPrepGlobalBarrier { @@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 32dd1d8020c9..35832e83f59c 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); - writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags); + writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); writer->EndObject(); } @@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); - helper.DeclareField("thread_axis_tags", &thread_axis_tags); + helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); + helper.DeclareOptionalField("thread_axis_tags", + &launch_param_tags); // for backward compatibility helper.ReadAllFields(reader); arg_types.resize(sarg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); - writer->Write(thread_axis_tags); + writer->Write(launch_param_tags); } bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; - if (!reader->Read(&thread_axis_tags)) return false; + if (!reader->Read(&launch_param_tags)) return false; return true; } diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc b/src/runtime/graph_executor/debug/graph_executor_debug.cc index 1ea01b19e8aa..2fa73971d000 100644 --- a/src/runtime/graph_executor/debug/graph_executor_debug.cc +++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc @@ -276,16 +276,20 @@ class GraphExecutorDebug : public GraphExecutor { * the module compared to GraphRuntimeDebug::RunIndividual as it runs the * entire graph in order. * + * \param collectors Optional user defined `MetricCollector`s to use with this profiling run. + * * \returns A table of per-op runtimes and total times. */ - profiling::Report Profile() { + profiling::Report Profile(Array collectors) { + std::vector cs(collectors.begin(), collectors.end()); + profiling::Profiler prof(devices_, cs); + // warm up. 1 iteration does not seem enough. for (int i = 0; i < 3; i++) { GraphExecutor::Run(); } - profiling::Profiler prof; - prof.Start(devices_); + prof.Start(); for (size_t i = 0; i < op_execs_.size(); ++i) { if (op_execs_[i]) { // get argument shapes @@ -359,7 +363,10 @@ PackedFunc GraphExecutorDebug::GetFunction(const std::string& name, *rv = this->RunIndividual(number, repeat, min_repeat_ms); }); } else if (name == "profile") { - return TypedPackedFunc([sptr_to_self, this]() { return this->Profile(); }); + return TypedPackedFunc)>( + [sptr_to_self, this](Array collectors) { + return this->Profile(collectors); + }); } else { return GraphExecutor::GetFunction(name, sptr_to_self); } diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 1084b4ee3ec4..7aae12b32377 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -380,7 +380,7 @@ void GraphExecutor::SetupOpExecs() { ICHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; std::shared_ptr op_args = nullptr; - std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args); for (size_t i = 0; i < inode.inputs.size(); i++) { uint32_t eid = this->entry_id(inode.inputs[i]); @@ -393,8 +393,7 @@ void GraphExecutor::SetupOpExecs() { } std::pair, std::shared_ptr > -GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector& args, - size_t num_inputs) { +GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector& args) { std::shared_ptr arg_ptr = std::make_shared(); // setup address. arg_ptr->args = args; diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 631605f630da..42b5c405b406 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -381,11 +381,10 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \brief Create an execution function given input. * \param attrs The node attributes. * \param args The arguments to the functor, including inputs and outputs. - * \param num_inputs Number of inputs. * \return The created executor. */ std::pair, std::shared_ptr> CreateTVMOp( - const TVMOpParam& attrs, const std::vector& args, size_t num_inputs); + const TVMOpParam& attrs, const std::vector& args); // Get node entry index. uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index e3ec155dc291..66d9a44099da 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -54,8 +54,8 @@ inline String get_name_mangled(const String& module_name, const String& name) { */ class MetadataNode : public Object { public: - /*! \brief number of inputs of the main function */ - int num_inputs = 1; + /*! \brief input information for the main function */ + Array inputs; /*! \brief number of outputs of the main function */ int num_outputs = 1; /*! \brief the executor to be used to run the model */ @@ -73,9 +73,9 @@ class MetadataNode : public Object { */ class Metadata : public ObjectRef { public: - TVM_DLL Metadata(int num_inputs, int num_outputs, String executor, String mod_name) { + TVM_DLL Metadata(Array inputs, int num_outputs, String executor, String mod_name) { auto n = make_object(); - n->num_inputs = num_inputs; + n->inputs = inputs; n->num_outputs = num_outputs; n->executor = executor; n->mod_name = mod_name; @@ -99,11 +99,14 @@ Module MetadataModuleCreate( const std::unordered_map& metadata, const std::unordered_map>& sym_vars); +/*! \brief A tag to specify whether or not dynamic shared memory is used */ +constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; std::vector arg_types; - std::vector thread_axis_tags; + std::vector launch_param_tags; void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 88501880557e..1e81ac1bbb34 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -178,7 +178,7 @@ void SaveToBinary(dmlc::Stream* stream) final { // initialize the METAL function. void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = metal::MetalWorkspace::Global(); m_ = m; sptr_ = sptr; @@ -186,7 +186,7 @@ void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_na num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; std::fill(scache_.begin(), scache_.end(), (id)nil); - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int dev_id = t->device.device_id; scache_[dev_id] = m->GetPipelineState(dev_id, func_name); @@ -201,7 +201,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons if (scache_[device_id] == nil) { scache_[device_id] = m_->GetPipelineState(device_id, func_name_); } - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2); auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); @@ -242,8 +242,8 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons // Device state cache per device. // mark as mutable, to enable lazy initialization mutable std::array, kMetalMaxNumDevice> scache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc MetalModuleNode::GetFunction(const std::string& name, @@ -261,7 +261,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); pf = PackFuncNonBufferArg(f, info.arg_types); }; return pf; diff --git a/src/runtime/micro/standalone/microtvm_graph_executor.cc b/src/runtime/micro/standalone/microtvm_graph_executor.cc index d91d0ea74ba4..d09efc4be50e 100644 --- a/src/runtime/micro/standalone/microtvm_graph_executor.cc +++ b/src/runtime/micro/standalone/microtvm_graph_executor.cc @@ -321,7 +321,7 @@ void MicroGraphExecutor::SetupStorage() { } std::function CreateTVMOp(const DSOModule& module, const TVMOpParam& param, - const DynArray& args, size_t num_inputs) { + const DynArray& args) { typedef union { void* v_handle; } TVMValue; @@ -389,7 +389,7 @@ void MicroGraphExecutor::SetupOpExecs() { args[index + inode.inputs.size()] = data_entry_[eid].ToDLTensor(); } assert(inode.op_type == "tvm_op"); - op_execs_[nid] = CreateTVMOp(*module_, inode.param, args, inode.inputs.size()); + op_execs_[nid] = CreateTVMOp(*module_, inode.param, args); } } diff --git a/src/runtime/module.cc b/src/runtime/module.cc index acc7fc7286d1..cff65452671e 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -113,7 +113,10 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); ICHECK(f != nullptr) << "Cannot find function " << name - << " in the imported modules or global registry"; + << " in the imported modules or global registry." + << " If this involves ops from a contrib library like" + << " cuDNN, ensure TVM was built with the relevant" + << " library."; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 23594b5d7d8a..1892ce780a4c 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -138,8 +138,12 @@ class TypeContext { std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard lock(mutex_); - ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; + if (tindex != 0) { + // always return the right type key for root + // for non-root type nodes, allocated slots should not equal 0 + ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) + << "Unknown type index " << tindex; + } return type_table_[tindex].name; } diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 4040d82b33e7..f6c7f6232819 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -40,14 +40,14 @@ class OpenCLWrappedFunc { // initialize the OpenCL function. void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { w_ = m->GetGlobalWorkspace(); m_ = m; sptr_ = sptr; entry_ = entry; func_name_ = func_name; arg_size_ = arg_size; - thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); + launch_param_config_.Init(arg_size.size(), launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { @@ -73,8 +73,8 @@ class OpenCLWrappedFunc { OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg)); } cl_command_queue queue = w_->GetQueue(t->device); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); + ThreadWorkLoad wl = launch_param_config_.Extract(args); + cl_uint work_dim = static_cast(launch_param_config_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { wl.work_size[i] *= wl.work_size[i + 3]; } @@ -96,8 +96,8 @@ class OpenCLWrappedFunc { std::string func_name_; // convert code for void argument std::vector arg_size_; - // thread axis config - ThreadAxisConfig thread_axis_cfg_; + // launch parameters config + LaunchParamConfig launch_param_config_; }; OpenCLModuleNode::~OpenCLModuleNode() { @@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types); } diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index ab9d674fad50..596b6ace8831 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -23,8 +23,10 @@ */ #include +#include #include #include +#include #include #include @@ -100,16 +102,37 @@ TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); namespace profiling { -void Profiler::Start(const std::vector& devs) { - CHECK(global_timers_.empty()) << "You can only call Start once per Profiler."; +Profiler::Profiler(std::vector devs, std::vector metric_collectors) + : devs_(devs), collectors_(metric_collectors) { + is_running_ = false; + std::vector wrapped_devs; for (auto dev : devs) { - global_timers_.emplace_back(dev, Timer::Start(dev)); + wrapped_devs.push_back(DeviceWrapper(make_object(dev))); + } + for (auto& x : collectors_) { + x->Init(wrapped_devs); + } + // reset the thread pool so that PAPI eventset hooks are set in all threads. + threading::ResetThreadPool(); +} + +void Profiler::Start() { + is_running_ = true; + for (auto dev : devs_) { + StartCall("Total", dev, {}); } } void Profiler::StartCall(String name, Device dev, std::unordered_map extra_metrics) { - in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics}); + std::vector> objs; + for (auto& collector : collectors_) { + ObjectRef obj = collector->Start(dev); + if (obj.defined()) { + objs.emplace_back(collector, obj); + } + } + in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics, objs}); } void Profiler::StopCall(std::unordered_map extra_metrics) { @@ -118,14 +141,21 @@ void Profiler::StopCall(std::unordered_map extra_metrics for (auto& p : extra_metrics) { cf.extra_metrics[p.first] = p.second; } + // collect the extra metrics from user defined collectors + for (const auto& obj : cf.extra_collectors) { + auto collector_metrics = obj.first->Stop(obj.second); + for (auto& p : collector_metrics) { + cf.extra_metrics[p.first] = p.second; + } + } in_flight_.pop(); calls_.push_back(cf); } void Profiler::Stop() { - // Stop all global timers. We wait to synchronize until we are making the report. - for (auto p : global_timers_) { - p.second->Stop(); + is_running_ = false; + for (size_t i = 0; i < devs_.size(); i++) { + StopCall(); } } @@ -198,6 +228,71 @@ String ReportNode::AsCSV() const { return s.str(); } +namespace { +void print_metric(std::ostream& os, ObjectRef o) { + if (o.as()) { + os << "\"" << Downcast(o) << "\""; + } else if (const CountNode* n = o.as()) { + os << "{\"count\":" << std::to_string(n->value) << "}"; + } else if (const DurationNode* n = o.as()) { + os << "{\"microseconds\":" << std::to_string(n->microseconds) << "}"; + } else if (const PercentNode* n = o.as()) { + os << "{\"percent\":" << std::to_string(n->percent) << "}"; + } else { + LOG(FATAL) << "Unprintable type " << o->GetTypeKey(); + } +} +} // namespace + +String ReportNode::AsJSON() const { + std::ostringstream s; + // DMLC's JSONWriter does not allow us to write a key value pair without + // implementing Write for the value. We want a specific write for the value, + // so we would have to implement a custom data structure for each type of + // value we want to print. Instead we construct the json by hand because it + // is easier. + s << "{"; + s << "\"calls\":["; + for (size_t i = 0; i < calls.size(); i++) { + size_t j = 0; + s << "{"; + for (const auto& kv : calls[i]) { + s << "\"" << kv.first << "\":"; + print_metric(s, kv.second); + if (j < calls[i].size() - 1) { + s << ","; + } + j++; + } + s << "}"; + if (i < calls.size() - 1) { + s << ","; + } + } + s << "],"; + s << "\"device_metrics\":{"; + size_t i = 0; + for (const auto& dev_kv : device_metrics) { + size_t j = 0; + s << "\"" << dev_kv.first << "\":{"; + for (const auto& metric_kv : dev_kv.second) { + s << "\"" << metric_kv.first << "\":"; + print_metric(s, metric_kv.second); + if (j < dev_kv.second.size() - 1) { + s << ","; + } + j++; + } + s << "}"; + if (i < device_metrics.size() - 1) { + s << ","; + } + i++; + } + s << "}}"; + return s.str(); +} + String ReportNode::AsTable(bool sort, bool aggregate) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; @@ -396,31 +491,11 @@ std::string DeviceString(Device dev) { } Report Profiler::Report(bool aggregate, bool sort) { - std::vector> global_times; - for (auto p : global_timers_) { - global_times.emplace_back(p.first, p.second->SyncAndGetElapsedNanos() / 1e3); - } - - double overall_time = 0; - for (auto p : global_times) { - overall_time = std::max(overall_time, p.second); - } - - std::unordered_map> device_metrics; - for (auto p : global_times) { - std::unordered_map row; - row["Name"] = String("Total"); - row["Duration (us)"] = ObjectRef(make_object(p.second)); - row["Percent"] = ObjectRef(make_object(p.second / overall_time * 100)); - row["Device"] = String(DeviceString(p.first)); - device_metrics[DeviceString(p.first)] = row; - } - - std::vector> rows; + // sync all timers and normalize rows + std::vector> rows; for (auto& cf : calls_) { std::unordered_map row; double us = cf.timer->SyncAndGetElapsedNanos() / 1e3; - row["Percent"] = ObjectRef(make_object(us / overall_time * 100)); row["Duration (us)"] = ObjectRef(make_object(us)); row["Count"] = ObjectRef(make_object(1)); row["Name"] = cf.name; @@ -431,7 +506,30 @@ Report Profiler::Report(bool aggregate, bool sort) { rows.push_back(row); } - return profiling::Report(rows, device_metrics); + // the last couple of call frames are the overall times + double overall_time_us = 0; + std::unordered_map> device_metrics; + for (size_t i = 0; i < devs_.size(); i++) { + auto row = rows[rows.size() - 1]; + rows.pop_back(); + device_metrics[Downcast(row["Device"])] = row; + overall_time_us = + std::max(overall_time_us, row["Duration (us)"].as()->microseconds); + } + + // Calculate percentages + for (auto& row : rows) { + row["Percent"] = ObjectRef(make_object( + row["Duration (us)"].as()->microseconds / overall_time_us * 100)); + } + + // convert to map + std::vector> converted_rows; + for (const auto& row : rows) { + converted_rows.push_back(row); + } + + return profiling::Report(converted_rows, device_metrics); } Report::Report(Array> calls, @@ -446,8 +544,16 @@ TVM_REGISTER_OBJECT_TYPE(DurationNode); TVM_REGISTER_OBJECT_TYPE(PercentNode); TVM_REGISTER_OBJECT_TYPE(CountNode); TVM_REGISTER_OBJECT_TYPE(ReportNode); +TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); +TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); +TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { + return n->AsJSON(); +}); +TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { + return DeviceWrapper(dev); +}); } // namespace profiling } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 567557c56794..487ad23e16b9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -147,12 +147,12 @@ class ROCMWrappedFunc { public: // initialize the ROCM function. void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& thread_axis_tags) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); - thread_axis_cfg_.Init(num_void_args, thread_axis_tags); + launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { @@ -164,13 +164,14 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. - ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), - wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); + ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, nullptr, + reinterpret_cast(&config))); } private: @@ -183,8 +184,8 @@ class ROCMWrappedFunc { // Device function cache per device. // mark as mutable, to enable lazy initialization mutable std::array fcache_; - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; }; PackedFunc ROCMModuleNode::GetFunction(const std::string& name, @@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& name, if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; ROCMWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncPackedArg(f, info.arg_types); } diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index cab04ec0db4a..c11e9f7ac084 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -258,26 +258,29 @@ class SpscTaskQueue { class ThreadPool { public: ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) { - for (int i = 0; i < num_workers_; ++i) { - // The SpscTaskQueue only hosts ONE item at a time - queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); - } const char* exclude_worker0 = getenv("TVM_EXCLUDE_WORKER0"); if (exclude_worker0 && atoi(exclude_worker0) == 0) { exclude_worker0_ = false; } - threads_ = std::unique_ptr( - new tvm::runtime::threading::ThreadGroup( - num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, - exclude_worker0_ /* include_main_thread */)); - num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); + Init(); } + ~ThreadPool() { for (std::unique_ptr& q : queues_) { q->SignalForKill(); } threads_.reset(); } + + void Reset() { + for (std::unique_ptr& q : queues_) { + q->SignalForKill(); + } + queues_.clear(); + threads_.reset(); + Init(); + } + int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); ICHECK(!launcher->is_worker) @@ -323,6 +326,19 @@ class ThreadPool { } private: + // Shared initialization code + void Init() { + for (int i = 0; i < num_workers_; ++i) { + // The SpscTaskQueue only hosts ONE item at a time + queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); + } + threads_ = std::unique_ptr( + new tvm::runtime::threading::ThreadGroup( + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + exclude_worker0_ /* include_main_thread */)); + num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); + } + // Internal worker function. void RunWorker(int worker_id) { SpscTaskQueue* queue = queues_[worker_id].get(); @@ -359,6 +375,10 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); }); +namespace threading { +void ResetThreadPool() { tvm::runtime::ThreadPool::ThreadLocal()->Reset(); } +} // namespace threading + } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c0393600b60c..ac8260ffbe39 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -19,7 +19,7 @@ /*! * \file thread_storage_scope.h - * \brief Extract thread axis configuration from TVMArgs. + * \brief Extract launch parameters configuration from TVMArgs. */ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ @@ -29,6 +29,8 @@ #include #include +#include "meta_data.h" + namespace tvm { namespace runtime { @@ -118,7 +120,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.empty()) { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { @@ -180,6 +184,8 @@ struct ThreadScope { struct ThreadWorkLoad { // array, first three are thread configuration. size_t work_size[6]; + // Dynamic shared memory allocation size in bytes. + size_t dyn_shmem_size{0}; /*! * \param i The block dimension. * \return i-th block dim @@ -191,17 +197,23 @@ struct ThreadWorkLoad { */ inline size_t grid_dim(size_t i) const { return work_size[i]; } }; -/*! \brief Thread axis configuration */ -class ThreadAxisConfig { +/*! \brief Launch parameters configuration */ +class LaunchParamConfig { public: - void Init(size_t base, const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& launch_param_tags) { base_ = base; std::vector filled(6, false); - for (size_t i = 0; i < thread_axis_tags.size(); ++i) { - const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::Create(tag); - arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); - filled[ts.rank * 3 + ts.dim_index] = true; + for (size_t i = 0; i < launch_param_tags.size(); ++i) { + const std::string& tag = launch_param_tags[i]; + if (tag == kUseDynamicSharedMemoryTag) { + ICHECK_EQ(i, launch_param_tags.size() - 1) + << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; + use_dyn_shared_memory_ = true; + } else { + ThreadScope ts = ThreadScope::Create(tag); + arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); + filled[ts.rank * 3 + ts.dim_index] = true; + } } work_dim_ = 1; for (int i = 0; i < 3; ++i) { @@ -221,6 +233,9 @@ class ThreadAxisConfig { w.work_size[arg_index_map_[i]] = size; } } + if (use_dyn_shared_memory_) { + w.dyn_shmem_size = static_cast(x.values[base_ + arg_index_map_.size()].v_int64); + } return w; } // return the work dim @@ -233,6 +248,8 @@ class ThreadAxisConfig { size_t work_dim_; /*! \brief The index mapping. */ std::vector arg_index_map_; + /*! \brief Whether or not use dynamic shared memory. */ + bool use_dyn_shared_memory_{false}; }; } // namespace runtime diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index e8b948d3d2ae..c2dc0307e166 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -234,8 +234,8 @@ TVMByteArray Executable::Save() { } void Executable::SaveGlobalSection(dmlc::Stream* strm) { - std::vector > globals(this->global_map.begin(), - this->global_map.end()); + std::vector> globals(this->global_map.begin(), + this->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; @@ -273,6 +273,20 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { primitive_names[packed_index] = it.first; } strm->Write(primitive_names); + std::map> primitive_attrs; + for (const auto& it : this->op_attrs) { + auto packed_index = static_cast(it.first); + std::map attrs; + for (const auto& elem : it.second) { + // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize + // strings for now + if (elem.second.as()) { + attrs[elem.first] = Downcast(elem.second); + } + } + primitive_attrs[packed_index] = attrs; + } + strm->Write(primitive_attrs); } // Serialize a virtual machine instruction. It creates a list that contains the @@ -569,6 +583,16 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { for (size_t i = 0; i < primitive_names.size(); i++) { this->primitive_map.insert({primitive_names[i], i}); } + + std::map> primitive_attrs; + STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs"); + for (const auto& fn : primitive_attrs) { + std::vector> attrs; + for (const auto& elem : fn.second) { + attrs.push_back({elem.first, String(elem.second)}); + } + this->op_attrs[fn.first] = Map(attrs.begin(), attrs.end()); + } } // Extract the `cnt` number of fields started at `start` from the list @@ -851,8 +875,8 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetV const auto* exec = dynamic_cast(mod.operator->()); ICHECK(exec); int idx = args[1]; - std::vector > globals(exec->global_map.begin(), - exec->global_map.end()); + std::vector> globals(exec->global_map.begin(), + exec->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index a7d65944d581..6d893114d623 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -43,26 +43,31 @@ namespace vm { PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "profile") { - return TypedPackedFunc([sptr_to_self, this](String arg_name) { - std::vector devices; - for (auto dev : devices_) { - if (dev.device_type > 0) { - devices.push_back(dev); - } - } - - auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); - // warmup - for (int i = 0; i < 3; i++) { - invoke(arg_name); - } - - prof_ = profiling::Profiler(); // reset profiler - prof_.Start(devices); - invoke(arg_name); - prof_.Stop(); - return prof_.Report(); - }); + return TypedPackedFunc)>( + [sptr_to_self, this](String arg_name, Array collectors) { + std::vector devices; + for (auto dev : devices_) { + if (dev.device_type > 0) { + devices.push_back(dev); + } + } + + std::vector cs(collectors.begin(), collectors.end()); + prof_ = profiling::Profiler(devices, cs); + + auto invoke = VirtualMachine::GetFunction("invoke", sptr_to_self); + // warmup + for (int i = 0; i < 3; i++) { + invoke(arg_name); + } + + prof_.operator*().Start(); + invoke(arg_name); + prof_.operator*().Stop(); + auto report = prof_.operator*().Report(); + prof_ = dmlc::optional(); // releases hardware counters + return report; + }); } else { return VirtualMachine::GetFunction(name, sptr_to_self); } @@ -80,7 +85,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun Index output_size, const std::vector& args) { ICHECK(exec_); ICHECK(!devices_.empty()) << "Device has not been initialized yet."; - if (prof_.IsRunning()) { + if (prof_ && prof_.operator*().IsRunning()) { // The device of any input of the operator is used for synchronization. ICHECK_GT(arg_count, 0U); ObjectRef arg = args[0]; @@ -122,11 +127,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& fun } metrics["Argument Shapes"] = profiling::ShapeString(shapes); - prof_.StartCall(packed_index_map_[packed_index], dev, metrics); + prof_.operator*().StartCall(packed_index_map_[packed_index], dev, metrics); } VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); - if (prof_.IsRunning()) { - prof_.StopCall(); + if (prof_ && prof_.operator*().IsRunning()) { + prof_.operator*().StopCall(); } } diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 521a9bd454e7..1efefda52b97 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_VM_PROFILER_VM_H_ #define TVM_RUNTIME_VM_PROFILER_VM_H_ +#include #include #include @@ -39,7 +40,7 @@ namespace vm { class VirtualMachineDebug : public VirtualMachine { public: - VirtualMachineDebug() : VirtualMachine() {} + VirtualMachineDebug() : VirtualMachine(), prof_({}) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; @@ -52,7 +53,7 @@ class VirtualMachineDebug : public VirtualMachine { const std::vector& args) final; std::unordered_map packed_index_map_; - profiling::Profiler prof_; + dmlc::optional prof_; }; } // namespace vm diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index 16987210b232..156f86dbb03e 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -156,6 +156,27 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, device_name = properties.properties.deviceName; driver_version = properties.properties.driverVersion; + switch (properties.properties.deviceType) { + case VK_PHYSICAL_DEVICE_TYPE_OTHER: + device_type = "other"; + break; + case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: + device_type = "integrated"; + break; + case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: + device_type = "discrete"; + break; + case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU: + device_type = "virtual"; + break; + case VK_PHYSICAL_DEVICE_TYPE_CPU: + device_type = "cpu"; + break; + default: + LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType; + break; + } + // By default, use the maximum API version that the driver allows, // so that any supported features can be used by TVM shaders. // However, if we can query the conformance version, then limit to diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 045628bc9092..412542029209 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -92,7 +92,8 @@ struct VulkanDeviceProperties { uint32_t max_storage_buffer_range{1 << 27}; uint32_t max_per_stage_descriptor_storage_buffer{4}; uint32_t max_shared_memory_per_block{16384}; - std::string device_name{"unknown device name"}; + std::string device_type{"unknown_device_type"}; + std::string device_name{"unknown_device_name"}; uint32_t driver_version{0}; uint32_t vulkan_api_version{VK_API_VERSION_1_0}; uint32_t max_spirv_version{0x10000}; diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 1fede98f7211..b4987eb321cf 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -50,6 +50,28 @@ VulkanDeviceAPI::VulkanDeviceAPI() { devices_.push_back(std::move(device)); } } + + // Move discrete GPUs to the start of the list, so the default + // device_id=0 preferentially uses a discrete GPU. + auto preference = [](const VulkanDevice& device) { + const std::string& type = device.device_properties.device_type; + if (type == "discrete") { + return 0; + } else if (type == "integrated") { + return 1; + } else if (type == "virtual") { + return 2; + } else if (type == "cpu") { + return 3; + } else { + return 4; + } + }; + + std::stable_sort(devices_.begin(), devices_.end(), + [&preference](const VulkanDevice& a, const VulkanDevice& b) { + return preference(a) < preference(b); + }); } VulkanDeviceAPI::~VulkanDeviceAPI() {} @@ -214,8 +236,8 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, if (property == "max_shared_memory_per_block") { *rv = int64_t(prop.max_shared_memory_per_block); } - if (property == ":string device_name") { - *rv = prop.device_name; + if (property == "device_name") { + *rv = String(prop.device_name); } if (property == "driver_version") { *rv = int64_t(prop.driver_version); diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 103b2aa7692c..0712f723bb64 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -33,13 +33,13 @@ namespace vulkan { void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags) { + const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; - thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + launch_param_config_.Init(num_buffer_args + num_pack_args, launch_param_tags); } void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, @@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } const auto& pipeline = scache_[device_id]; - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + ThreadWorkLoad wl = launch_param_config_.Extract(args); std::vector descriptor_buffers; descriptor_buffers.resize(num_buffer_args_); for (size_t i = 0; i < num_buffer_args_; ++i) { @@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, - info.thread_axis_tags); + info.launch_param_tags); return PackFuncNonBufferArg(std::move(f), info.arg_types); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index a174f22eba59..cd4774bf0f5a 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -58,7 +58,7 @@ class VulkanWrappedFunc { public: void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, - const std::vector& thread_axis_tags); + const std::vector& launch_param_tags); void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; @@ -73,11 +73,10 @@ class VulkanWrappedFunc { size_t num_buffer_args_; // number of packed arguments. size_t num_pack_args_; + // launch parameters configuration + LaunchParamConfig launch_param_config_; // Device state cache per device. // mark as mutable, to enable lazy initialization - // thread axis configuration - ThreadAxisConfig thread_axis_cfg_; - mutable std::array, kVulkanMaxNumDevice> scache_; }; diff --git a/src/target/build_common.h b/src/target/build_common.h index d2fe6468eef8..c66c2b52822e 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -53,7 +53,12 @@ inline std::unordered_map ExtractFuncInfo(co if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); + info.launch_param_tags.push_back(thread_axis[i]->thread_tag); + } + } + if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + if (opt.value()) { + info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 78f8a50e4e1b..7770e42086de 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -72,50 +72,45 @@ class CodeGenAMDGPU : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; + buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16), + llvm::GlobalValue::ExternalLinkage); + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; - } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } - buf = alloca; - } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); } - buf = global; } buf = builder_->CreatePointerCast( diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48ccefafe3c4..b83748b784b6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -523,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp *p_alignment = align_bits / 8; } +llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, + int alignment, + llvm::GlobalValue::LinkageTypes linkage) { + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); +#if TVM_LLVM_VERSION >= 100 + global->setAlignment(llvm::Align(alignment)); +#else + global->setAlignment(alignment); +#endif + return global; +} + std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); @@ -1390,11 +1407,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d5fcfab6d889..52c5b98a0025 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief The alignment of allocation */ int alignment{0}; }; @@ -294,6 +292,11 @@ class CodeGenLLVM : public ExprFunctor, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index); + + llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size, + unsigned int shared_address_space, int alignment, + llvm::GlobalValue::LinkageTypes linkage); + // The IRBuilder. using IRBuilder = llvm::IRBuilder; // The current function diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 9e56529ec9ef..15543eda423f 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -48,48 +48,44 @@ class CodeGenNVPTX : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } // maximum necessary alignment in the NV devices if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { -#if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); -#else - alloca->setAlignment(info.alignment); -#endif - } - buf = alloca; - } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); + buf = + AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + if (storage_scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - buf = global; + } + buf = alloca; + } else { + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, + llvm::GlobalValue::PrivateLinkage); + } } buf = builder_->CreatePointerCast( diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 24fb3dc95819..15a1493b8585 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -223,8 +223,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { found_linked_params = true; continue; } - ICHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " + << kv.second->GetTypeKey(); + continue; + } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); @@ -234,7 +238,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { } funcs.push_back(f); } - ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params)); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 99c9452975d4..8397044e8b93 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -861,12 +861,11 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - if (it != alloc_storage_scope_.end()) { - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - } + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -882,10 +881,6 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { BindThreadIndex(iv); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index ae451f39f89b..834c57ac10fd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -39,6 +39,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "codegen_source_base.h" namespace tvm { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71..7897490730a3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,6 +525,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ "; } } @@ -703,14 +705,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + std::string scope = GetPtrStorageScope(op->buffer_var); const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,18 +720,28 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } - constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + int32_t constant_size = op->constant_allocation_size(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; } - stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ac4d7e3666ea..7728773b13d7 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -192,25 +192,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } - void GenerateEntrypointForUnpackedAPI(const std::string& run_func) { + void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, + const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - int total_args = (metadata_->num_inputs + metadata_->num_outputs); - for (int i = 0; i < total_args; ++i) { - code_ << "arg" << i; + unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + for (unsigned int i = 0; i < total_args; ++i) { + code_ << "void* arg" << i; if (i + 1 != total_args) { code_ << ","; } } code_ << ");\n"; - code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (int i = 0; i < metadata_->num_inputs; ++i) { + for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->num_inputs + i; + int j = metadata_->inputs.size() + i; code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; if (i + 1 != metadata_->num_outputs) { code_ << ","; @@ -220,11 +221,12 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n"; } - void GenerateEntrypointForPackedAPI(const std::string& run_func) { + void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name, + const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle);\n"; - code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main; + code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func; @@ -232,25 +234,70 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n"; } + void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func, + const std::string& mod_name) { + code_ << "#include <" << mod_name << ".h>\n"; + code_ << "TVM_DLL int32_t " << run_func << "("; + unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); + for (unsigned int i = 0; i < total_args; ++i) { + code_ << "void* arg" << i; + if (i + 1 != total_args) { + code_ << ","; + } + } + code_ << ");\n"; + code_ << "int32_t " << entrypoint_name << "("; + code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs," + << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs" + << ") {"; + code_ << "return " << run_func << "("; + for (const auto& input : metadata_->inputs) { + code_ << "inputs->" << input << ","; + } + if (metadata_->num_outputs == 1) { + code_ << "outputs->output"; + } else { + for (int i = 0; i < metadata_->num_outputs; ++i) { + code_ << "outputs->output" << i; + if (i + 1 != metadata_->num_outputs) { + code_ << ","; + } + } + } + code_ << ");\n"; + code_ << "}\n"; + } + void GenerateAOTDescriptor() { - const std::string run_func = ::tvm::runtime::symbol::tvm_run_func_suffix; - const std::string run_func_mangled = runtime::get_name_mangled(metadata_->mod_name, run_func); + const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix; + const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix; + const std::string run_func_mangled = + runtime::get_name_mangled(metadata_->mod_name, run_func_suffix); + const std::string entrypoint_mangled = + runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); - code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n"; + auto unpacked_api = target_->GetAttr("unpacked-api").value_or(Bool(false)); + auto interface_api = target_->GetAttr("interface-api").value_or(String("packed")); + code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; - code_ << "extern \"C\"\n"; + code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - if (target_->GetAttr("unpacked-api").value_or(Bool(false))) { - GenerateEntrypointForUnpackedAPI(run_func_mangled); + + if (unpacked_api) { + if (interface_api == "c") { + GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); + } else { + GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); + } } else { - GenerateEntrypointForPackedAPI(run_func_mangled); + ICHECK_EQ(interface_api, "packed") << "Packed interface required for packed operators"; + GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); } - code_ << "const tvm_model_t " << network_mangled << " = {\n" - << " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n" - << " .num_input_tensors = " << metadata_->num_inputs << ",\n" - << " .num_output_tensors = " << metadata_->num_outputs << ", \n" - << "};\n"; + + code_ << "#ifdef __cplusplus\n"; + code_ << "}\n"; + code_ << "#endif\n"; } void CreateSource() { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d52bee44e98..42d0027a326f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -32,6 +32,7 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../runtime/vulkan/vulkan_shader.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { namespace codegen { @@ -42,7 +43,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; - uint32_t num_buffer = 0; + uint32_t i_buffer = 0; // Currently, all storage and uniform buffer arguments are passed as // a single descriptor set at index 0. If ever non-zero, must @@ -52,24 +53,25 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { - if (auto* ptr = arg->type_annotation.as()) { - auto* prim = ptr->element_type.as(); - ICHECK(prim); - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::UInt(1)) { - // We need a physically addressable buffer type to support boolean tensors. - // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = DataType::UInt(8); - } - spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), - descriptor_set, num_buffer); - builder_->SetName(arg_value, arg->name_hint); - storage_info_[arg.get()].UpdateContentType(value_storage_type); - var_map_[arg.get()] = arg_value; - } else { - LOG(FATAL) << "require all handles to be typed"; + auto* ptr = arg->type_annotation.as(); + ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + auto* prim = ptr->element_type.as(); + ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::Bool()) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } - ++num_buffer; + spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), + descriptor_set, i_buffer++); + builder_->SetName(arg_value, arg->name_hint); + storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint); + var_map_[arg.get()] = arg_value; } else { pod_args.push_back(arg); } @@ -94,7 +96,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: } else { shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer); + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; @@ -403,14 +405,19 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { ICHECK(is_one(op->predicate)); - auto it = storage_info_.find(op->buffer_var.get()); + + DataType desired_read_type = op->dtype; + if (desired_read_type == DataType::Bool()) { + desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); + } + + const VarNode* buffer_var = op->buffer_var.get(); + auto it = storage_info_.find(buffer_var); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - if (!info.content_fixed) { - info.UpdateContentType(op->dtype); - } + info.CheckContentType(desired_read_type, op->index.dtype().lanes()); - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -418,47 +425,38 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (info.is_volatile) { mask |= spv::MemoryAccessVolatileMask; } - if (op->dtype.lanes() == 1) { + + if (desired_read_type == info.element_type) { + // Requested a single value from an array. This may be a scalar load + // or a vectorized load, based on the array element type. spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - if (op->dtype == DataType::UInt(1)) { - // A bool tensor is backed by a byte buffer, we cast to bool here. - auto bool_ty = builder_->GetSType(DataType::UInt(1)); - return builder_->Cast(bool_ty, loaded); - } else { - ICHECK_EQ(info.content_type, op->dtype) - << "Vulkan only allow one type access to the same buffer"; - return loaded; + // OpTypeBool have no physical address/storage. Here, cast from + // the storage type to an OpTypeBool. + if (op->dtype == DataType::Bool()) { + auto spirv_bool = builder_->GetSType(DataType::Bool()); + loaded = builder_->Cast(spirv_bool, loaded); } + return loaded; + + } else if (desired_read_type.element_of() == info.element_type) { + // Requested several elements returned as an array. Read out each + // element and concatenate into the result. + std::vector values; + auto f = [&](int i, spirv::Value index) { + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + }; + this->Scalarize(op->index, f); + return builder_->Concat(values); + } else { - if (op->dtype.element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - std::vector values; - auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); - }; - this->Scalarize(op->index, f); - return builder_->Concat(values); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->dtype.lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - } - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; + LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint + << "' with element type " << info.element_type << " using index of type " + << op->index->dtype << " to produce output of type " << op->dtype; + return spirv::Value(); } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - return spirv::Value(); } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { @@ -481,12 +479,9 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { auto it = storage_info_.find(op->buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; + info.CheckContentType(op->value.dtype(), op->index.dtype().lanes()); - if (!info.content_fixed) { - info.UpdateContentType(op->value.dtype()); - } - - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -496,37 +491,29 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.dtype().lanes() == 1) { - ICHECK_EQ(info.content_type, op->value.dtype()) + if (op->value.dtype() == info.element_type) { + // Requested store of a single value. This may be a scalar store + // or a vectorized store, based on the array element type. + ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); + + } else if (op->value.dtype().element_of() == info.element_type) { + // Requested store of several arbitrarily located values. Extract + // each value from the composite, then assign to the buffer. + auto f = [&](int i, spirv::Value index) { + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + builder_->MakeInst(spv::OpStore, ptr, elem, mask); + }; + this->Scalarize(op->index, f); + } else { - if (op->value.dtype().element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - builder_->MakeInst(spv::OpStore, ptr, elem, mask); - }; - this->Scalarize(op->index, f); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->value.dtype().lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - builder_->MakeInst(spv::OpStore, ptr, value, mask); - return; - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - } + LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" + << op->buffer_var->name_hint << "' with element type " << info.element_type + << " using index of type " << op->index->dtype; } } @@ -644,13 +631,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + spirv::Value buf; - StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (info.scope.rank == runtime::StorageRank::kShared) { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); @@ -660,8 +648,10 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { builder_->SetName(buf, op->buffer_var->name_hint); - ICHECK(!info.content_fixed); - info.UpdateContentType(op->dtype); + StorageInfo& info = storage_info_[op->buffer_var.get()]; + ICHECK(!info.element_type_known); + info.SetContentType(op->dtype, op->buffer_var->name_hint); + ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -677,10 +667,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3868322a74e0..8b14754f617f 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -114,49 +114,104 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: - /*! \brief The storage information */ + /*! \brief Storage information for a buffer */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; + /*! \brief The name of the tir::Var for the buffer + * + * Used for error messages. + */ + std::string name_hint; + /*! \brief Whether it is volatile */ bool is_volatile{false}; - /*! \brief Whether it is volatile */ - bool content_fixed{false}; - /*! \brief Current content type */ - DataType content_type{DataType::Handle()}; - - // Update content type if it hasn't beenupdated. - void UpdateContentType(DataType type) { - if (content_fixed) { - ICHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; - } else { - this->content_type = type; - content_fixed = true; - } + + /*! \brief Whether the element type of the buffer is known. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + bool element_type_known{false}; + + /*! \brief The known element type of the buffer. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + DataType element_type{DataType()}; + + /* \brief Check that the access type matches the known type + * + * Asserts that the type given is the same as the type previously + * stored in this array. + * + * @param type The data type being stored/loaded in the buffer + * + * @param index_lanes The number of lanes of the index. The + * number of lanes in the value being stored/loaded should be the + * product of the number of lanes of the buffer element type and + * the number of lanes of the index. + */ + void CheckContentType(DataType type, int index_lanes = 1) { + ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint + << " no previous element type defined"; + DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); + ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint + << " as element type " << type << " using an index of size " + << index_lanes << " when the element type is " << element_type; + } + + // Update content type if it hasn't been updated. + void SetContentType(DataType type, std::string name_hint) { + ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint + << " a second time."; + this->element_type = type; + this->name_hint = name_hint; + element_type_known = true; } }; // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); + spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); + // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; + // The builder std::unique_ptr builder_; + // Work group size of three uint32_t workgroup_size_[3]; + // Likely branch uint32_t weight_likely_branch_{128}; + + /* The data type used for the backing array for booleans. + * + * Currently matched to the data type used in Buffer::vstore and + * Buffer::vload. In the future, this should be the smallest + * integer type supported by the device, as not all Vulkan + * implementations support int8. + */ + DataType boolean_storage_type_{DataType::Int(8)}; + // the storage scope of allocation std::unordered_map storage_info_; + // The definition of local variable. std::unordered_map var_map_; + // The analyzer. std::unique_ptr analyzer_; + // deep comparison of PrimExpr ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; }; diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d037b9dfdbdb..3ad04eb3d577 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -249,7 +249,7 @@ Map UpdateVulkanAttrs(Map attrs) { "driver_version", "vulkan_api_version", "max_spirv_version"}; - std::vector str_opts = {"device_name"}; + std::vector str_opts = {"device_name", "device_type"}; for (auto& key : bool_opts) { if (!attrs.count(key)) { @@ -299,6 +299,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("runtime") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") + .add_attr_option("interface-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("c", kDLCPU) @@ -310,6 +311,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("executor") .add_attr_option("workspace-byte-alignment") .add_attr_option("unpacked-api") + .add_attr_option("interface-api") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) @@ -387,6 +389,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_per_stage_descriptor_storage_buffer") .add_attr_option("max_shared_memory_per_block") // Other device properties + .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 76fed053fda1..240adf14b31d 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -834,7 +834,7 @@ std::pair ImplicationNotContainingVars( return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && (pair_b.first || pair_a.second) && (pair_a.second || pair_b.second)}; - } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { + } else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { return {cond, const_true()}; } else { return {const_true(), cond}; @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, // Keep only those variables of the new vars which are used in the new_expr Array used_res_variables; for (const Var& var : res->dst->variables) { - if (ExprUseVar(new_expr, var)) { + if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) { ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; used_res_variables.push_back(var); } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..c73a6e0ce120 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_mapshape, tensor->dtype, tensor->GetNameHint()); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; // Step 3. Add Buffer to root_alloc @@ -270,7 +270,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { const te::Tensor& tensor = op.output(0); // Check op is in op list ICHECK(info.IsArg(tensor)); - const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + const Buffer& buffer = + decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); info.tensor2buffers[tensor] = buffer; } else if (const auto* compute_op = op.as()) { // Case 2. ComputeOp (te.compute) diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index da20dd875ba5..2ed5fd4029a2 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t))); + PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; @@ -177,7 +177,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); + res_handles[idx] = + Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } @@ -224,12 +225,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = - AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 0aa279fb9246..447fc501d03b 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, fbanned)) { + if (tir::UsesVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { - if (tir::ExprUseVar(pred, fbanned)) { + if (tir::UsesVar(pred, fbanned)) { LOG(FATAL) << "Tensorize failed, split condition " << pred << " relies on var defined inside tensorize scope"; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..825092d20ac0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. @@ -175,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +214,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = + ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..2063fc7cad6a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { @@ -95,7 +92,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +119,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index b1da536f1dad..8f87ef920784 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -26,6 +26,7 @@ #include #include +#include "../transforms/ir_utils.h" namespace tvm { namespace tir { @@ -65,8 +66,12 @@ class BlockReadWriteDetector : public StmtExprVisitor { std::vector> read_regions_; /*! \brief The write regions of the current block */ std::vector> write_regions_; + /*! \brief The opaque regions of the current block */ + std::vector> opaque_regions_; /*! \brief The outside buffer data mapping to its buffer */ Map buffer_var_map_; + /*! \brief The target buffer var mapping to its matching */ + std::unordered_map match_buffers_; /*! \brief The analyzer for simplifying*/ arith::Analyzer analyzer_; @@ -78,14 +83,18 @@ class BlockReadWriteDetector : public StmtExprVisitor { * \param region The provided region */ void Update(std::vector* buffers, std::vector>* regions, - const Buffer& buffer, const std::vector& region); + Buffer buffer, std::vector region); /*! \brief Helper function to collect access regions. */ Array CollectRegions(const std::vector& buffers, const std::vector>& regions); - /*! \brief Helper function to add a opaque buffer. */ - void AddOpaque(const Var& buffer_var); + /*! \brief Helper function to convert matched access region to source region. */ + std::vector ConvertMatchedRegion(const MatchBufferRegion& match_buffer, + const std::vector& int_sets) const; + + /*! \brief Helper function to update a opaque access. */ + void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -97,8 +106,13 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - ICHECK(stmt.as() != nullptr) - << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + const auto* block = stmt.as(); + ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + for (const MatchBufferRegion& match_buffer : block->match_buffers) { + const Var& target_var = match_buffer->buffer->data; + match_buffers_[target_var.get()] = match_buffer; + buffer_var_map_.Set(target_var, match_buffer->buffer); + } StmtExprVisitor::operator()(stmt); } @@ -111,18 +125,13 @@ Array BlockReadWriteDetector::CollectWrites() { } Array BlockReadWriteDetector::CollectOpaques() { - Array res; - res.reserve(opaque_buffers_.size()); - for (const Buffer& buffer : opaque_buffers_) { - res.push_back(BufferRegion::FullRegion(buffer)); - } - return res; + return CollectRegions(opaque_buffers_, opaque_regions_); } -void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef(op)); } +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); ExprVisitor::VisitExpr_(op); } @@ -143,7 +152,7 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { } void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { - AddOpaque(op->buffer_var); + UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); } @@ -184,11 +193,39 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } } +std::vector BlockReadWriteDetector::ConvertMatchedRegion( + const MatchBufferRegion& match_buffer, const std::vector& int_sets) const { + const Buffer& buffer = match_buffer->buffer; + + Region region; + region.reserve(int_sets.size()); + ICHECK_EQ(buffer->shape.size(), int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + const tvm::arith::IntSet& int_set = int_sets[i]; + region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + + region = ConvertRegion(match_buffer, region); + + std::vector result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::EvalSet(range, dom_map_)); + } + return result; +} + void BlockReadWriteDetector::Update(std::vector* buffers, - std::vector>* regions, - const Buffer& buffer, - const std::vector& region) { + std::vector>* regions, Buffer buffer, + std::vector region) { if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; + // Handle match_buffer remap + auto it = match_buffers_.find(buffer->data.get()); + if (it != match_buffers_.end()) { + const MatchBufferRegion& match_buffer = it->second; + buffer = match_buffer->source->buffer; + region = ConvertMatchedRegion(match_buffer, std::move(region)); + } ICHECK_EQ(buffers->size(), regions->size()) << " Expected the buffer and regions to have the same size "; for (size_t i = 0; i < regions->size(); ++i) { @@ -200,8 +237,8 @@ void BlockReadWriteDetector::Update(std::vector* buffers, return; } } - buffers->push_back(buffer); - regions->push_back(region); + buffers->push_back(std::move(buffer)); + regions->push_back(std::move(region)); } Array BlockReadWriteDetector::CollectRegions( @@ -213,8 +250,9 @@ Array BlockReadWriteDetector::CollectRegions( for (size_t i = 0; i < regions.size(); ++i) { Array region; region.reserve(regions[i].size()); + ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { - tvm::arith::IntSet range = regions[i][j]; + const tvm::arith::IntSet& range = regions[i][j]; region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } res.push_back(BufferRegion(buffers[i], region)); @@ -222,14 +260,18 @@ Array BlockReadWriteDetector::CollectRegions( return res; } -void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) { +void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { auto it = buffer_var_map_.find(buffer_var); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; - for (const Buffer& opaque_buffer : opaque_buffers_) { - if (buffer.same_as(opaque_buffer)) return; + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); } - opaque_buffers_.push_back(buffer); + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); } } diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 6f2622f3a61e..e680d689735d 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -85,9 +85,17 @@ class LCADetector : public StmtExprVisitor { for (const Buffer& buf : op->alloc_buffers) { buffer_var_map_.emplace(buf->data.get(), buf.get()); } + const ScopeInfo* parent_scope = ancestor_scopes_.back(); auto* current_scope = arena_.make(parent_scope, op, n); + ancestor_scopes_.push_back(current_scope); + // Update match_buffers + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + UpdateBufferLCA(match_buffer->source->buffer.get()); + match_buffers_.insert(match_buffer->buffer.get()); + } + StmtExprVisitor::VisitStmt_(op); ancestor_scopes_.pop_back(); } @@ -129,8 +137,11 @@ class LCADetector : public StmtExprVisitor { } void UpdateBufferLCA(const BufferNode* buffer) { - const ScopeInfo*& lca = buffer_lca_[buffer]; - lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + if (match_buffers_.find(buffer) == match_buffers_.end()) { + // Ingore buffer created by block match_buffer + const ScopeInfo*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_scopes_.back()); + } } static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { @@ -164,6 +175,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The match buffers inside blocks. */ + std::unordered_set match_buffers_ = {}; /*! \brief Internal arena. */ support::Arena arena_; }; diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7eb8013f2a85..7f48cc439234 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } + if (lhs.as()) { + return false; + } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index 40a2cce70ae9..c4acd2b74aad 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -22,23 +22,33 @@ * \brief Implementation of simple passes */ #include -#include #include namespace tvm { namespace tir { -class VarTouchVisitor : public ExprVisitor { +class VarTouchVisitor : public StmtExprVisitor { public: - explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) + : var_set_(std::move(var_set)) {} + + void VisitStmt(const Stmt& stmt) final { + if (use_var_) return; + StmtExprVisitor::VisitStmt(stmt); + } void VisitExpr(const PrimExpr& e) final { if (use_var_) return; - ExprVisitor::VisitExpr(e); + StmtExprVisitor::VisitExpr(e); } void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitStmt_(const StoreNode* op) final { + Handle(op->buffer_var.get()); + StmtVisitor::VisitStmt_(op); + } + void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); ExprVisitor::VisitExpr_(op); @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; -bool ExprUseVar(const PrimExpr& e, std::function var_set) { - VarTouchVisitor visitor(var_set); - visitor(e); +bool UsesVar(const Stmt& stmt, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(stmt); + return visitor.use_var_; +} + +bool UsesVar(const PrimExpr& expr, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(expr); return visitor.use_var_; } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index afd3c7add605..10d857bdc953 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,6 +30,8 @@ #include #include +#include "../transforms/ir_utils.h" + namespace tvm { namespace tir { @@ -58,11 +60,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size - if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { + if (scope == "local") { size_t size = static_cast(op->constant_allocation_size()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { + } else if (scope == "shared") { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } @@ -78,15 +81,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - std::string op_value = op->value.as()->value; - if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); - } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); - } - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -211,8 +206,6 @@ class GPUCodeVerifier : public StmtExprVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -230,8 +223,6 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector errors_; void Reset_() { - visited_local_buffers_.clear(); - visited_shared_buffers_.clear(); local_memory_per_block_ = 0; shared_memory_per_block_ = 0; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 3c29e4e84bca..2089ead98168 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -170,7 +170,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; + ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..335ff19dd775 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -32,6 +32,8 @@ #include #include +#include "../../arith/pattern_match.h" + namespace tvm { namespace tir { @@ -45,10 +47,11 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, - Array(), PrimExpr(), name, "", 0, 0, kDefault, span); + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, + Array(), PrimExpr(), name, 0, 0, kDefault, span); } // Split the given expression w.r.t the add operator @@ -180,7 +183,12 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - PrimExpr simplified_base = analyzer->Simplify(base); + PrimExpr simplified_base = base; + arith::PVar x, y; + if ((floordiv(x, y) * y + floormod(x, y)).Match(simplified_base)) { + simplified_base = x.Eval(); + } + simplified_base = analyzer->Simplify(simplified_base); std::vector eles = ExprSplitAddition(simplified_base); std::list mult_exprs; std::list > mod_exprs; @@ -311,6 +319,15 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { } } +String Buffer::scope() const { + const auto* ptr_type = (*this)->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (ptr_type->storage_scope.empty()) { + return "global"; + } + return ptr_type->storage_scope; +} + Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; @@ -350,7 +367,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->data_alignment, 0, n->buffer_type); } @@ -383,8 +400,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type, Span span) { + PrimExpr elem_offset, String name, int data_alignment, int offset_factor, + BufferType buffer_type, Span span) { DataType storage_dtype = dtype; // specially handle bool if (storage_dtype == DataType::Bool()) { @@ -401,10 +418,6 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); - if (scope.length() == 0) { - scope = "global"; - } - n->scope = std::move(scope); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); } @@ -436,11 +449,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 11); - auto buffer_type = args[9].operator String(); + ICHECK_EQ(args.size(), 10); + auto buffer_type = args[8].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], - type, args[10]); + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, args[9]); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); @@ -449,5 +462,7 @@ TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h new file mode 100644 index 000000000000..8dac41a02e57 --- /dev/null +++ b/src/tir/ir/buffer_common.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tir/ir/buffer_common.h + * \brief Common utils for buffer access + */ +#ifndef TVM_TIR_IR_BUFFER_COMMON_H_ +#define TVM_TIR_IR_BUFFER_COMMON_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Returns the type of object pointed to. + * + * \param type The type to be checked. + * + * \return A (bool, DataType) pair. If the type is a pointer to a + * primitive, the boolean is true and the DataType is the pointed-to + * type. Otherwise, the boolean is false and the DataType is + * default-constructed. This can be replaced with std::optional with + * C++17 if/when C++17 is required. + */ +inline std::pair GetPointerType(const Type& type) { + if (type.defined()) { + if (auto* ptr_type = type.as()) { + if (auto* prim_type = ptr_type->element_type.as()) { + return {true, prim_type->dtype}; + } + } + } + + return {false, DataType()}; +} + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_BUFFER_COMMON_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 352f75abdf5e..afc5c36ebb92 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -25,10 +25,8 @@ #include #include -#include -#include - #include "../../support/str_escape.h" +#include "buffer_common.h" namespace tvm { namespace tir { @@ -618,8 +616,42 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); - ICHECK_EQ(dtype.lanes(), index.dtype().lanes()); - ICHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Cannot check element type of array, as it may be different than + // the loaded type in some cases. + // + // 1. Booleans use DataType::Int(8) while stored, and the codegens + // handle cast to boolean. + // + // 2. The StorageRewrite pass can merge multiple allocations at + // the same scope, regardless of element type. The codegen is + // then responsible for casting to the output type. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK(dtype.element_of() == pointer_type.second.element_of()) + // << "Type mismatch, cannot load type " << dtype << " from buffer " << + // buffer_var->name_hint + // << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + // The C-based codegens assume that all loads occur on a array with + // non-vectorized elements, and cast between + // vectorized/non-vectorized arrays as needed. Ideally, these + // should be changed to explicit casts in the TIR graph, rather than + // being handled at the code-gen level. + ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); + ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->dtype = dtype; diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index c15b3bb47bf4..f265a8ae2b1b 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -36,15 +36,12 @@ namespace tir { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { public: - explicit ScriptCompleter(Map* buffer_var_map, bool contain_root) - : buffer_var_map_(buffer_var_map), contain_root_(contain_root) {} + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} /*! \brief Whether the stmt contains at least one block. */ bool contains_block = false; private: Map* buffer_var_map_; - bool contain_root_; - bool visited_root_ = false; Stmt VisitStmt_(const BlockRealizeNode* op) override { contains_block = true; Stmt body = StmtMutator::VisitStmt_(op); @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator { } Stmt VisitStmt_(const BlockNode* op) override { - bool is_root_block = contain_root_ && !visited_root_; - visited_root_ = true; // Buffers allocated in the block can be accessed by its body. for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->Set(target_buffer->data, target_buffer); + } Block block = Downcast(StmtMutator::VisitStmt_(op)); // Remove buffers allocated inside block to detect its access region for (const auto& alloc_buffer : op->alloc_buffers) { buffer_var_map_->erase(alloc_buffer->data); } + for (const auto& match_buffer : op->match_buffers) { + const Buffer& target_buffer = match_buffer->buffer; + buffer_var_map_->erase(target_buffer->data); + } // Get access detection mask // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write int mask = 0; @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator { } // ignore root block or blocks which already has reads/writes regions if (mask != 0) { - if (op->iter_vars.empty()) { - // non-root opaque block is not allowed - CHECK(is_root_block) - << "ValueError: Can not auto detect buffer access region for an opaque block. Please " - "annotate the access region manually."; - return std::move(block); - } auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); const Array& reads = access_region[0]; const Array& writes = access_region[1]; @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } bool contain_root = root_allocates.empty() && func->body->IsInstance() && Downcast(func->body)->block->iter_vars.empty(); - ScriptCompleter script_completer(&buffer_var_map, contain_root); + ScriptCompleter script_completer(&buffer_var_map); // generate surrounding loops automatically Stmt res = script_completer(func->body); // generate root block automatically diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc new file mode 100644 index 000000000000..aa5f271c20c2 --- /dev/null +++ b/src/tir/ir/specialize.cc @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/specialize.cc + * \brief Specialize parameters of PrimFunc. + */ +#include +#include +#include +#include + +#include + +#include "functor_common.h" + +namespace tvm { +namespace tir { + +using VarMap = std::unordered_map; + +/**************** Helper functions ****************/ + +/*! \brief Helper function to check whether the given var is in function parameter list. */ +inline bool IsParam(const PrimFunc& func, const Var& param) { + return std::any_of(func->params.begin(), func->params.end(), + [&](const Var& var) { return var.same_as(param); }); +} + +/**************** Specializer ****************/ + +/*! \brief Mutator to specialize function and remove const parameters */ +class PrimFuncSpecializer : public StmtExprMutator { + public: + explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {} + + static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { + PrimFuncSpecializer specializer(var_map); + // Updating Buffer map + Map buffer_map; + bool buffer_map_updated = false; + for (const auto& it : f->buffer_map) { + const Var& var = it.first; + const Buffer& buffer = it.second; + Buffer new_buffer = specializer.MutateBuffer(buffer); + buffer_map.Set(var, new_buffer); + if (!new_buffer.same_as(buffer)) { + buffer_map_updated = true; + specializer.buffer_map_[buffer] = new_buffer; + } + } + + // Updating parmeters + Array params; + bool param_updated = false; + for (const auto& var : f->params) { + // Remove parmeters which has been specialized. + if (var_map.find(var) == var_map.end()) { + params.push_back(var); + } else { + param_updated = true; + } + } + + // Updating function body + Stmt body = specializer(f->body); + + if (param_updated || buffer_map_updated || !f->body.same_as(body)) { + PrimFuncNode* f_ptr = f.CopyOnWrite(); + f_ptr->params = std::move(params); + f_ptr->buffer_map = std::move(buffer_map); + f_ptr->body = std::move(body); + } + return f; + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + // Step.0. Define buffer mappings which is allocated inside the block + Array alloc_buffers = MutateArray( + op->alloc_buffers, + std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + + // Step.1. Recursively visit block body + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + Array reads = MutateArray( + op->reads, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + + if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->alloc_buffers = std::move(alloc_buffers); + n->reads = std::move(reads); + n->writes = std::move(writes); + return Stmt(n); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->buffer = it->second; + return Stmt(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } else { + auto n = make_object(*op); + n->buffer = it->second; + return PrimExpr(n); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(GetRef(op)); + if (it == var_map_.end()) { + return GetRef(op); + } else { + return it->second; + } + } + + private: + Buffer MutateBuffer(const Buffer& buffer) const { + Array shape = + MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + Array strides = + MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + + PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_); + + if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && + buffer->strides.same_as(strides)) { + return buffer; + } else { + auto n = make_object(*buffer.get()); + n->elem_offset = std::move(elem_offset); + n->shape = std::move(shape); + n->strides = std::move(strides); + return Buffer(n); + } + } + + Range MutateRange(const Range& range) { + PrimExpr min = this->VisitExpr(range->min); + PrimExpr extent = this->VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(std::move(min), std::move(extent)); + } + } + + Buffer MutateAllocBuffer(const Buffer& alloc_buf) { + Buffer buf = MutateBuffer(alloc_buf); + if (buf.same_as(alloc_buf)) { + return alloc_buf; + } else { + ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); + buffer_map_[alloc_buf] = buf; + return buf; + } + } + + BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + Array region = + MutateArray(buffer_region->region, + std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); + if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(it->second, std::move(region)); + } + } + + private: + /*! \brief The vars to be substitute and their values */ + const VarMap& var_map_; + /*! \brief map from old buffer to mutated buffer */ + std::unordered_map buffer_map_; +}; + +/*! + * \brief Update Specialize var map with buffer matching. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_buf The matching buffer. + * \param var_map The var mapping to be updated. + * \note This function will match target buffer's shape, strides and element_offset + * For example, we define a buffer in PrimFunc: + * A = tir.match_buffer(a, [m, n]) + * + * Then we match it with a buffer B = tir.decl_buffer((8, 16)) + * + * It means we have two var mappings here: m = 8 and n = 16 + * + * If the buffer signature is not a Var, the mapping will fail. + * e.g. A = tir.match_buffer(a, [m * 2, n + 1]) + */ +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, + VarMap* var_map) { + // preliminaries + tir::ExprDeepEqual equal; + + auto it = func->buffer_map.find(param); + CHECK(it != func->buffer_map.end()) + << "ValueError: specialize expects param to be in PrimFunc's buffer_map"; + const Buffer& buf_to_specialize = (*it).second; + + // build var mapping using specific_buf's parameters + auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { + if (!equal(new_expr, old_expr)) { + CHECK(old_expr->IsInstance()) + << "TypeError: The signature of target buffer exprected an independent Var, but got " + << old_expr << "."; + const Var& var = Downcast(old_expr); + auto it = var_map->find(var); + if (it != var_map->end()) { + CHECK(equal(it->second, new_expr)) + << "ValueError: The assigned value of var " << var << " mismatched. " << it->second + << " vs. " << new_expr << "."; + } else { + (*var_map)[var] = new_expr; + } + } + }; + + // Check buffer dimensions + CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size()) + << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() + << " vs. " << specific_buf->shape.size() << "."; + + CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size()) + << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() + << " vs. " << specific_buf->strides.size() << "."; + + // Updating var mapping using specific_expr + for (size_t i = 0; i < specific_buf->shape.size(); ++i) { + build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); + } + for (size_t i = 0; i < specific_buf->strides.size(); ++i) { + build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]); + } + build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); + + // Check data_alignment and offset_factor. + // These two signatures are int, so we do not need map them. + CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) + << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment + << " vs. " << specific_buf->data_alignment << "."; + + CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor) + << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor + << " vs. " << specific_buf->offset_factor << "."; +} + +/*! + * \brief Update Specialize var map with parameter value. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_expr The parameter value. + * \param var_map The var mapping to be updated. + */ +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, + VarMap* var_map) { + // check param is in PrimFunc's parameters + CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params"; + // specialize a param not in buffer_map + CHECK_EQ(func->buffer_map.count(param), 0) + << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map"; + // build var mapping using specific_expr + (*var_map)[param] = specific_expr; +} + +/**************** Implementation ****************/ + +PrimFunc Specialize(PrimFunc func, const Map& param_map) { + VarMap var_map; + for (const auto& kv : param_map) { + const Var& param = kv.first; + const ObjectRef& instance = kv.second; + if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else { + LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " + << instance->GetTypeKey(); + } + } + return PrimFuncSpecializer::Specialize(func, std::move(var_map)); +} + +/**************** FFI ****************/ + +TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..d59c94dc5753 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -20,11 +20,14 @@ /*! * \file tvm/tir/stmt.cc */ +#include #include #include #include #include +#include "buffer_common.h" + namespace tvm { namespace tir { @@ -234,8 +237,29 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); - ICHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); - ICHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Currently cannot check element type of array, see Load::Load + // for details. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of()) + // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " + // << buffer_var->name_hint << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); + ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); @@ -360,13 +384,15 @@ TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); + const auto* ptr_type = op->buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; p->PrintIndent(); p->stream << "allocate " << op->buffer_var << "[" << op->dtype; for (size_t i = 0; i < op->extents.size(); ++i) { p->stream << " * "; p->Print(op->extents[i]); } - p->stream << "]"; + p->stream << "], storage_scope = " << ptr_type->storage_scope; if (!is_one(op->condition)) { p->stream << " if "; p->Print(op->condition); @@ -377,7 +403,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +420,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + String storage_scope, Span span) { + return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); @@ -632,6 +659,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // BufferRegion BufferRegion::BufferRegion(Buffer buffer, Array region) { + CHECK_EQ(buffer->shape.size(), region.size()) + << "The dimension between " << buffer << " and region " << region + << " mismatched, the buffer is " << buffer; ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->region = std::move(region); @@ -679,6 +709,49 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // MatchBufferRegion MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { + const Buffer& source_buffer = source->buffer; + arith::Analyzer analyzer; + // Check scope and dtype + CHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " + << source_buffer.scope(); + CHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " + << source_buffer->dtype; + + // Check data_alignment + CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) + << "Trying to match buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + + // Check BufferType. AutoBroadcast is not allowed for now. + CHECK(buffer->buffer_type == BufferType::kDefault && + source_buffer->buffer_type == BufferType::kDefault) + << "AutoBroadcast is not allowed in MatchBuffer"; + + // Validate shape + CHECK(source->region.size() >= buffer->shape.size()) + << "Dimension of source Region expected to be larger or equal than target buffer shape, but " + "got " + << source->region.size() << " vs. " << buffer->shape.size(); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < offset; ++i) { + CHECK(analyzer.CanProve(source->region[i]->extent == 1)) + << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; + } + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const PrimExpr& buffer_shape = buffer->shape[i]; + if (!buffer_shape->IsInstance()) { + CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) + << "The dimension mismatched between source region and target buffer shape, got " + << source_range->extent << " vs. " << buffer_shape << "."; + } + } + // Note that we do not check elem_offset and strides in this function + + // Construction ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->source = std::move(source); @@ -695,7 +768,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->buffer->name << " = match_buffer_region("; + p->stream << op->buffer->name << " = match_buffer("; p->Print(op->source); p->stream << ")\n"; }); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index af78804837ba..aca6d1b50b0e 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -112,24 +112,31 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; - // Only do very simple type coversion - // int->float, DataType::Int(32)->int(64) - // require the types to be relatively consistent - // This will the reduce amount code generated by operators - // and also help user to find potential type conversion problems. - if (!lhs.dtype().is_float() && - (rhs.dtype().is_float() || - datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { - // int->float + + // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by + // operators. This can be helpful for users to find potential type conversion problems. The + // following are exceptions: + if (lhs.dtype().is_float() && rhs.dtype().is_float()) { + // Given two dissimilar floats, cast the lower bit version to the higher bit version. + // E.g. fp16 + fp32 --> fp32 + fp32 + if (lhs.dtype().bits() < rhs.dtype().bits()) { + lhs = cast(rhs.dtype(), lhs); + } else if (lhs.dtype().bits() > rhs.dtype().bits()) { + rhs = cast(lhs.dtype(), rhs); + } + } else if (!lhs.dtype().is_float() && + (rhs.dtype().is_float() || + datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) { + // Cast int->float when the other operand is a float lhs = cast(rhs.dtype(), lhs); } else if ((lhs.dtype().is_float() || datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) && !rhs.dtype().is_float()) { - // int->float + // Cast int->float when the other operand is a float rhs = cast(lhs.dtype(), rhs); } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) || (lhs.dtype().is_uint() && rhs.dtype().is_uint())) { - // promote int to higher bits + // Promote int to higher bits e.g. int8 + int16 --> int16 + int16 if (lhs.dtype().bits() < rhs.dtype().bits()) { lhs = cast(rhs.dtype(), lhs); } else { @@ -137,6 +144,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) || (lhs.dtype().is_uint() && rhs.dtype().is_int())) { + // Handle mixing signed and unsigned integers int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index dd7fee37e2d1..9baf4b5245ea 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,6 +21,9 @@ #include +#include +#include + namespace tvm { namespace tir { @@ -41,30 +44,35 @@ void VerifySRefTree(const ScheduleState& self); */ void VerifyCachedFlags(const ScheduleState& self); -/******** Scope ********/ +/******** IR Module ********/ /*! - * \brief Gets the sref to the scope root block, exclusive - * \param sref The block or loop sref to be retrieved - * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write */ -Optional GetScopeRoot(const StmtSRef& sref); +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var); +/******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it - * \param prim The name of the schedule primitive * \param self The schedule state * \param sref The sref whose scope is to be checked + * \param require_stage_pipeline A boolean indicating whether to check stage pipeline * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its * scope root is not a stage pipeline * \return The block sref to the scope root */ -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref); +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); /*! * \brief Checks whether the block is a complete block under the scope * \param self The schedule state * \param block_sref The block to be checked - * \param scope_root The sref to the root block of the scope that `block_sref` is in + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in * \return A boolean indicating if the block is a complete block * \note Definition of a complete block: * 1) All block vars are data parallel @@ -73,10 +81,10 @@ StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const Stmt * 3) No overlap between the buffers the block reads and writes */ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root); + const StmtSRef& scope_root_sref); /*! - * \brief Checks if the block is a complete block + * \brief Check if the block is a complete block under the scope * \param self The schedule state * \param block_sref The sref to the block whose completeness is to be checked * \param scope_root_sref The scope root of the block @@ -85,6 +93,33 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check whether the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return A boolean indicating if the block is a reduction block + * \note Definition of a reduction block: + * 1) The block has the `init` statement + * 2) All the block bindings are quasi-affine expressions + * 3) All block vars are either data parallel block vars or reduction block vars + * 4) Dominant: the block is the only writer of its output, dominating the reader of its output + * buffers + * 5) The reduction block vars are not used to index the output buffers + */ +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check if the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is not a reduction block + */ +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -119,29 +154,75 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, */ Map GetBindings(const BlockRealize& realize); +/*! + * \brief Get the vars involved in the bindings of data parallel block vars and reduction block + * vars, respectively + * \param block_realize The BlockRealize to be analyzed + * \param data_par_vars The vars that appear in the binding of any data parallel block iter + * \param reduce_vars The vars that appear in the binding of any reduction block iter + * \return A boolean indicating whether the block has block iters that is neither a data parallel + * block iter nor a reduction block iter + */ +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars); + /******** Block-loop relation ********/ + /*! - * \brief Retrieves blocks in a specific function with its name + * \brief Gets StmtSRefs of leaf blocks of a scope where a specific block/loop is in * \param self The schedule state - * \param name The name of the blocks to be retrieved - * \param func_name The name of the function - * \return A list of blocks with the specific name + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of StmtSRefs of leaf block */ -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); + /*! - * \brief Gets the parent loops of the block in its scope, from outer to inner - * \param self The schedule state - * \param block_sref The query block - * \return A list of loops above the given block in its scope, from outer to inner + * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of leaf BlockRealize */ -Array GetLoops(const StmtSRef& block_sref); +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); + /*! - * \brief Gets the leaf blocks of a scope where a specific block/loop is in + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks * \param self The schedule state * \param parent_sref The StmtSRef that points to the parent block/loop - * \return A list of leaf blocks + * \return The BlockRealize of the single child block + * \throw ScheduleError If there is 0 or multiple child blocks + */ +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); +/*! + * \brief Get the BlockRealize of the input block + * \param self The schedule state + * \param block_sref The StmtSRef of the queried block + * \return The BlockRealize of the input block + */ +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); + +/******** Commutative Reducer ********/ + +/*! + * \brief Get the list of the registered reducer-getter functions + * \return The list of the registered reducer-getter functions + * \sa ReducerRegistry + */ +std::vector> GetReducerGetters(); + +/*! + * \brief Given the input identity and the combiner BufferStore of a reduction, extract the + * corresponding commutative reducer and its lhs, rhs if possible. + * \param identity The identity of the reduction + * \param combiner The combiner of the reduction + * \param result_reducer The extracted CommReducer + * \param lhs The extracted lhs of the reducer + * \param rhs The extracted rhs of the reducer + * \return A boolean indicating whether a corresponding commutative reducer is found */ -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d58dece3c644..3ee98ec5b7d2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,8 +21,37 @@ namespace tvm { namespace tir { +/******** IR Module ********/ + +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == root_block) { + if (result_g_var != nullptr) { + *result_g_var = g_var; + } + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " + "statement:\n" + << GetRef(root_block); + throw; +} + /******** Scope ********/ +/*! + * \brief Gets the sref to the scope root block, exclusive + * \param sref The block or loop sref to be retrieved + * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + */ Optional GetScopeRoot(const StmtSRef& sref) { for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { if (p->stmt->IsInstance()) { @@ -32,7 +61,8 @@ Optional GetScopeRoot(const StmtSRef& sref) { return NullOpt; } -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref) { +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, + bool require_stage_pipeline) { class RootBlockError : public ScheduleError { public: explicit RootBlockError(IRModule mod) : mod_(mod) {} @@ -75,7 +105,7 @@ Definition of a scope that is a stage pipeline: throw RootBlockError(self->mod); } bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; - if (stage_pipeline == false) { + if (require_stage_pipeline && stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); throw NotStagePipelineError(self->mod, GetRef(block)); } @@ -106,20 +136,29 @@ bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { return true; } -bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root) { - BlockScope scope = self->GetBlockScope(scope_root); +/*! + * \brief A helper function that checks whether a given block is a complete block under the scope, + * or return the condition it violates if it is not a complete block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a complete block, or a positive integer indicating which condition is + * first violated + */ +int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); // Cond 1. All block vars are data parallel - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { - return false; + return 1; } } // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers if (!IsDominantBlock(scope, block_sref)) { - return false; + return 2; } // Cond 3. No overlap between the buffers the block reads and writes std::unordered_set written_buffers; @@ -129,35 +168,150 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, } for (const BufferRegion& read : block->reads) { if (written_buffers.count(read->buffer.get())) { - return false; + return 3; } } - return true; + return 0; +} + +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; } void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class IncompleteBlockError : public ScheduleError { public: - explicit IncompleteBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} String FastErrorString() const final { return "ScheduleError: Incomplete block"; } String DetailRenderTemplate() const final { - return R"(The block {0} is not a complete block. -Definition of a complete block: + std::ostringstream os; + os << "The block {0} is not a complete block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a complete block: 1) All block vars are data parallel 2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers 3) No overlap between the buffers the block reads and writes)"; + return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; + int violated_cond_; }; - bool result = IsCompleteBlock(self, block_sref, scope_root_sref); - if (result == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); - throw IncompleteBlockError(self->mod, GetRef(block)); + int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw IncompleteBlockError(self->mod, GetRef(block), error_code); + } +} + +/*! + * \brief A helper function that checks whether a given block is a reduction block under the scope, + * or return the condition it violates if it is not a reduction block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a reduction block, or a positive integer indicating which condition is + * first violated + */ +int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + // Cond 1. The block has the `init` statement. + if (!block->init.defined()) { + return 1; + } + // Cond 2. All the block bindings are quasi-affine expressions. + if (!self->IsAffineBlockBinding(block_sref)) { + return 2; + } + // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, + // we collect all the reduction block vars. + std::unordered_set reduction_block_vars; + reduction_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return 3; + } else if (iter_var->iter_type == kCommReduce) { + reduction_block_vars.insert(iter_var->var.get()); + } + } + // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its + // output buffers. + if (!IsDominantBlock(scope, block_sref)) { + return 4; + } + // Cond 5. The reduction block vars are not used to index the output buffers. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + if (const auto* store = obj.as()) { + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (UsesVar(index, [&reduction_block_vars](const VarNode* var) { + return reduction_block_vars.count(var); + })) { + affected = true; + return false; + } + } + return false; + } + return true; + }); + return !affected ? 0 : 5; +} + +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; +} + +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class NotReductionBlockError : public ScheduleError { + public: + explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} + String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +5) The reduction block vars are not used to index the output buffers)"; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + int violated_cond_; + }; + + int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotReductionBlockError(self->mod, GetRef(block), error_code); } } @@ -229,74 +383,465 @@ Map GetBindings(const BlockRealize& realize) { return result; } -/******** Block-loop relation ********/ - -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { - struct Finder : public StmtVisitor { - explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars) { + Block block = block_realize->block; + ICHECK(block_realize->block.same_as(block)) + << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " + "input block"; - void VisitStmt_(const BlockNode* block) override { - if (block->name_hint == name_) { - auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); - results_.push_back(it->second); - } - StmtVisitor::VisitStmt_(block); + bool has_block_vars_of_other_types = false; + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n = static_cast(block->iter_vars.size()); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& iter_value = block_realize->iter_values[i]; + std::unordered_set* set = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + set = data_par_vars; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + set = reduce_vars; + } else { + has_block_vars_of_other_types = true; } - const ScheduleState& self_; - const String& name_; - Array results_; - }; + Array vars_in_binding = UndefinedVars(iter_value); + for (const Var& var : vars_in_binding) { + set->insert(var.get()); + } + } - BaseFunc func = self->mod->Lookup(func_name); - const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); - Finder finder(self, name); - finder(prim_func->body); - return std::move(finder.results_); + return has_block_vars_of_other_types; } -Array GetLoops(const StmtSRef& block_sref) { - std::vector result; - for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); - parent = parent->parent) { - result.push_back(GetRef(parent)); +/******** Block-loop relation ********/ + +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + Array child_block_srefs; + child_block_srefs.reserve(child_block_realize.size()); + + for (BlockRealize realize : child_block_realize) { + child_block_srefs.push_back(self->stmt2ref.at(realize->block.get())); } - return {result.rbegin(), result.rend()}; + return child_block_srefs; } -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - public: - static Array Collect(const ScheduleState& self, const Stmt& stmt) { - Collector collector(self); + static Array Collect(const Stmt& stmt) { + Collector collector; collector(stmt); return std::move(collector.result_); } - private: - explicit Collector(const ScheduleState& self) : self_(self) {} - - void VisitStmt_(const BlockNode* block) final { - auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); - result_.push_back(it->second); + void VisitStmt_(const BlockRealizeNode* block_realize) final { + result_.push_back(GetRef(block_realize)); } - const ScheduleState& self_; - Array result_; + Array result_; }; if (parent_sref->stmt->IsInstance()) { const auto* loop = static_cast(parent_sref->stmt); - return Collector::Collect(self, loop->body); + return Collector::Collect(loop->body); } else if (parent_sref->stmt->IsInstance()) { const auto* block = static_cast(parent_sref->stmt); - return Collector::Collect(self, block->body); + return Collector::Collect(block->body); } ICHECK(false) << "Unreachable"; throw; } +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + class NonSingleChildBlockError : public ScheduleError { + public: + explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) + : mod_(std::move(mod)), stmt_(GetRef(sref->stmt)) { + sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; + } + + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The " << sref_type_ << " is required to have only one child block"; + return os.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The " << sref_type_ << " {0} is required to have only one child block"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {stmt_}; } + + IRModule mod_; + Stmt stmt_; + String sref_type_; + }; + + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + if (child_block_realize.size() != 1) { + throw NonSingleChildBlockError(self->mod, parent_sref); + } + return child_block_realize[0]; +} + +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { + struct BlockRealizeFinder : public StmtVisitor { + explicit BlockRealizeFinder(const BlockNode* target_block) + : target_block(target_block), result(nullptr) {} + + void VisitStmt(const Stmt& stmt) final { + if (result != nullptr) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + if (block_realize->block.get() == target_block) { + result = block_realize; + } + // No need to visit recursively, since the deeper BlockRealizes must not be the result. + } + + const BlockNode* target_block; + const BlockRealizeNode* result; + }; + + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block_sref->parent == nullptr) { + const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); + return Downcast(func->body); + } else { + BlockRealizeFinder finder(block); + finder(GetRef(block_sref->parent->stmt)); + ICHECK(finder.result != nullptr) + << "InternalError: Cannot find the BlockRealize of block " << GetRef(block); + return GetRef(finder.result); + } +} + +/******** Pattern Matcher ********/ + +/*! + * \brief PrimExpr pattern matcher. + * + * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated + * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific + * patterns. + * + * The code below shows how to use the pattern matcher. + * + * \code + * + * Var x("x"), y("y"); + * // use PrimExpr to declare patterns, x, y are holes that can be filled with + * PatternMatcher pattern_matcher(x + y); + * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match + * pattern_matcher.Match(expr); + * + * if (pattern_matcher.Success()) { + * pattern_matcher.Eval(x) // C[i, j] + * pattern_matcher.Eval(y) // A[i, k] * B[k, j] + * } + * + * \endcode + */ +class PatternMatcher : public ExprVisitor { + public: + explicit PatternMatcher(PrimExpr pattern) : pattern_(std::move(pattern)) {} + + void VisitExpr_(const VarNode* op) final { + auto it = filled_map_.find(op); + if (it == filled_map_.end()) { + filled_map_[op] = expr_to_match_; + } else { + ExprDeepEqual equal; + if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; + match_success_ = false; + } + } + + void VisitExpr_(const LoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer_var.same_as(ptr->buffer_var)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->predicate; + VisitExpr(op->predicate); + expr_to_match_ = ptr->index; + VisitExpr(op->index); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const LetNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->var; + VisitExpr(op->var); + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->body; + VisitExpr(op->body); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const CallNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->op.same_as(ptr->op)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->args.size(); ++i) { + expr_to_match_ = ptr->args[i]; + VisitExpr(op->args[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + +#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ + void VisitExpr_(const OpName* op) { \ + const auto* ptr = expr_to_match_.as(); \ + if (ptr == nullptr) { \ + match_success_ = false; \ + } else { \ + PrimExpr current = expr_to_match_; \ + expr_to_match_ = ptr->a; \ + VisitExpr(op->a); \ + expr_to_match_ = ptr->b; \ + VisitExpr(op->b); \ + std::swap(expr_to_match_, current); \ + } \ + } + + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); + + void VisitExpr_(const CastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const NotNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->a; + VisitExpr(op->a); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const SelectNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->condition; + VisitExpr(op->condition); + expr_to_match_ = ptr->true_value; + VisitExpr(op->true_value); + expr_to_match_ = ptr->false_value; + VisitExpr(op->false_value); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const RampNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const BroadcastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const ShuffleNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + for (size_t i = 0; i < op->vectors.size(); ++i) { + expr_to_match_ = ptr->vectors[i]; + VisitExpr(op->vectors[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const IntImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const FloatImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const StringImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void Match(const PrimExpr& expr_to_match) { + this->match_success_ = true; + this->filled_map_.clear(); + this->expr_to_match_ = expr_to_match; + this->operator()(pattern_); + } + + PrimExpr Eval(const Var& var) { + auto it = filled_map_.find(var.operator->()); + ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; + ICHECK(match_success_) << "Match failed"; + return it->second; + } + + bool Success() const { return match_success_; } + + private: + bool match_success_{true}; + PrimExpr pattern_, expr_to_match_; + std::unordered_map filled_map_; +}; + +/******** Commutative Reducer ********/ + +bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, + const BufferLoad& load, PrimExpr* lhs, PrimExpr* rhs) { + if (!ExprDeepEqual()(reducer->identity_element[0], identity)) { + return false; + } + PatternMatcher pattern_matcher(reducer->result[0]); + pattern_matcher.Match(combiner); + if (pattern_matcher.Success()) { + PrimExpr lhs_tmp = pattern_matcher.Eval(reducer->lhs[0]); + PrimExpr rhs_tmp = pattern_matcher.Eval(reducer->rhs[0]); + if (ExprDeepEqual()(load, lhs_tmp)) { + *lhs = std::move(lhs_tmp); + *rhs = std::move(rhs_tmp); + } + return true; + } + return false; +} + +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs) { + BufferLoad load(combiner->buffer, combiner->indices); + // Check reduction patterns. + for (const TypedPackedFunc& reducer_getter : GetReducerGetters()) { + CommReducer reducer = reducer_getter(identity.dtype()); + if (MatchReducer(reducer, identity, combiner->value, load, lhs, rhs)) { + *result_reducer = std::move(reducer); + return true; + } + } + return false; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0563d39427b1..df2f06815133 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -258,6 +258,93 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { } /******** Schedule: loops manipulation ********/ + +LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { + CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Fuse(state_, loop_srefs); + TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { + class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + }; + + class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; + // Prepare for the splitting + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + int infer_index = -1; + PrimExpr tot_length = 1; + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + // infer factor if needed and check validity of factors + for (size_t i = 0; i < factor_rvs.size(); i++) { + if (!factor_rvs[i].defined()) { + factors.push_back(Integer(-1)); + if (infer_index == -1) { + infer_index = i; + } else { + throw NotSingleInferFactorError(state_->mod); + } + } else { + PrimExpr factor = this->Get(factor_rvs[i].value()); + factors.push_back(factor); + tot_length *= factor; + } + } + if (infer_index != -1) { + factors.Set(infer_index, + this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(state_->mod, GetRef(loop)); + } + results = tir::Split(state_, loop_sref, factors); + TVM_TIR_SCHEDULE_END("split", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(results); +} + /******** Schedule: compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { @@ -277,6 +364,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /******** Schedule: loop binding/annotation ********/ /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ + +BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); + TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: blockize & tensorize ********/ /******** FFI ********/ diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8945fb9ee0dc..c44ec05d660b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,26 +68,35 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline Array GetSRefs(const Array& rvs) const; + inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } using ScheduleNode::GetSRef; public: - /******** Block/Loop relation ********/ + /******** Schedule: Sampling ********/ + /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; - /******** Schedule: loops manipulation ********/ - /******** Schedule: compute location ********/ + /******** Schedule: Transform loops ********/ + LoopRV Fuse(const Array& loop_rvs) override; + Array Split(const LoopRV& loop_rv, const Array>& factors) override; + /******** Schedule: Manipulate ForKind ********/ + /******** Schedule: Insert cache stages ********/ + /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; - /******** Schedule: loop binding/annotation ********/ - /******** Schedule: cache read/write ********/ - /******** Schedule: reduction ********/ - /******** Schedule: blockize & tensorize ********/ + /******** Schedule: Reduction ********/ + BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + /******** Schedule: Blockize & Tensorize ********/ + /******** Schedule: Annotation ********/ + /******** Schedule: Misc ********/ + void EnterPostproc() override {} - /******** Utility functions ********/ protected: + /******** Utility functions ********/ /*! * \brief Copy the schedule state, as well as the symbol table * \param new_state The ScheduleState copied @@ -132,28 +141,27 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const auto* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); return GetRef(loop); } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - auto it = this->symbol_table_.find(expr_rv); - if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv; - } - const ObjectRef& obj = (*it).second; - const auto* expr_node = obj.as(); - if (expr_node == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } - return GetRef(expr_node); + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + auto it = this->symbol_table_.find(var); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); + return Integer(int_imm->value); + }); + return this->analyzer_->Simplify(transformed); } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { @@ -198,6 +206,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } +template +inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(sch->GetSRef(rv)); + } + return result; +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + /******** Adding/Removing elements in the symbol table ********/ template diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc new file mode 100644 index 000000000000..af721767c32f --- /dev/null +++ b/src/tir/schedule/instruction.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, + Array outputs) { + ObjectPtr n = make_object(); + n->kind = std::move(kind); + n->inputs = std::move(inputs); + n->attrs = std::move(attrs); + n->outputs = std::move(outputs); + this->data_ = std::move(n); +} + +using InstructionKindRegistry = AttrRegistry; + +InstructionKind InstructionKind::Get(const String& name) { + const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name); + ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered"; + return reg->inst_kind_; +} + +InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) { + this->inst_kind_ = InstructionKind(make_object()); +} + +InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) { + return InstructionKindRegistry::Global()->RegisterOrGet(name); +} + +/**************** Repr ****************/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const auto* self = obj.as(); + ICHECK_NOTNULL(self); + Array inputs; + inputs.reserve(self->inputs.size()); + for (const ObjectRef& obj : self->inputs) { + if (!obj.defined()) { + inputs.push_back(String("None")); + } else if (obj->IsInstance() || obj->IsInstance()) { + inputs.push_back(String("_")); + } else if (const auto* str_obj = obj.as()) { + inputs.push_back(String('"' + std::string(str_obj->data) + '"')); + } else if (obj->IsInstance() || obj->IsInstance()) { + inputs.push_back(obj); + } else if (const auto* expr = obj.as()) { + PrimExpr new_expr = + Substitute(GetRef(expr), [](const Var& var) -> Optional { + ObjectPtr new_var = make_object(*var.get()); + new_var->name_hint = "_"; + return Var(new_var); + }); + std::ostringstream os; + os << new_expr; + inputs.push_back(String(os.str())); + } else { + LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj->GetTypeKey(); + throw; + } + } + p->stream << self->kind->f_as_python( + /*inputs=*/inputs, + /*attrs=*/self->attrs, + /*decision=*/NullOpt, + /*outputs=*/Array(self->outputs.size(), String("_"))); + }); + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(InstructionNode); +TVM_REGISTER_NODE_TYPE(InstructionKindNode); + +TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); +TVM_REGISTER_GLOBAL("tir.schedule.Instruction") + .set_body_typed([](InstructionKind kind, Array inputs, Array attrs, + Array outputs) -> Instruction { + return Instruction(kind, inputs, attrs, outputs); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h new file mode 100644 index 000000000000..95d636467aa0 --- /dev/null +++ b/src/tir/schedule/instruction_traits.h @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ +#define TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Register an InstructionKind using a trait class + * \param InstructionKindTraits A traits class of an InstructionKind + * + * Example: + * + * \code + * + * struct SomeInstructionKindTraits { + * static constexpr const char* kName = "name-of-the-instruction"; + * static constexpr bool kIsPure = false; + * + * // Convertible to `InstructionKindNode::FInstructionApply` + * static Array ApplyToSchedule( + * const tir::Schedule& sch, + * const Array& inputs, + * const Array& attrs, + * const Optional& decision); + * + * // Convertible to `InstructionKindNode::FInstructionAsPython` + * static String AsPython( + * const Array& inputs, + * const Array& attrs, + * const Optional& decision, + * const Array& outputs); + * + * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON` + * static Array AttrsAsJSON( + * const Array& attrs); + * + * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON` + * static Array AttrsFromJSON( + * const Array& attrs_record); + * }; + * + * TVM_REGISTER_INST_KIND_TRAITS(SomeInstructionKindTraits); + * + * \endcode + */ +#define TVM_REGISTER_INST_KIND_TRAITS(InstructionKindTraits) \ + TVM_REGISTER_INST_KIND(InstructionKindTraits::kName) \ + .set_is_pure(InstructionKindTraits::kIsPure) \ + .set_apply_to_schedule(InstructionKindTraits::ApplyToSchedule) \ + .set_attrs_as_json(InstructionKindTraits::AttrsAsJSON) \ + .set_attrs_from_json(InstructionKindTraits::AttrsFromJSON) \ + .set_as_python(InstructionKindTraits::AsPython) + +/*! + * \brief A helper to conveniently register an InstructionKind. When inherited in curiously + * recursive template pattern, the derived class `TTraits` only needs to define two functions on the + * unpacked inputs, and the helper handles unpacking and downcasting. See the example for more + * details. + * + * \tparam TTraits The derived class + * + * Example: + * + * \code + * + * struct SamplePerfectTileTraits : public UnpackedInstTraits { + * // The name of this kind of instruction + * static constexpr const char* kName = "SamplePerfectTile"; + * // A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state + * static constexpr bool kIsPure = true; + * // The number of inputs in this kind of instruction + * static constexpr size_t kNumInputs = 1; + * // The number of attributes in this kind of instruction + * static constexpr size_t kNumAttrs = 2; + * // The number of decisions in this kind of instruction (only 0 or 1 is allowed) + * static constexpr size_t kNumDecisions = 1; + * + * // Calling convention: + * // - All the arguments must be ObjectRef + * // - The 1st argument is Schedule + * // - The next `kNumInputs` arguments are input random variables + * // - The next `kNumAttrs` arguments are attributes + * // - The next argument is decision, if `kNumDecisions == 1` + * static Array UnpackedApplyToSchedule( + * Schedule sch, + * LoopRV loop_rv, + * Integer n, + * Integer max_innermost_factor, + * Optional> decision) { + * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); + * } + * + * // Calling convention: + * // - All the arguments must be ObjectRef + * // - The 1st argument is an array containing names of output random variables + * // - The next `kNumInputs` arguments are names of input random variables + * // - The next `kNumAttrs` arguments are attributes + * // - The next argument is decision, if `kNumDecisions == 1` + * static String UnpackedAsPython( + * Array outputs, + * String loop_rv, + * Integer n, + * Integer max_innermost_factor, + * Optional> decision) { + * PythonAPICall py("sample_perfect_tile"); + * py.Input("loop", loop_rv); + * py.Input("n", n->value); + * py.Input("max_innermost_factor", max_innermost_factor->value); + * py.Decision(decision); + * py.OutputList(outputs); + * return py.Str(); + * } + * + * template + * friend struct UnpackedInstTraits; + * }; + * + * TVM_REGISTER_INST_KIND(SamplePerfectTileTraits); + * \endcode + */ +template +struct UnpackedInstTraits { + /*! + * \brief Unpack the arguments in the calling convention, and feed them into + * `TTraits::UnpackedApplyToSchedule` + * \sa InstructionKindNode::f_apply_to_schedule + */ + static Array ApplyToSchedule(const Schedule& sch, const Array& inputs, + const Array& attrs, + const Optional& decision); + + /*! + * \brief Unpack the arguments in the calling convention, and feed them into + * `TTraits::UnpackedAsPython` + * \sa InstructionKindNode::f_as_python + */ + static String AsPython(const Array& inputs, const Array& attrs, + const Optional& decision, const Array& outputs); + + /*! \brief No customized serializer by default */ + static constexpr std::nullptr_t AttrsAsJSON = nullptr; + + /*! \brief No customized deserializer by default */ + static constexpr std::nullptr_t AttrsFromJSON = nullptr; + + protected: + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs); + template + static TVM_ALWAYS_INLINE void _SetAttrs(const runtime::TVMArgsSetter& setter, + const Array& attrs); + template + static TVM_ALWAYS_INLINE void _SetDecision(const runtime::TVMArgsSetter& setter, + const Optional& decision); + static TVM_ALWAYS_INLINE Array _ConvertOutputs(const TVMRetValue& rv); +}; + +/*! + * \brief A helper class that constructs schedule API call in python syntax, + * which helps convert an Inst to a python statement. + * \sa InstructionKindNode::f_as_python + */ +class PythonAPICall { + public: + /*! + * \brief Constructor + * \param method_name The name of the schedule API to be called + */ + explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {} + /*! \brief Add an intger input */ + inline void Input(String arg_name, int arg); + /*! \brief Add an intger input */ + inline void Input(String arg_name, int64_t arg); + /*! \brief Add a double input */ + inline void Input(String arg_name, double arg); + /*! \brief Add an input random variable */ + inline void Input(String arg_name, String arg); + /*! \brief Add an input, dispatched to different implementations according to the object's type */ + inline void Input(String arg_name, ObjectRef arg); + /*! \brief Add the decision */ + inline void Decision(ObjectRef decision); + /*! + * \brief Add a single output random variable + * \param unit_array An array containing only one element + */ + inline void SingleOutput(Array unit_array); + /*! \brief Add a list of output random variables */ + inline void OutputList(Array outputs); + /*! \returns The schedule API call in python syntax */ + inline String Str() const; + + private: + /*! \brief Converts a TVM object to python string and print to the output stream */ + inline void AsPythonString(const ObjectRef& obj, std::ostream& os); + + private: + /*! \brief The name of the API to call */ + String method_name_; + /*! \brief The output of the instruction */ + Optional output_; + /*! \brief The names of input arguments */ + std::vector arg_names_; + /*! \brief The values of input arguments */ + std::vector args_; +}; + +/********** implementation details **********/ + +// forward declaration +namespace details { + +template +struct _ArgsPacker; + +template <> +struct _ArgsPacker<> { + static constexpr bool checked = true; +}; + +template +struct _ArgsPacker { + static constexpr bool checked = + std::is_base_of::value && _ArgsPacker::checked; +}; + +template +struct _MethodType {}; + +template +struct _MethodType { + using return_type = TReturn; + using argument_type = _ArgsPacker; +}; + +template +struct _NumArgs {}; + +template +struct _NumArgs { + static constexpr size_t value = sizeof...(Args); +}; + +template +struct _IsTVMArray : std::false_type {}; + +template +struct _IsTVMArray> : std::true_type {}; + +template +struct _IsSingleObject + : std::integral_constant::value && !_IsTVMArray::value> { +}; + +template +using ReturnType = typename _MethodType>::return_type; + +template +static constexpr bool ArgumentAreAllObjects = + _MethodType>::argument_type::checked; + +template +static constexpr size_t NumArgs = _NumArgs>::value; + +template +static constexpr int IsTVMArray = _IsTVMArray>::value; + +template +static constexpr int IsSingleObject = _IsSingleObject>::value; + +}; // namespace details + +template +Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch, + const Array& inputs, + const Array& attrs, + const Optional& decision) { + using method_type = decltype(TTraits::UnpackedApplyToSchedule); + using return_type = details::ReturnType; + static_assert(details::ArgumentAreAllObjects, + "All arguments to `UnpackedApplyToSchedule` must be subclasses of ObjectRef"); + constexpr size_t kNumArgs = details::NumArgs; + constexpr size_t kNumInputs = TTraits::kNumInputs; + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, + "length of argument list mismatch"); + TVMValue tvm_values[kNumArgs]; + int tvm_type_codes[kNumArgs]; + runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); + setter(0, sch); + TTraits::template _SetInputs<1>(setter, inputs); + TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); + TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); + PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + using runtime::detail::unpack_call; + constexpr size_t kNumArgs = details::NumArgs; + ICHECK_EQ(args.size(), kNumArgs); + unpack_call(nullptr, TTraits::UnpackedApplyToSchedule, args, rv); + }); + TVMRetValue rv; + pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); + return TTraits::_ConvertOutputs(rv); +} + +template +String UnpackedInstTraits::AsPython(const Array& inputs, + const Array& attrs, + const Optional& decision, + const Array& outputs) { + using method_type = decltype(TTraits::UnpackedAsPython); + using return_type = details::ReturnType; + static_assert(details::ArgumentAreAllObjects, + "All arguments to `UnpackedAsPython` must be subclasses of ObjectRef"); + constexpr size_t kNumArgs = details::NumArgs; + constexpr size_t kNumInputs = TTraits::kNumInputs; + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, + "length of argument list mismatch"); + TVMValue tvm_values[kNumArgs]; + int tvm_type_codes[kNumArgs]; + runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); + setter(0, outputs); + TTraits::template _SetInputs<1>(setter, inputs); + TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); + TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); + PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { + using runtime::detail::unpack_call; + constexpr size_t kNumArgs = details::NumArgs; + ICHECK_EQ(args.size(), kNumArgs); + unpack_call(nullptr, TTraits::UnpackedAsPython, args, rv); + }); + TVMRetValue rv; + pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); + String result = rv; + return result; +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + constexpr size_t kNumInputs = TTraits::kNumInputs; + ICHECK_EQ(kNumInputs, inputs.size()) + << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; + const ObjectRef* ptr = inputs.template as()->begin(); + for (size_t i = 0; i < kNumInputs; ++i) { + setter(i + index_offset, *(ptr + i)); + } +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetAttrs(const runtime::TVMArgsSetter& setter, + const Array& attrs) { + constexpr size_t kNumAttrs = TTraits::kNumAttrs; + ICHECK_EQ(kNumAttrs, attrs.size()) + << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; + const ObjectRef* ptr = attrs.as()->begin(); + for (size_t i = 0; i < kNumAttrs; ++i) { + setter(i + index_offset, *(ptr + i)); + } +} + +template +template +TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision( + const runtime::TVMArgsSetter& setter, const Optional& decision) { + constexpr size_t kNumDecisions = TTraits::kNumDecisions; + static_assert(kNumDecisions <= 1, "an instruction is supposed to have at most 1 decision"); + if (kNumDecisions == 1) { + setter(index_offset, decision); + } else { + ICHECK(!decision.defined()); + } +} + +template +TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs( + const TVMRetValue& rv) { + using method_type = decltype(TTraits::UnpackedApplyToSchedule); + using return_type = details::ReturnType; + constexpr int is_array = details::IsTVMArray; + constexpr int is_single_obj = details::IsSingleObject; + constexpr int is_void = std::is_void::value; + static_assert(is_array || is_single_obj || is_void, "return type not supported"); + static_assert(is_array + is_single_obj + is_void == 1, "internal template error"); + if (is_void) { + return {}; + } else if (is_single_obj) { + ObjectRef obj = rv; + return {obj}; + } else if (is_array) { + ObjectRef obj = rv; + const ArrayNode* array = obj.as(); + return GetRef>(array); + } +} + +/********** PythonAPICall **********/ + +inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) { + if (const auto* str = obj.as()) { + os << str->data; + } else if (const auto* int_imm = obj.as()) { + os << int_imm->value; + } else if (const auto* float_imm = obj.as()) { + os.precision(17); + os << float_imm->value; + } else if (const auto* array = obj.as()) { + os << '['; + bool is_first = true; + for (const ObjectRef& e : *array) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + AsPythonString(e, os); + } + os << ']'; + } else { + LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey() + << "' to python. Its value is: " << obj; + throw; + } +} + +void PythonAPICall::Input(String arg_name, int arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.push_back(std::to_string(arg)); +} + +void PythonAPICall::Input(String arg_name, int64_t arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.push_back(std::to_string(arg)); +} + +void PythonAPICall::Input(String arg_name, double arg) { + arg_names_.emplace_back(std::move(arg_name)); + std::ostringstream os; + os.precision(17); + os << arg; + args_.push_back(os.str()); +} + +void PythonAPICall::Input(String arg_name, String arg) { + arg_names_.emplace_back(std::move(arg_name)); + args_.emplace_back(std::move(arg)); +} + +void PythonAPICall::Input(String arg_name, ObjectRef arg) { + arg_names_.emplace_back(std::move(arg_name)); + std::ostringstream os; + AsPythonString(arg, os); + args_.push_back(os.str()); +} + +void PythonAPICall::Decision(ObjectRef decision) { + if (decision.defined()) { + this->Input("decision", decision); + } +} + +void PythonAPICall::SingleOutput(Array unit_array) { + ICHECK_EQ(unit_array.size(), 1); + this->output_ = unit_array[0]; +} + +void PythonAPICall::OutputList(Array outputs) { + if (outputs.empty()) { + return; + } + if (outputs.size() == 1) { + this->output_ = outputs[0] + ","; + return; + } + std::ostringstream os; + os << outputs[0]; + for (int i = 1, n = outputs.size(); i < n; ++i) { + os << ", " << outputs[i]; + } + this->output_ = os.str(); +} + +String PythonAPICall::Str() const { + std::ostringstream os; + if (output_.defined()) { + os << output_.value() << " = "; + } + os << "sch." << method_name_ << '('; + int n = args_.size(); + for (int i = 0; i < n; ++i) { + if (i > 0) { + os << ", "; + } + if (arg_names_[i].empty()) { + os << args_[i]; + } else { + os << arg_names_[i] << '=' << args_[i]; + } + } + os << ')'; + return os.str(); +} + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index ab8299e38169..22e25f1c54a7 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -24,9 +24,49 @@ namespace tvm { namespace tir { -/******** Schedule: loops manipulation ********/ +/******** Schedule: Sampling ********/ +/******** Schedule: Get blocks & loops ********/ +/*! + * \brief Retrieves blocks in a specific function with its name + * \param self The schedule state + * \param name The name of the blocks to be retrieved + * \param func_name The name of the function + * \return A list of blocks with the specific name + */ +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +/*! + * \brief Gets the parent loops of the block in its scope, from outer to inner + * \param self The schedule state + * \param block_sref The query block + * \return A list of loops above the given block in its scope, from outer to inner + */ +Array GetLoops(const StmtSRef& block_sref); +/******** Schedule: Transform loops ********/ -/******** Schedule: compute location ********/ +/*! + * Split a loop into a list of consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being split + * \param factors The splitting factors + * \return An array of srefs to the loops after splitting + */ +TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors); +/*! + * \brief Fuse a list of consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The inner loop must be the only child of the outer loop. + * 3) All loops must start with 0. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be fused + * \return The sref to the fused loop + */ +TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); +/******** Schedule: Manipulate ForKind ********/ +/******** Schedule: Insert cache stages ********/ +/******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer @@ -52,14 +92,21 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); * \param block_sref The sref to the block to be inlined to its producer */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); - -/******** Schedule: loop binding/annotation ********/ - -/******** Schedule: cache read/write ********/ - -/******** Schedule: reduction ********/ - -/******** Schedule: blockize & tensorize ********/ +/******** Schedule: Reduction ********/ +/*! + * \brief Factor a reduction block by the specified loop + * \details See python/tvm/tir/schedule/schedule.py + * \param loop_sref The loop outside block for which we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The sref of the rfactor block + */ +TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); +/******** Schedule: Blockize & Tensorize ********/ +/******** Schedule: Annotation ********/ +/******** Schedule: Misc ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 6bd6388fafff..2583b21227e4 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -622,7 +622,8 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, producer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, producer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body @@ -649,7 +650,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre Block consumer_block = GetRef(_consumer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, consumer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, consumer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -673,5 +675,56 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); } +/******** Instruction Registration ********/ + +struct ComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ComputeInline"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->ComputeInline(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("compute_inline"); + py.Input("block", block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReverseComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseComputeInline"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->ReverseComputeInline(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("reverse_compute_inline"); + py.Input("block", block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ComputeInlineTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeInlineTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc new file mode 100644 index 000000000000..a8d9c5a69dc9 --- /dev/null +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { + struct Finder : public StmtVisitor { + explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} + + void VisitStmt_(const BlockNode* block) override { + if (block->name_hint == name_) { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + results_.push_back(it->second); + } + StmtVisitor::VisitStmt_(block); + } + + const ScheduleState& self_; + const String& name_; + Array results_; + }; + + BaseFunc func = self->mod->Lookup(func_name); + const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + Finder finder(self, name); + finder(prim_func->body); + return std::move(finder.results_); +} + +Array GetLoops(const StmtSRef& block_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + +/******** Instruction Registration ********/ + +struct GetBlockTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetBlock"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) { + return sch->GetBlock(name, func_name); + } + + static String UnpackedAsPython(Array outputs, String name, String func_name) { + PythonAPICall py("get_block"); + py.Input("name", name); + py.Input("func_name", func_name); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct GetLoopsTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetLoops"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetLoops(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_loops"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc new file mode 100644 index 000000000000..d1875df61ac7 --- /dev/null +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ +class BlockPredicateAppender : public StmtMutator { + public: + /*! + * \brief Constructor + * \param to_append The predicate to be appended to BlockRealizeNode + */ + explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {} + + private: + // For each direct child of type BlockRealizeNode, append the predicate + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // We do not recursively do this + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && to_append_; + return BlockRealize(n); + } + + /*! \brief The predicate to be appended */ + const PrimExpr& to_append_; +}; + +/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ +class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { + public: + explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) + : vmap_(vmap), opaque_blocks_(opaque_blocks) {} + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (Optional ret = vmap_(var)) { + return ret.value(); + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + if (realize->block->iter_vars.empty()) { + opaque_blocks_->Set(op->block, realize->block); + } + return std::move(realize); + } + + /*! \brief The substitute function */ + std::function(const Var&)> vmap_; + /*! \brief The reuse mapping of opaque blocks */ + Map* opaque_blocks_; +}; + +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ +class IterMapSimplifyBlockBinding : public StmtExprMutator { + public: + explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map loop_var2extent) + : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} + + static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + MapNode* opaque_blocks) { + Map loop_var2extent; + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return Downcast( + IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt))); + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + Stmt res = StmtMutator::VisitStmt_(op); + loop_var2extent_.erase(op->loop_var); + return res; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // skip opaque block and update mapping + if (op->iter_values.empty()) { + Block block = op->block; + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + for (const std::pair& entry : *opaque_blocks_) { + if (entry.second.same_as(block)) { + opaque_blocks_->at(entry.first) = realize->block; + break; + } + } + return std::move(realize); + } + Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, + /*input_iters=*/loop_var2extent_, + /*input_pred=*/op->predicate, + /*require_bijective=*/false); + if (v.same_as(op->iter_values)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->iter_values = std::move(v); + return Stmt(n); + } + } + + /*! \brief The reuse mapping */ + MapNode* opaque_blocks_; + /*! \brief The range of loops */ + Map loop_var2extent_; +}; + +class HasAnnotationOrThreadBindingError : public ScheduleError { + public: + explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) + : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has annotation or " + "thread binding"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has annotation or thread binding"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class OuterNotInnerParent : public ScheduleError { + public: + explicit OuterNotInnerParent(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The outer loop is not the parent of the inner loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the outer loop {0} is not the parent of the inner " + "loop {1}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class NotOnlyChildError : public ScheduleError { + public: + explicit NotOnlyChildError(IRModule mod, For outer, For inner) + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + + String FastErrorString() const final { + return "ScheduleError: The inner loop is not the only child of outer loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the inner loop {1} is not the only child of outer " + "loop {0}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; +}; + +class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors) { + // Invariance + // - The total repeat number has not changed for each direct child block with updating predicate. + // - The execution order has not changed. (The block executes with the same args and the same + // order with before. + // Step 1. Check correctness + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + // Currently, loops not starting with 0 are not supported + arith::Analyzer analyzer; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + // Step 2. Replace all occurrences of the original loop var with new variables + int n = factors.size(); + PrimExpr substitute_value = 0; + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + const PrimExpr& factor = factors[i]; + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + substitute_value = substitute_value * factor + var; + analyzer.Bind(var, Range::FromMinExtent(0, factor)); + new_loop_vars.emplace_back(std::move(var)); + } + Map opaque_block_reuse; + Stmt new_stmt = loop->body; + new_stmt = SubstituteVarAndCollectOpaqueBlock( + [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }, + &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + if (!analyzer.CanProve(predicate)) { + new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); + } + // Step 4. Generate nested loops to replace the original loop and simplify the binding + for (int i = n - 1; i >= 0; i--) { + new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); + } + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), + opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_sref, new_stmt, opaque_block_reuse); + Array result_srefs; + result_srefs.reserve(n); + for (int i = 0; i < n; i++) { + result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); + const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + new_stmt = outer_loop->body; + } + return result_srefs; +} + +StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + std::vector loops; + loops.reserve(loop_srefs.size()); + StmtSRef outer_loop_sref{nullptr}; + const ForNode* outer_loop = nullptr; + arith::Analyzer analyzer; + // Step 1. check correctness + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); + } + if (outer_loop_sref.defined()) { + if (sref->parent != outer_loop_sref.get()) { + throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + } + if (!outer_loop->body.same_as(GetRef(loop))) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + } + } + outer_loop_sref = sref; + outer_loop = loop; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + loops.push_back(loop); + } + // Step 2. Create fused loop var and replace the original loop vars + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Array substitute_value; + substitute_value.resize(loops.size()); + PrimExpr tot = fused_var; + for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { + substitute_value.Set(i, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + Stmt new_stmt = loops.back()->body; + Map opaque_block_reuse; + auto f_substitute = [&](const Var& v) -> Optional { + for (int i = 0; i < n; i++) { + if (v.same_as(loops[i]->loop_var)) { + return substitute_value[i]; + } + } + return NullOpt; + }; + new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); + // Step 3. Generate a loop to replace the original loops + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + fused_extent = analyzer.Simplify(fused_extent); + new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( + std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite()); + self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); + return self->stmt2ref.at(new_stmt.get()); +} + +/******** Instruction Registration ********/ + +struct SplitTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Split"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + thread_local ObjectRef loop_rv{nullptr}; + thread_local Array factors{nullptr}; + loop_rv = inputs[0]; + factors = Array{inputs.begin() + 1, inputs.end()}; + setter(delta, loop_rv); + setter(delta + 1, factors); + } + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, + Array> factors) { + return sch->Split(loop_rv, factors); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Array factors) { + PythonAPICall py("split"); + py.Input("loop", loop_rv); + py.Input("factors", factors); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct FuseTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Fuse"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static LoopRV UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Fuse(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("fuse"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); +TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc new file mode 100644 index 000000000000..bf29ceb1ef9f --- /dev/null +++ b/src/tir/schedule/primitive/reduction.cc @@ -0,0 +1,992 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Commutative Reducer ********/ + +/*! + * \brief A structure used for registering new commutative reducers, and store all the registered + * reducers. The reducers are preserved in a list, in the form of "reducer-getter function". When + * invoking a reducer-getter function with a specific datatype, the reducer-getter will return the + * CommReducer of the corresponding reduction pattern and the specific datatype + */ +struct ReducerRegistry { + ReducerRegistry() + : reducer_getters{CreateReducerGetter([](const Var& x, const Var& y) { return x + y; }, + [](DataType dtype) { return make_const(dtype, 0); }), + CreateReducerGetter([](const Var& x, const Var& y) { return x * y; }, + [](DataType dtype) { return make_const(dtype, 1); }), + CreateReducerGetter([](const Var& x, const Var& y) { return min(x, y); }, + [](DataType dtype) { return max_value(dtype); }), + CreateReducerGetter([](const Var& x, const Var& y) { return max(x, y); }, + [](DataType dtype) { return min_value(dtype); })} {} + + static void RegisterReducer(TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( + std::move(combiner_getter), std::move(identity_getter))); + } + + static TypedPackedFunc CreateReducerGetter( + TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + return [combiner_getter = std::move(combiner_getter), + identity_getter = std::move(identity_getter)](DataType dtype) -> CommReducer { + Var lhs("x", dtype); + Var rhs("y", dtype); + return CommReducer({lhs}, {rhs}, {combiner_getter(lhs, rhs)}, {identity_getter(dtype)}); + }; + } + + static ReducerRegistry* Global() { + static ReducerRegistry instance; + return &instance; + } + + std::vector> reducer_getters; +}; + +std::vector> GetReducerGetters() { + return ReducerRegistry::Global()->reducer_getters; +} + +class NotSerialLoopKindError : public ScheduleError { + public: + explicit NotSerialLoopKindError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The input loop of rfactor is required to be `kSerial`"; + } + + String DetailRenderTemplate() const final { + String str_kind = ForKind2String(loop_->kind); + std::ostringstream os; + os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " + "kind of {0} is `" + << str_kind << "`"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class InitBodyNotBufferStoreError : public ScheduleError { + public: + explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, + bool body_is_bufferstore) + : mod_(std::move(mod)), + block_(std::move(block)), + init_is_bufferstore_(init_is_bufferstore), + body_is_bufferstore_(body_is_bufferstore) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of reduction block are required to be both " + "BufferStore"; + } + + String DetailRenderTemplate() const final { + if (!init_is_bufferstore_ && !body_is_bufferstore_) { + return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor " + "can be applied"; + } else if (!init_is_bufferstore_) { + return "The `init` of block {0} is required to be BufferStore so that rfactor can be applied"; + } else { + ICHECK(!body_is_bufferstore_); + return "The `body` of block {0} is required to be BufferStore so that rfactor can be applied"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + bool init_is_bufferstore_; + bool body_is_bufferstore_; +}; + +class InitBodyNotSameBufferAccessError : public ScheduleError { + public: + explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of the reduction block are required to have the " + "same buffer access pattern"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + const auto* init = block_->init.as(); + const auto* update = block_->body.as(); + os << "The `init` and `body` of the block {0} is required to have the same buffer access " + "pattern. However, in block {0} the `init` writes to " + << init->buffer->name << init->indices << ", and the `body` writes to " + << update->buffer->name << update->indices; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +class FactorAxisOutOfRangeError : public ScheduleError { + public: + explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " + "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The write buffer " << buffer_->name << " has " << ndim + << " dimension(s), so `factor_axis` is required to be in [" << -(ndim + 1) << ", " << ndim + << "] for rfactor. However, the input `factor_axis` is " << factor_axis_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { + int ndim = static_cast(buffer->shape.size()); + if (factor_axis < -(ndim + 1) || factor_axis > ndim) { + throw FactorAxisOutOfRangeError(mod, buffer, factor_axis); + } + // If factor_axis is negative, convert it to a non-negative one. + if (factor_axis < 0) { + factor_axis += ndim + 1; + } + return factor_axis; + } + + IRModule mod_; + Buffer buffer_; + int factor_axis_; +}; + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) + : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + PrimExpr identity_; + BufferStore combiner_; +}; + +class LoopPropertyError : public ScheduleError { + public: + enum ErrorType { + kDataParIterTouchRFactorLoop = 0, + kLoopTouchedByBothKindsOfBlockIters = 1, + kNotFirstChildBlockOfOutermostLoop = 2, + kUnboundLoopUnderReductionLoop = 3 + }; + + explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) + : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} + + String FastErrorString() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " + "data parallel block iter of the block"; + case kLoopTouchedByBothKindsOfBlockIters: + return "ScheduleError: The loops outside of the reduction block are required not to be " + "touched by both data parallel block iters and reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "ScheduleError: The reduction block should be the first child block of the " + "outermost loop outside of it"; + case kUnboundLoopUnderReductionLoop: + return "ScheduleError: A loop who has extent greater than one and is not bound to any " + "block iter should not appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + String DetailRenderTemplate() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "The loop to be applied rfactor is {0}, which is required not to be touched by any " + "data parallel block iter of the block below. However, some of the block's data " + "parallel block iters touch this loop"; + case kLoopTouchedByBothKindsOfBlockIters: + return "It is not allowed that the loop {0} is touched by both some data parallel block " + "iters and some reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "The first child block of the outermost loop {0} is not the reduction block."; + case kUnboundLoopUnderReductionLoop: + return "The loop {0} has extent greater than one, and is not bound to any block iter. " + "Therefore it shouldn't appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + static void CheckLoopProperty(const ScheduleState& self, const Array& loops, + const ForNode* rf_loop, const Block& block, + const std::unordered_set& data_par_loop_vars, + const std::unordered_set& reduce_loop_vars) { + Array children_of_outermost_loop = + GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); + if (!children_of_outermost_loop[0]->block.same_as(block)) { + throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); + } + + bool meet_reduction_loop = false; + for (const For& loop : loops) { + bool data_par_touched = data_par_loop_vars.count(loop->loop_var.get()); + bool reduction_touched = reduce_loop_vars.count(loop->loop_var.get()); + + if (data_par_touched && reduction_touched) { + throw LoopPropertyError(self->mod, loop, kLoopTouchedByBothKindsOfBlockIters); + } else if (data_par_touched) { + if (loop.get() == rf_loop) { + throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); + } + continue; + } else if (reduction_touched) { + if (!meet_reduction_loop) { + CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); + meet_reduction_loop = true; + } + continue; + } else if (meet_reduction_loop && !is_one(loop->extent)) { + throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); + } + } + } + + IRModule mod_; + For loop_; + ErrorType error_type_; +}; + +/*! + * \brief Convert the `init` and `body` of the input block to BufferStores + * \param self The schedule state + * \param block The block to be analyzed + * \return The BufferStores of the `init` and `body` of the input block + * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same + * buffer + */ +std::pair GetBufferStoreNodes(const ScheduleState& self, + const Block& block) { + const auto* init = block->init.as(); + const auto* body = block->body.as(); + if (!(init && body)) { + throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + } + if (!init->buffer.same_as(body->buffer)) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + int ndim = static_cast(init->buffer->shape.size()); + for (int i = 0; i < ndim; ++i) { + if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + } + return std::make_pair(GetRef(init), GetRef(body)); +} + +/*! + * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative + * reducer, and extract the combiner lhs and combiner rhs + * \param self The schedule state + * \param identity The reduction identity to be analyzed + * \param combiner The reduction combiner to be analyzed + * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \throw ScheduleError If no corresponding commutative reducer can be matched + */ +std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; + bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + throw NoMatchedReducerError(self->mod, identity, combiner); + } + return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); +} + +/*! + * \brief For each loop in the given array of loop, associate its loop var with the loop itself + * using a mapping + * \param loops The loops to be analyzed + * \return A mapping from loops to their corresponding loop vars + */ +std::unordered_map GetLoopVar2LoopMap(const Array& loops) { + std::unordered_map loop_vars2loop; + loop_vars2loop.reserve(loops.size()); + for (const For& loop : loops) { + loop_vars2loop[loop->loop_var.get()] = loop; + } + return loop_vars2loop; +} + +/*! + * \brief Create the intermediate rfactor buffer, which the rfactor block writes to and the + * write-back block reads from + * \param buffer The buffer written by the reduction block + * \param factor_axis The `factor_axis` parameter of rfactor + * \param rf_loop The rfactor loop + * \return The new created intermediate rfactor buffer + */ +Buffer CreateRFactorBuffer(const Buffer& buffer, int factor_axis, const ForNode* rf_loop) { + Array rf_shape = buffer->shape; + rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); + + ObjectPtr n = make_object(*buffer.get()); + n->shape = rf_shape; + n->name = buffer->name + ".rf"; + n->data = buffer->data.copy_with_suffix(".rf"); + return Buffer(n); +} + +/*! + * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four + * steps: + * 1) Create the new block iters and the their iter bindings + * 2) Create the reduction update of the new block + * 3) Create the read/write regions of the new block + * 4) Create the new block and the new block-realize + */ +class BaseBlockCreator { + public: + explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, Buffer rf_buffer, + bool is_rf_block) + : old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + old_reduction_update_(std::move(old_reduction_update)), + reducer_(std::move(reducer)), + rf_buffer_(std::move(rf_buffer)), + is_rf_block_(is_rf_block) { + n_block_iters_ = static_cast(old_block_realize_->iter_values.size()); + } + + void CreateBlock() { + CreateAdditionalIter(); + for (int i = 0; i < n_block_iters_; ++i) { + CreateNormalIters(i); + } + CreateReductionUpdate(); + CreateReadWriteRegions(); + + String new_block_name = old_block_realize_->block->name_hint; + PrimExpr predicate = Bool(true); + if (is_rf_block_) { + new_block_name = new_block_name + "_rf"; + predicate = old_block_realize_->predicate; + } + new_block_ = Block( + /*iter_vars=*/iter_vars_, + /*reads=*/read_regions_, + /*writes=*/write_regions_, + /*name_hint=*/new_block_name, + /*body=*/new_reduction_update_, + /*init=*/ + BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], + new_reduction_update_->indices)); + new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_); + } + + private: + virtual void CreateAdditionalIter() = 0; + virtual void CreateNormalIters(int idx) = 0; + virtual void CreateReductionUpdate() = 0; + virtual void CreateReadWriteRegions() = 0; + + public: + /*! \brief The new created block */ + Block new_block_; + /*! \brief The new created block-realize */ + BlockRealize new_block_realize_; + /*! \brief The indices used to access the intermediate rfactor buffer */ + Array rf_buf_access_indices_; + + protected: + /*! \brief The old block-realize */ + BlockRealize old_block_realize_; + /*! \brief The number of block iters in the old block */ + int n_block_iters_; + /*! \brief The rfactor loop */ + For rf_loop_; + /*! \brief The update BufferStore of the old block */ + BufferStore old_reduction_update_; + /*! \brief The matched commutative reducer */ + CommReducer reducer_; + /*! \brief The intermediate rfactor buffer */ + Buffer rf_buffer_; + + /*! \brief Whether we are creating the rfactor block or the write-back block */ + bool is_rf_block_; + /*! \brief The new block iters of the new created block */ + std::vector iter_vars_; + /*! \brief The new block iter bindings of the new created block-realize */ + std::vector iter_values_; + /*! + * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced + * by the expressions in future substitution for the two blocks + */ + Map var_map_; + /*! \brief The update BufferStore of the new created block */ + BufferStore new_reduction_update_; + /*! \brief The read regions of the new created block */ + Array read_regions_; + /*! \brief The write regions of the new created block */ + Array write_regions_; +}; + +/*! + * \brief The derived class of the rfactor block creator, which implements all virtual methods in + * the base creator + * \details Start constructing the rfactor block. The main difficulty to construct the rfactor block + * is to create its block iters. So here we introduce the algorithm to create the block iters. + * 1. Create a block iter for the rfactor loop. The block binding of this iter is the loop var, and + * the block iter is data parallel. + * 2. For all the old block's block iters, there are two cases: + * (a) If it is data parallel block iter, or a reduction block iter which doesn't touch the + * rfactor loop, we keep it and its block binding in the rfactor block. + * (b) Otherwise it is a reduction block iter which touches the rfactor loop. In this case, we + * "split" the block iter into one or more new block iters and do not keep the old block + * var. More specifically, we create a new reduction block iter for each loop var that + * appears in the reduction block iter's binding (except for the rfactor loop), and the + * binding of the new block iter is exactly the loop var. (Note that for each loop var, we + * create at most one block iter, even if there are multiple old block iters which touch + * both this loop and the rfactor loop). + * Then we substitute the appearances of the old block iter with the new created block + * iters by recording two mappings: one maps loops vars to new created block iters which + * is used for binding substitution, and another maps old block iters to new expressions + * which is used for substitutions of the old block iters. + */ +class RFactorBlockCreator : public BaseBlockCreator { + public: + explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, + std::unordered_map loop_vars2loop, + int factor_axis, PrimExpr combiner_rhs) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + true), + loop_vars2loop_(std::move(loop_vars2loop)), + factor_axis_(factor_axis), + combiner_rhs_(std::move(combiner_rhs)) {} + + private: + void CreateAdditionalIter() final { + // Create a new data parallel block iter for the rfactor loop. + additional_iter_ = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); + loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; + iter_vars_.push_back(additional_iter_); + iter_values_.push_back(rf_loop_->loop_var); + } + + void CreateNormalIters(int idx) final { + IterVar old_iter = old_block_realize_->block->iter_vars[idx]; + PrimExpr old_binding = old_block_realize_->iter_values[idx]; + if (old_iter->iter_type == IterVarType::kDataPar || + !UsesVar(old_binding, + [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { + // The old block iter is either a data parallel block iter, or a reduction block iter that + // doesn't touch the rfactor loop. In this case reuse the old reduction block iter and its + // corresponding binding. + iter_vars_.push_back(old_iter); + iter_values_.push_back(old_binding); + return; + } + ICHECK(old_iter->iter_type == kCommReduce); + // This block iter is a reduction block iter that touches the rfactor loop. So next we try to + // create a new block iter for all loop vars that appear in the old binding. + Array vars_in_old_binding = UndefinedVars(old_binding); + for (const Var& var : vars_in_old_binding) { + auto it = loop_vars2loop_.find(var.get()); + if (it == loop_vars2loop_.end()) { + // `var` is not a loop var. So skip. + continue; + } + const For& loop = it->second; + if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { + // We haven't created the new block iter for `var`. So here we create it, append it + // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. + IterVar new_iter_var = + IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); + loop_var2block_binding_[var.get()] = new_iter_var->var; + iter_vars_.push_back(new_iter_var); + iter_values_.push_back(var); + } + } + // Substitute the original binding with new block iters. Store the result expression + // in `rf_var_map` for future substitution. + var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); + } + + void CreateReductionUpdate() final { + rf_buf_access_indices_ = old_reduction_update_->indices; + rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, + additional_iter_->var); + new_reduction_update_ = BufferStore( + rf_buffer_, + (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0], + rf_buf_access_indices_); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + const Block& old_block = old_block_realize_->block; + read_regions_ = CreateRegions(old_block->reads); + write_regions_ = CreateRegions(old_block->writes); + } + + Array CreateRegions(const Array& old_regions) { + Array new_regions; + new_regions.reserve(old_regions.size()); + for (const BufferRegion& buffer_region : old_regions) { + if (buffer_region->buffer.same_as(old_reduction_update_->buffer)) { + Array region = buffer_region->region; + region.insert(region.begin() + factor_axis_, + Range::FromMinExtent(additional_iter_->var, 1)); + new_regions.push_back(BufferRegion(rf_buffer_, Substitute(region, var_map_))); + } else { + new_regions.push_back( + BufferRegion(buffer_region->buffer, Substitute(buffer_region->region, var_map_))); + } + } + return new_regions; + } + + public: + /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ + IterVar additional_iter_; + + private: + /*! + * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction + * block's outer loops + */ + std::unordered_map loop_vars2loop_; + /*! \brief The factor_axis specified for rfactor */ + int factor_axis_; + /*! \brief The rhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_rhs_; + /*! + * \brief A mapping which maps loop vars to new created block iters. This map is used to + * substitute the loop vars which appear in the bindings of some old block iters with the new + * created block iters + */ + std::unordered_map loop_var2block_binding_; +}; + +/*! + * \brief The derived class of the write-back block creator, which implements all virtual methods in + * the base creator + */ +class WriteBackBlockCreator : public BaseBlockCreator { + public: + explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, IterVar rf_additional_iter, + PrimExpr combiner_lhs, Array rf_buf_access_indices) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + false), + rf_additional_iter_(std::move(rf_additional_iter)), + combiner_lhs_(std::move(combiner_lhs)) { + iter_vars_.reserve(n_block_iters_); + iter_values_.reserve(n_block_iters_); + rf_buf_access_indices_ = std::move(rf_buf_access_indices); + } + + private: + void CreateAdditionalIter() final { + // Create a new reduction block iter for the rfactor loop. + IterVar wb_new_block_iter = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.push_back(wb_new_block_iter); + iter_values_.push_back(rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); + } + + void CreateNormalIters(int idx) final { + IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; + if (old_block_iter->iter_type == IterVarType::kDataPar) { + iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix(""), + kDataPar); + iter_values_.push_back(old_block_realize_->iter_values[idx]); + var_map_.Set(old_block_iter->var, iter_vars_.back()); + } + } + + void CreateReductionUpdate() final { + wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); + wb_rhs_ = + Downcast(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_)); + new_reduction_update_ = + BufferStore(old_reduction_update_->buffer, (*reducer_.get())({wb_lhs_}, {wb_rhs_})[0], + old_reduction_update_->indices); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + read_regions_.push_back(CreateRegion(wb_lhs_)); + read_regions_.push_back(CreateRegion(wb_rhs_)); + write_regions_.push_back(read_regions_[0]); + } + + static BufferRegion CreateRegion(const BufferLoad& load) { + Array region; + region.reserve(load->indices.size()); + for (const PrimExpr& index : load->indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + return BufferRegion(load->buffer, std::move(region)); + } + + private: + /*! \brief The new created additional block iter of the rfactor block */ + IterVar rf_additional_iter_; + /*! \brief The lhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_lhs_; + /*! \brief The lhs of the combiner of the write-back block */ + BufferLoad wb_lhs_; + /*! \brief The rhs of the combiner of the write-back block */ + BufferLoad wb_rhs_; +}; + +/*! + * \brief Create new outer loops for the rfactor block, meanwhile update the rfactor block's iter + * bindings to use the new created loop vars + * \param rf_block_realize The BlockRealize of the rfactor block + * \param loops The loops to be wrapped over the rfactor block + * \return A Stmt which is the wrapping result + */ +Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array& loops) { + int n_loops = static_cast(loops.size()); + + // Step 1. Create new loop vars. + Array new_loops; + std::unordered_map new_loop_var_map; + new_loops.reserve(n_loops); + new_loop_var_map.reserve(n_loops); + for (const For& old_loop : loops) { + Var new_loop_var = old_loop->loop_var.copy_with_suffix(""); + new_loop_var_map[old_loop->loop_var.get()] = new_loop_var; + } + + // Step 2. Update the iter bindings of the rfactor block. + Array new_bindings; + new_bindings.reserve(rf_block_realize->iter_values.size()); + for (const PrimExpr& old_binding : rf_block_realize->iter_values) { + new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); + } + rf_block_realize.CopyOnWrite()->iter_values = new_bindings; + + // Step 3. Wrap `rf_block_realize` with outer loops. + Stmt rf_body = rf_block_realize; + for (int i = n_loops - 1; i >= 0; --i) { + ObjectPtr p_loop = make_object(*loops[i].get()); + p_loop->loop_var = Downcast(new_loop_var_map[loops[i]->loop_var.get()]); + p_loop->body = rf_body; + rf_body = For(std::move(p_loop)); + } + + return rf_body; +} + +class BlockReplacer : public StmtMutator { + public: + /*! + * \brief The replace takes the old scope root block as input, and does four things: + * 1) replace the reduction block with the write-back block, + * 2) remove loops outside the write-back block that are touched by reduction block iters, except + * for the rfactor loop + * 3) combine the rfactor block (wrapped with outer loops) and the transformed outermost loop + * into a SeqStmt, and + * 4) insert the rfactor buffer into the scope root block's `alloc_buffers` + * After transformation, the function returns the new scope root block + * \param scope_root_block The old scope root block + * \param rf_body The rfactor block, which is already wrapped with outer loops + * \param outermost_loop The loop that is outermost among all loops outside the reduction block + * \param wb_block_realize The new created BlockRealize of the write-back block + * \param old_block_realize The BlockRealize of the reduction block + * \param rf_loop The rfactor loop, which should be kept outside the write-back block + * \param reduce_loop_vars The loops that are touched by reduction block iters, used to remove + * loops outside the write-back block + * \param loop_vars2loop The mapping from loop vars to loops that are outside the reduction block, + * which is used to reduce redundant recursive visits + * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers` + * \return The transformed new scope root block + */ + static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop, + BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop, + const Buffer& rf_buffer) { + BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), + std::move(wb_block_realize), std::move(old_block_realize), + std::move(rf_loop), std::move(reduce_loop_vars), + std::move(loop_vars2loop)); + Block new_scope_root = Downcast(replacer(std::move(scope_root_block))); + BlockNode* p = new_scope_root.CopyOnWrite(); + p->alloc_buffers.push_back(rf_buffer); + return new_scope_root; + } + + private: + explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize, + BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop) + : rf_body_(std::move(rf_body)), + outermost_loop_(std::move(outermost_loop)), + wb_block_realize_(std::move(wb_block_realize)), + old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + reduce_loop_vars_(std::move(reduce_loop_vars)), + loop_vars2loop_(std::move(loop_vars2loop)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Check whether this loop is outside the reduction block. Given that we've made sure + // that the scope root block has stage-pipeline property, if this loop is not outside the + // reduction block, there's no need to recursively mutate. + if (!loop_vars2loop_.count(loop->loop_var.get())) { + return GetRef(loop); + } + + // Step 2. Recursively mutate. + Stmt body = StmtMutator::VisitStmt(loop->body); + + // Step 3. If this loop is the rfactor loop and isn't touched by any reduction block iter, it + // should be kept outside the write-back block. Otherwise it shouldn't. + if (loop == rf_loop_.get() || !reduce_loop_vars_.count(loop->loop_var.get())) { + ObjectPtr p_loop = CopyOnWrite(loop); + p_loop->body = body; + body = Stmt(p_loop); + } + + // Step 4. If this loop is the outermost loop of the reduction block, return the combination of + // `rf_body_` and the mutation result `body`. Otherwise return the mutation result. + return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's + // block-realize. And we directly return the new `wb_block_realize`. + ICHECK_EQ(block_realize, old_block_realize_.get()); + return wb_block_realize_; + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array new_stmts; + new_stmts.reserve(static_cast(seq->seq.size())); + + for (const Stmt old_stmt : seq->seq) { + new_stmts.push_back(VisitStmt(old_stmt)); + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Stmt rf_body_; + For outermost_loop_; + BlockRealize wb_block_realize_; + BlockRealize old_block_realize_; + For rf_loop_; + std::unordered_set reduce_loop_vars_; + std::unordered_map loop_vars2loop_; +}; + +StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_axis) { + // ***************************************************** + // * Condition Checks and Information Collection * + // ***************************************************** + + // Step 1. Check some basic conditions for rfactor. Get the block and block-realize. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); + const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); + const Block& block = block_realize->block; + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + CheckReductionBlock(self, block_sref, scope_root); + const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); + if (rf_loop->kind != ForKind::kSerial) { + throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); + } + + // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block + // iters, respectively. + std::unordered_set data_par_loop_vars; + std::unordered_set reduce_loop_vars; + GetVarsTouchedByBlockIters(block_realize, &data_par_loop_vars, &reduce_loop_vars); + + // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to + // corresponding loop vars. + Array loops = LoopSRefs2Loops(GetLoops(block_sref)); + std::unordered_map loop_vars2loop = GetLoopVar2LoopMap(loops); + + // Step 4. Check four properties that the loops should have: + // - the rfactor loop cannot be touched by any data parallel block iter; + // - all the loops cannot be touched by both data parallel block iters and reduction block iters; + // - the outermost loop should have the reduction block as its first child block; + // - the outermost loop that is touched by some reduction block iters can only have one child + // block. + LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, + reduce_loop_vars); + + // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the + // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the + // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs + // will be used when constructing the rfactor block. + BufferStore init; + BufferStore update; + CommReducer reducer; + PrimExpr combiner_lhs, combiner_rhs; + std::tie(init, update) = GetBufferStoreNodes(self, block); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(self, init->value, update); + + // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it + // is negative. + factor_axis = FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, update->buffer, factor_axis); + + // ***************************************************** + // * IR Manipulation * + // ***************************************************** + // Since rfactor splits the reduction block into two, we call the first one "rfactor block", and + // the latter one "write-back block", and the intermediate buffer is called "rfactor buffer". + + // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional + // dimension that specified by `factor_axis` and `rf_loop`. + Buffer rf_buffer = CreateRFactorBuffer(update->buffer, factor_axis, rf_loop); + + // Step 2. Create the rfactor block. + RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, loop_vars2loop, factor_axis, + std::move(combiner_rhs)); + rf_block_creator.CreateBlock(); + + // Step 3. Create the write-back block. + WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, std::move(rf_block_creator.additional_iter_), + std::move(combiner_lhs), + std::move(rf_block_creator.rf_buf_access_indices_)); + wb_block_creator.CreateBlock(); + + // Step 4. Wrap the rfactor block with loops. + Stmt rf_body = CreateLoopOutsideRfactorBlock(rf_block_creator.new_block_realize_, loops); + + // ***************************************************** + // * Schedule Replacement & Update * + // ***************************************************** + + // Step 1. Substitute the old scope root block with the new scope root block. + Block old_scope_root_block = GetRef(scope_root->StmtAs()); + Block new_scope_root_block = BlockReplacer::Replace( + old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, + GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffer); + self->Replace(scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}}); + + // Step 2. Update scope information. + std::vector new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), + self->stmt2ref.at(wb_block_creator.new_block_.get())}; + for (const StmtSRef& new_block_sref : new_block_srefs) { + BlockInfo& info = self->block_info[new_block_sref]; + info.affine_binding = true; + info.region_cover = true; + info.scope->stage_pipeline = true; + } + return new_block_srefs[0]; +} + +/******** Instruction Registration ********/ + +struct RFactorTraits : public UnpackedInstTraits { + static constexpr const char* kName = "RFactor"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { + return sch->RFactor(loop_rv, factor_axis->value); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Integer factor_axis) { + PythonAPICall py("rfactor"); + py.Input("loop", loop_rv); + py.Input("factor_axis", factor_axis->value); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits); + +/******** FFI ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") + .set_body_typed([](PackedFunc combiner_getter, PackedFunc identity_getter) { + ReducerRegistry::RegisterReducer(std::move(combiner_getter), std::move(identity_getter)); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 115f7936f64e..eda6ac27d283 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include "./utils.h" namespace tvm { namespace tir { @@ -55,17 +55,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // /**************** (FFI) Constructor ****************/ +TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); +TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule { - IRModule mod{nullptr}; - if (const auto* func = obj.as()) { - mod = IRModule({{GlobalVar("main"), GetRef(func)}}); - } else if (const auto* p_mod = obj.as()) { - mod = GetRef(p_mod); - } else { - LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " - << obj->GetTypeKey(); - } + .set_body_typed([](IRModule mod, int debug_mode, int error_render_level) -> Schedule { return Schedule::Concrete(mod, debug_mode, static_cast(error_render_level)); }); @@ -116,22 +109,30 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") throw; }); -/***** (FFI) Block/Loop relation *****/ - +/******** (FFI) Sampling ********/ +/******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); -/******** (FFI) loops manipulation ********/ -/******** (FFI) compute location ********/ +/******** (FFI) Transform loops ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +/******** (FFI) Manipulate ForKind ********/ +/******** (FFI) Insert cache stages ********/ +/******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); -/******** (FFI) loop binding/annotation ********/ -/******** (FFI) cache read/write ********/ -/******** (FFI) reduction ********/ -/******** (FFI) blockize & tensorize ********/ +/******** (FFI) Reduction ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") + .set_body_method(&ScheduleNode::RFactor); +/******** (FFI) Blockize & Tensorize ********/ +/******** (FFI) Annotation ********/ +/******** (FFI) Misc ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") + .set_body_method(&ScheduleNode::EnterPostproc); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ca61dfea2768..8f0284f2901e 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -43,7 +43,7 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, AsIntSet(LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)))); + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())))); } /*! @@ -67,7 +67,7 @@ Array AnalyzeRegionLowerBound(const BlockRealize& realize, LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)), + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())), /*predicate=*/realize->predicate, /*analyzer=*/analyzer)) { return result.value(); } @@ -161,34 +161,6 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new sref->stmt = new_stmt; } -/*! - * \brief Get PrimFunc and GlobalVar that the root block belongs to - * \param mod The IRModule - * \param root_block The root block of the PrimFunc - * \param result_g_var The result GlobalVar - * \return The result PrimFunc where the root block belongs to - * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write - */ -const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, - GlobalVar* result_g_var) { - for (const auto& kv : mod->functions) { - const GlobalVar& g_var = kv.first; - const BaseFunc& base_func = kv.second; - if (const auto* func = base_func.as()) { - if (const auto* realize = func->body.as()) { - if (realize->block.get() == root_block) { - *result_g_var = g_var; - return func; - } - } - } - } - LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " - "statement:\n" - << GetRef(root_block); - throw; -} - /**************** Creation ****************/ /*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ @@ -444,9 +416,6 @@ ScheduleState::ScheduleState(IRModule mod, int debug_mode) { data_ = StateCreator::Create(mod, debug_mode); } -ScheduleState::ScheduleState(PrimFunc func, int debug_mode) - : ScheduleState(IRModule({{GlobalVar("main"), func}}), debug_mode) {} - /**************** Replace ****************/ /* @@ -737,7 +706,7 @@ class SRefUpdater : public StmtVisitor { void UpdateBlockInfo(const StmtSRef& block_sref) { using TIter = std::unordered_map::iterator; // The caller is responsible for correcting the flags - BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref))); + BlockInfo new_info((BlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref)))); std::pair insert_result = self_->block_info.emplace(block_sref, new_info); bool inserted = insert_result.second; BlockInfo& info = insert_result.first->second; @@ -1045,7 +1014,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" @@ -1063,16 +1032,10 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState").set_body_typed([](ObjectRef obj, int debug_mode) { - if (const auto* func = obj.as()) { - return ScheduleState(GetRef(func), debug_mode); - } - if (const auto* mod = obj.as()) { - return ScheduleState(GetRef(mod), debug_mode); - } - LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); - throw; -}); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") + .set_body_typed([](IRModule mod, int debug_mode) -> ScheduleState { + return ScheduleState(mod, debug_mode); + }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc new file mode 100644 index 000000000000..d8c18f0de0d6 --- /dev/null +++ b/src/tir/schedule/trace.cc @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +/**************** Constructors ****************/ + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(Array insts, Map decisions) { + ObjectPtr n = make_object(); + n->insts = std::move(insts); + n->decisions = std::move(decisions); + data_ = std::move(n); +} + +/**************** Utilities ****************/ + +bool IsPostproc(const InstructionKind& inst_kind) { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return inst_kind.same_as(inst_enter_postproc); +} + +int GetNumValidInstructions(const Array& insts, bool remove_postproc) { + if (!remove_postproc) { + return insts.size(); + } + int n_insts = 0; + for (const Instruction& inst : insts) { + if (!IsPostproc(inst->kind)) { + ++n_insts; + } else { + break; + } + } + return n_insts; +} + +/**************** TranslateInputRVs ****************/ + +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& rv_map) { + Array result; + result.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + if (!input.defined() || // constant: nullptr + input->IsInstance() || // constant: string + input->IsInstance() || // constant: integer + input->IsInstance()) { // constant: float + result.push_back(input); + } else if (input->IsInstance() || // RV: block + input->IsInstance() || // RV: loop + input->IsInstance()) { // RV: var + auto it = rv_map.find(input.get()); + ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; + result.push_back(GetRef(it->second)); + } else if (const auto* expr = input.as()) { // RV: Expr + result.push_back( + Substitute(GetRef(expr), [&rv_map](const Var& var) -> Optional { + auto it = rv_map.find(var.get()); + if (it == rv_map.end()) { + return NullOpt; + } + const Object* dst = it->second; + ICHECK(dst->IsInstance()) + << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); + return GetRef(static_cast(dst)); + })); + } else { + ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " + << input->GetTypeKey(); + throw; + } + } + return result; +} + +Array TranslateInputRVs( + const Array& inputs, + const std::unordered_map& rv_names) { + Array results; + results.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + if (!input.defined()) { + // Case 0. nullptr => None + results.push_back(String("None")); + continue; + } + auto it = rv_names.find(input); + if (it != rv_names.end()) { + // Case 1. BlockRV, LoopRV, VarRV + results.push_back(it->second); + } else if (const auto* str_obj = input.as()) { + // Case 2. string => "content" + results.push_back(String('"' + std::string(str_obj->data) + '"')); + } else if (input->IsInstance() || input->IsInstance()) { + // Case 3. integer or floating-point number + results.push_back(input); + } else if (input->IsInstance() || inputs->IsInstance() || + inputs->IsInstance()) { + LOG(FATAL) << "IndexError: Random variable is not defined " << input; + throw; + } else { + LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input->GetTypeKey(); + throw; + } + } + return results; +} + +Array TranslateInputRVs(const Array& inputs, + const std::unordered_map& named_rvs) { + Array results; + results.reserve(inputs.size()); + for (const ObjectRef& input : inputs) { + // Case 3. integer or floating-point number + if (input->IsInstance() || input->IsInstance()) { + results.push_back(input); + continue; + } + const auto* str = input.as(); + CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); + CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; + const char* name = str->data; + int64_t size = str->size; + // Case 2. string + if (size > 2 && name[0] == '"' && name[size - 1] == '"') { + results.push_back(String(std::string(name + 1, size - 2))); + continue; + } + // Case 0 & 1. None, BlockRV, LoopRV, VarRV + auto it = named_rvs.find(name); + CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; + results.push_back(it->second); + } + return results; +} + +/**************** TranslateAddOutputRVs ****************/ + +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* rv_map) { + ICHECK_EQ(old_outputs.size(), new_outputs.size()); + int n = old_outputs.size(); + const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); + const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); + for (int i = 0; i < n; ++i) { + (*rv_map)[p_old[i].get()] = p_new[i].get(); + } +} + +Array TranslateAddOutputRVs( + const Array& outputs, + std::unordered_map* rv_names) { + Array results; + results.reserve(outputs.size()); + for (const ObjectRef& output : outputs) { + int i = rv_names->size(); + ICHECK(!rv_names->count(output)) + << "ValueError: The random variable has been produced once: " << rv_names->at(output); + String result{ObjectPtr{nullptr}}; + if (output->IsInstance()) { + result = "b" + std::to_string(i); + } else if (output->IsInstance()) { + result = "l" + std::to_string(i); + } else if (output->IsInstance()) { + result = "v" + std::to_string(i); + } else { + LOG(FATAL) << "TypeError: Cannot recognize the type of the random variable: " + << output->GetTypeKey(); + throw; + } + results.push_back(result); + rv_names->emplace(output, std::move(result)); + } + return results; +} + +void TranslateAddOutputRVs(const Array& old_outputs, const Array& new_outputs, + std::unordered_map* named_rvs) { + ICHECK_EQ(old_outputs.size(), new_outputs.size()); + int n = old_outputs.size(); + const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); + const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); + for (int i = 0; i < n; ++i) { + const auto* name = static_cast(p_old[i].get()); + named_rvs->emplace(std::string(name->data, name->size), p_new[i]); + } +} + +/**************** Add/Remove/Get ****************/ + +Optional TraceNode::GetDecision(const Instruction& inst) const { + auto it = this->decisions.find(inst); + return it == this->decisions.end() ? Optional(NullOpt) : (*it).second; +} + +void TraceNode::Append(Instruction inst) { insts.push_back(std::move(inst)); } + +void TraceNode::Append(Instruction inst, ObjectRef decision) { + decisions.Set(inst, std::move(decision)); + insts.push_back(std::move(inst)); +} + +Optional TraceNode::Pop() { + if (insts.empty()) { + return NullOpt; + } + Instruction inst = insts.back(); + insts.pop_back(); + if (decisions.count(inst)) { + decisions.erase(inst); + } + return inst; +} + +/**************** Interfacing with InstructionKind ****************/ + +void TraceNode::ApplyToSchedule( + Schedule sch, bool remove_postproc, + runtime::TypedPackedFunc& inputs, // + const Array& attrs, // + const Optional& decision)> + decision_provider) const { + std::unordered_map rv_map; + for (const Instruction& inst : this->insts) { + if (remove_postproc && IsPostproc(inst->kind)) { + break; + } + Array inputs = TranslateInputRVs(inst->inputs, rv_map); + Array attrs = inst->attrs; + Optional decision = this->GetDecision(inst); + if (decision_provider != nullptr) { + decision = decision_provider(inst, inputs, attrs, decision); + } + Array outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); + TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); + } +} + +ObjectRef TraceNode::AsJSON(bool remove_postproc) const { + std::unordered_map rv_names; + Array json_insts; + Array json_decisions; + json_insts.reserve(this->insts.size()); + json_decisions.reserve(this->insts.size()); + + int i = 0; + for (const Instruction& inst : this->insts) { + const InstructionKind& kind = inst->kind; + if (remove_postproc && IsPostproc(kind)) { + break; + } + json_insts.push_back(Array{ + /* 0: inst name */ kind->name, + /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), + /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) + : ObjectRef(inst->attrs), + /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), + }); + if (Optional decision = this->GetDecision(inst)) { + json_decisions.push_back(Array{ + /* 0: index */ Integer(i), + /* 1: decision */ decision.value(), + }); + } + ++i; + } + return Array{ + /* 0: trace */ std::move(json_insts), + /* 1: decision */ std::move(json_decisions), + }; +} + +Array TraceNode::AsPython(bool remove_postproc) const { + std::unordered_map rv_names; + Array py_trace; + py_trace.reserve(this->insts.size()); + for (const Instruction& inst : this->insts) { + if (remove_postproc && IsPostproc(inst->kind)) { + break; + } + Array attrs; + attrs.reserve(inst->attrs.size()); + for (const ObjectRef& obj : inst->attrs) { + if (const auto* str = obj.as()) { + attrs.push_back(String('"' + std::string(str->data) + '"')); + } else { + attrs.push_back(obj); + } + } + py_trace.push_back( + inst->kind->f_as_python(/*inputs=*/TranslateInputRVs(inst->inputs, rv_names), + /*attrs=*/attrs, + /*decision=*/this->GetDecision(inst), + /*outputs=*/TranslateAddOutputRVs(inst->outputs, &rv_names))); + } + return py_trace; +} + +void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { + Array json_insts{nullptr}; + Array json_decisions{nullptr}; + // Parse `json` into `json_insts` and `json_decisions` + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + json_insts = GetRef>(arr0); + json_decisions = GetRef>(arr1); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " + "instructions and an array of decisions, but gets: " + << json; + throw; + } + // Parse `json_decisions` + std::vector> decisions(json_insts.size(), NullOpt); + for (const ObjectRef& decision_entry : json_decisions) { + int index = -1; + ObjectRef decision{nullptr}; + try { + const ArrayNode* arr = decision_entry.as(); + ICHECK(arr && arr->size() == 2); + const IntImmNode* arr0 = arr->at(0).as(); + ICHECK(arr0); + index = arr0->value; + decision = arr->at(1); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " + "decision], but gets: " + << decision_entry; + throw; + } + decisions[index] = std::move(decision); + } + // Parse `json_insts` + std::unordered_map named_rvs{{"None", ObjectRef{nullptr}}}; + int i = 0; + for (const ObjectRef& inst_entry : json_insts) { + InstructionKind kind{nullptr}; + Array inputs{nullptr}; + Array attrs{nullptr}; + Array outputs{ObjectPtr{nullptr}}; + // Parse the entry + try { + const auto* arr = inst_entry.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + for (const ObjectRef& str : *arr3) { + ICHECK(str->IsInstance()); + } + kind = InstructionKind::Get(arr0->data); + inputs = GetRef>(arr1); + attrs = GetRef>(arr2); + outputs = GetRef>(arr3); + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " + "inputs, attrs, outputs], but gets: " + << inst_entry; + throw; + } + // Parse inputs + inputs = TranslateInputRVs(inputs, named_rvs); + // Parse attrs + if (kind->f_attrs_from_json != nullptr) { + attrs = kind->f_attrs_from_json(attrs); + } + // Apply to the schedule + Array new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); + // Parse outputs + TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); + ++i; + } +} + +/**************** Creation ****************/ + +Trace TraceNode::WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const { + int n_insts = GetNumValidInstructions(this->insts, remove_postproc); + Array new_insts = + Array{this->insts.begin(), this->insts.begin() + n_insts}; + Map new_decisions{this->decisions.begin(), this->decisions.end()}; + new_decisions.Set(std::move(inst), std::move(decision)); + return Trace(new_insts, new_decisions); +} + +Trace TraceNode::Simplified(bool remove_postproc) const { + int n_insts = GetNumValidInstructions(this->insts, remove_postproc); + std::unordered_set used_rvs; + std::vector new_insts; + std::unordered_map new_decisions; + new_insts.reserve(n_insts); + new_decisions.reserve(this->decisions.size()); + for (int inst_idx = n_insts - 1; inst_idx >= 0; --inst_idx) { + const Instruction& inst = this->insts[inst_idx]; + // Check if all the variables the instruction defined are dead + // If so, and the instruction is pure, we can safely remove this instruction + bool all_defs_dead = inst->kind->is_pure; + if (all_defs_dead) { + for (const ObjectRef& obj : inst->outputs) { + if (used_rvs.count(obj.get())) { + all_defs_dead = false; + break; + } + } + } + // Remove this instruction + if (all_defs_dead) { + continue; + } + // Otherwise this instruction is not dead + new_insts.push_back(inst); + if (Optional decision = this->GetDecision(inst)) { + new_decisions.emplace(inst, std::move(decision)); + } + // Add its inputs as "used" ones + for (const ObjectRef& obj : inst->inputs) { + if (obj->IsInstance() || obj->IsInstance() || + obj->IsInstance()) { + used_rvs.insert(obj.get()); + continue; + } else if (obj->IsInstance()) { + PostOrderVisit(obj, [&used_rvs](const ObjectRef& obj) -> void { + if (obj->IsInstance()) { + used_rvs.insert(obj.get()); + } + }); + } + } + } + return Trace(Array(new_insts.rbegin(), new_insts.rend()), + Map(new_decisions)); +} + +/**************** Repr ****************/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const auto* self = obj.as(); + ICHECK_NOTNULL(self); + Array repr = self->AsPython(/*remove_postproc=*/false); + bool is_first = true; + for (const String& line : repr) { + if (is_first) { + is_first = false; + } else { + p->stream << std::endl; + } + p->stream << line; + } + }); + +/**************** Instruction Registration ****************/ + +struct EnterPostprocTraits : public UnpackedInstTraits { + static constexpr const char* kName = "EnterPostproc"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } + + static String UnpackedAsPython(Array outputs) { + PythonAPICall py("enter_postproc"); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); + +/**************** FFI ****************/ + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("tir.schedule.Trace") + .set_body_typed([](Optional> insts, + Optional> decisions) { + return Trace(insts.value_or(Array()), + decisions.value_or(Map())); + }); +TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision") + .set_body_method(&TraceNode::GetDecision); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") + .set_body_typed([](Trace self, Instruction inst, Optional decision) { + if (decision.defined()) { + return self->Append(inst, decision.value()); + } else { + return self->Append(inst); + } + }); +TVM_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); +TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") + .set_body_method(&TraceNode::ApplyToSchedule); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); +TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision") + .set_body_method(&TraceNode::WithDecision); +TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); +TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") + .set_body_typed(Trace::ApplyJSONToSchedule); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 19ed995ac8cc..8ccf8da731b5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -25,18 +25,22 @@ #include #include #include +#include #include #include +#include #include #include #include +#include "../../node/attr_registry.h" #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/array.h" #include "./analysis.h" #include "./error.h" +#include "./instruction_traits.h" #include "./primitive.h" namespace tvm { @@ -98,6 +102,21 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +/*! + * \brief Convert an array of loop StmtSRefs to an array of loops + * \param loop_srefs The loop StmtSRefs to be converted + * \return The conversion result loops + */ +inline Array LoopSRefs2Loops(const Array& loop_srefs) { + Array loops; + loops.reserve(loop_srefs.size()); + for (StmtSRef loop_sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + loops.push_back(GetRef(loop)); + } + return loops; +} + /******** Storage scope ********/ /*! @@ -143,6 +162,18 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { return SeqStmt::Flatten(new_stmts); } +/*! + * \brief Create a new IterVar for the input For loop, with specified name and type + * \param loop The loop to be created from + * \param name The name of the new IterVar + * \param iter_var_type The type of the new IterVar + * \return The newly created IterVar + */ +inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { + return IterVar(Range::FromMinExtent(loop->min, loop->extent), + Var(std::move(name), loop->loop_var.dtype()), iter_var_type); +} + /******** Integer set ********/ /*! diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 9cd29357f8c7..293c990d2745 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -88,7 +88,7 @@ void ArgBinder::BindArray(const Array& arg, const Array& val void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - ICHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; + ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg_name << " Buffer bind scope mismatch"; ICHECK_EQ(arg->dtype, value->dtype) << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 7a8789457923..76845cbebd2a 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -323,8 +323,8 @@ class BF16LowerRewriter : public StmtExprMutator { DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, - oldbuf->name, oldbuf->scope, oldbuf->data_alignment, - oldbuf->offset_factor, oldbuf->buffer_type); + oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor, + oldbuf->buffer_type); buffer_remap_[oldbuf] = newbuf; var_remap_[oldbuf->data] = buffer_var; changes.emplace_back(itr.first, newbuf); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index edbafe27cf13..bd1fa9bce836 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -203,7 +203,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { std::unordered_map dom_map; for (const ForNode* loop : ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); - if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer->scope))) { + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } } @@ -362,6 +362,7 @@ class BufferCompactor : public StmtExprMutator { BlockNode* n = block.CopyOnWrite(); RewriteBufferRegions(&n->reads); RewriteBufferRegions(&n->writes); + RewriteMatchBuffers(&n->match_buffers); n->alloc_buffers = std::move(alloc_buffers); return std::move(block); } @@ -434,6 +435,18 @@ class BufferCompactor : public StmtExprMutator { *regions = std::move(new_regions); } + void RewriteMatchBuffers(Array* match_buffers) const { + Array result; + result.reserve(match_buffers->size()); + for (const auto& match_buffer : *match_buffers) { + const BufferRegion& buffer_region = match_buffer->source; + auto p = make_object(*buffer_region.get()); + RewriteBufferRegion(&p->buffer, &p->region); + result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); + } + *match_buffers = std::move(result); + } + /*! \brief The allocation information about each buffer. */ std::unordered_map buffer_info_; }; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 07f7b42fe2eb..f1f914fa2f5c 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -127,13 +127,9 @@ class BufferFlattener : public StmtExprMutator { } static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { - String storage_scope = buffer->scope; - if (storage_scope.empty()) { - storage_scope = "global"; - } + String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index f7443c74c0f7..f99cbd5b5a05 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -29,6 +29,7 @@ #include #include "../../arith/pattern_match.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -42,10 +43,7 @@ class CopyIntrinInjector : public StmtMutator { flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = op->value.as()->value; - } else if (op->attr_key == pragma_key_) { + if (op->attr_key == pragma_key_) { Stmt ret; ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; @@ -148,30 +146,18 @@ class CopyIntrinInjector : public StmtMutator { dst_strides.push_back(make_const(DataType::Int(32), 1)); } Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + store_strides[loop_var_size], store->buffer_var->name_hint, 0, 0, kDefault); Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, - load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, - kDefault); + load->buffer_var->name_hint, 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); ICHECK(out->defined()) << "flower function did not return correct stmt"; return true; } - // Get storage scope - std::string GetStorageScope(const VarNode* var) const { - auto it = storage_scope_.find(var); - if (it != storage_scope_.end()) { - return it->second; - } else { - return ""; - } - } + // pragma key std::string pragma_key_; // function to lower copy intrinsics. const PackedFunc& flower_copy_fromto_; - // Storage scope - std::unordered_map storage_scope_; // arith analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 7a16c06d8058..0b45bde28dfe 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -95,16 +95,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = dbuffer_info_.find(buf); - if (it != dbuffer_info_.end()) { - it->second.scope = op->value.as()->value; - return this->VisitStmt(op->body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } else if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -112,8 +103,10 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - auto it = dbuffer_info_.find(op->buffer_var.get()); + const VarNode* buf = op->buffer_var.as(); + auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { + it->second.scope = GetPtrStorageScope(op->buffer_var); it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); @@ -125,8 +118,6 @@ class DoubleBufferInjector : public StmtExprMutator { } ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back( - AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..7248bd4e663f 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -23,6 +23,7 @@ */ #include "ir_utils.h" +#include #include #include @@ -172,16 +173,6 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { if (const VarNode* v = op->node.as()) { - if (op->attr_key == attr::storage_scope) { - const AllocateNode* alloc = op->body.as(); - if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = this->VisitStmt(op->body); - if (new_alloc.same_as(op->body)) return GetRef(op); - alloc = new_alloc.as(); - ICHECK(alloc); - return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); - } - } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { @@ -201,5 +192,57 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } +String GetPtrStorageScope(Var buffer_var) { + const auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(indices.size(), target->shape.size()); + + arith::Analyzer analyzer; + Array result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - indices.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& range = source->region[i]; + ICHECK(analyzer.CanProve(range->extent == 1)); + result.push_back(range->min); + } + for (size_t i = 0; i < indices.size(); ++i) { + const Range& range = source->region[i + offset]; + const PrimExpr& index = indices[i]; + result.push_back(range->min + index); + } + return result; +} + +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { + const Buffer& target = match_buffer->buffer; + const BufferRegion& source = match_buffer->source; + ICHECK_EQ(region.size(), target->shape.size()); + + arith::Analyzer analyzer; + Region result; + result.reserve(source->region.size()); + size_t offset = source->region.size() - region.size(); + for (size_t i = 0; i < offset; ++i) { + const Range& source_range = source->region[i]; + ICHECK(analyzer.CanProve(source_range->extent == 1)); + result.push_back(Range::FromMinExtent(source_range->min, 1)); + } + for (size_t i = 0; i < region.size(); ++i) { + const Range& source_range = source->region[i + offset]; + const Range& target_range = region[i]; + result.push_back( + Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); + } + return result; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 906ff8a38b6c..79c5f0609243 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -191,6 +191,28 @@ inline PrimExpr StackAlloca(std::string type, size_t num) { */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +String GetPtrStorageScope(Var buffer_var); + +/*! + * \brief Convert match buffer target buffer access indices to original one. + * \param indices The indices of the target buffer + * \return The indices of source buffer. + */ +Array ConvertIndices(const MatchBufferRegion& match_buffer, + const Array& indices); + +/*! + * \brief Convert match buffer target buffer region to original one. + * \param region The sub-region of the target buffer + * \return The region of source buffer. + */ +Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc index 424da1e817b6..cb2b50260326 100644 --- a/src/tir/transforms/legalize_packed_calls.cc +++ b/src/tir/transforms/legalize_packed_calls.cc @@ -109,7 +109,7 @@ Pass LegalizePackedCalls() { inputs[i] = true; } n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body)); - return std::move(f); + return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {}); } diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f1d816f0baef..97f5b6f90a70 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -84,19 +85,6 @@ using Partition = std::unordered_map; -bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { - bool success = false; - PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { - if (const VarNode* v = node.as()) { - if (vars.count(v)) { - success = true; - return; - } - } - }); - return success; -} - // Select potential candidate IRs that can be partitioned. // Rule: // - the range should not be const @@ -200,7 +188,8 @@ class PartitionFinder : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) final { - if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; + auto f_vset_contains = [this](const VarNode* var) { return out_vars_.count(var); }; + if (UsesVar(op->min, f_vset_contains) || UsesVar(op->extent, f_vset_contains)) return; const VarNode* var = op->loop_var.get(); hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); @@ -230,7 +219,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { + if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 829b7d822d11..b4ec91ba5012 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -41,39 +41,22 @@ using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: Stmt VisitStmt_(const AllocateNode* op) final { - // Lower allocate to device allocate when needed. - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - // For special memory, remove allocate, or use head expr - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.info.defined()) { - const MemoryInfo& info = it->second.info; - ++it->second.alloc_count; - ICHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (scope.tag.length() != 0 && scope.tag != ".dyn") { + auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); + ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); + storage_info_[op->buffer_var.get()] = info; + // Lower allocate to device allocate when needed. + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); } else { return op->body; } - } else { - return stmt; - } - } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0) { - e.info = GetMemoryInfo(op->value.as()->value); - ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return StmtExprMutator::VisitStmt_(op); - } else { return StmtExprMutator::VisitStmt_(op); } @@ -99,8 +82,8 @@ class StorageAccessInfoLower : public StmtExprMutator { Var buffer_var = Downcast(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); + if (it != storage_info_.end() && it->second.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); } ICHECK(op->dtype.is_handle()); // Change to address_of @@ -118,17 +101,8 @@ class StorageAccessInfoLower : public StmtExprMutator { return cast(ptr_type, analyzer_.Simplify( offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; // The storage scope of each buffer - std::unordered_map storage_info_; + std::unordered_map storage_info_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc new file mode 100644 index 000000000000..2f8fbe0ea6e7 --- /dev/null +++ b/src/tir/transforms/lower_match_buffer.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_match_buffer.cc + * \brief The pass for lowering match_buffer. + */ + +#include +#include +#include +#include +#include + +#include "../ir/functor_common.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { +class MatchBufferLower : public StmtExprMutator { + public: + explicit MatchBufferLower(const PrimFunc& func) { + for (const Var& param : func->params) { + // Mark input var as const variable. + if (!param.dtype().is_handle()) var_map_.Set(param, param); + } + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + CheckAndUpdateVarMap(match_buffer); + } + + Stmt stmt = StmtExprMutator ::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Array reads = MutateArray( + op->reads, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + + if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { + return stmt; + } else { + auto n = CopyOnWrite(op); + n->match_buffers = {}; + n->reads = std::move(reads); + n->writes = std::move(writes); + return Stmt(n); + } + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + auto it = var_map_.find(v); + if (it != var_map_.end()) { + return (*it).second; + } else { + return std::move(v); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return stmt; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + + auto n = CopyOnWrite(op); + n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + n->buffer = source->buffer; + return Stmt(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + + auto it = match_buffers_.find(op->buffer); + if (it == match_buffers_.end()) { + return expr; + } else { + const Buffer& buffer = (*it).first; + const BufferRegion& source = (*it).second; + Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); + return BufferLoad(source->buffer, indices); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Load from buffer created by match_buffer is not allowed, but got: " << expr; + return expr; + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + CHECK(var_map_.find(op->buffer_var) == var_map_.end()) + << "Store from buffer created by match_buffer is not allowed, but got: " << stmt; + return stmt; + } + + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + auto it = match_buffers_.find(buffer); + if (it == match_buffers_.end()) { + return buffer_region; + } else { + const BufferRegion& source = (*it).second; + Region region = ConvertRegion(MatchBufferRegion(buffer, source), buffer_region->region); + return BufferRegion(source->buffer, std::move(region)); + } + } + + private: + void CheckAndUpdateVarMap(const MatchBufferRegion& match_buffer) { + // Step.1. Check + const Buffer& buffer = match_buffer->buffer; + const BufferRegion& source = VisitBufferRegion(match_buffer->source); + const Buffer& source_buffer = source->buffer; + + // Step.1.1. Check scope & dtype + ICHECK_EQ(buffer.scope(), source_buffer.scope()) + << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs." + << source_buffer.scope(); + ICHECK_EQ(buffer->dtype, source_buffer->dtype) + << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs." + << source_buffer->dtype; + + // Step.1.2. Check data alignment + if (source_buffer->data_alignment % buffer->data_alignment != 0) { + LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " + << " required_alignment=" << buffer->data_alignment + << ", provided_alignment=" << source_buffer->data_alignment; + } + if (is_zero(buffer->elem_offset)) { + ICHECK(is_zero(source_buffer->elem_offset)) + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << buffer->elem_offset + << ", provided elem_offset=" << source_buffer->elem_offset; + } + + // Step.2. Update + match_buffers_.Set(buffer, source); + // Step.2.1. Update buffer data + Bind(buffer->data, source_buffer->data, buffer->name + ".data"); + + // Step.2.2. Update element offset + // Note we create Load via vload and try to reuse index calculate. + { + Array indices; + indices.reserve(source->region.size()); + for (const Range& range : source->region) { + indices.push_back(range->min); + } + + Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); + Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + << "The source elem_offset " << buffer->elem_offset + << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + } + + // Step 2.3. Check and update strides + // Check if target buffer strides are defined + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + PrimExpr stride = make_const(DataType::Int(32), 1); + for (size_t i = buffer->shape.size(); i > 0; --i) { + const PrimExpr& shape = source_buffer->shape[i - 1]; + Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); + stride *= shape; + } + } + + // Step 2.4. Check and update shape + ICHECK(source->region.size() >= buffer->shape.size()); + size_t offset = source->region.size() - buffer->shape.size(); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + const Range& range = source->region[i + offset]; + Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i)); + } + } + + void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { + CHECK_EQ(arg.dtype(), value.dtype()) + << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; + // Handle recursive case + value = Substitute(std::move(value), var_map_); + if (arg->IsInstance()) { + Var v = Downcast(arg); + auto it = var_map_.find(v); + if (it == var_map_.end()) { + var_map_.Set(v, value); + analyzer_.Bind(v, value); + } else { + AssertBinding((*it).second, value, arg_name); + } + } else { + AssertBinding(arg, value, arg_name); + } + } + + void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, + const std::string& arg_name = "argument") { + CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name + << " unmet: " << lhs << "==" << rhs << "."; + } + + private: + /*! \brief Buffer region mapping. */ + Map match_buffers_; + /*! \brief Var mapping for buffer signature (data, strides, element_offset, etc.) */ + Map var_map_; + /*! \brief The analyzer */ + arith::Analyzer analyzer_; +}; + +PrimFunc LowerMatchBuffer(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + fptr->body = MatchBufferLower(func)(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass LowerMatchBuffer() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return LowerMatchBuffer(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9e536814fa12..481b1bfd4b19 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,10 +33,32 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { +class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { + public: + explicit UpdatePointerStorageScopeAllReduce( + const std::unordered_map& new_storage_scopes) + : UpdatePointerStorageScope(new_storage_scopes) {} + + Stmt VisitStmt_(const AllocateNode* op) final { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + auto new_scope = GetPtrStorageScope(remapped); + if (new_scope != GetPtrStorageScope(op->buffer_var)) { + Stmt body = StmtExprMutator::VisitStmt(op->body); + if (new_scope == "shared") { + // use volatile access to shared buffer. + body = AttrStmt(remapped, attr::volatile_scope, 1, body); + } + return Allocate(remapped, op->dtype, op->extents, op->condition, body); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -48,15 +70,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); return ret; - } else if (op->attr_key == attr::storage_scope) { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - const VarNode* v = op->node.as(); - if (alloc_remap_.count(v)) { - return op->body; - } else { - return ret; - } } else if (op->attr_key == attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); ICHECK(combiner); @@ -86,12 +99,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - // use volatile access to shared buffer. - stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; } else { @@ -108,6 +119,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + std::unordered_map new_storage_scopes_; + private: // Thread entry struct ThreadEntry { @@ -366,7 +379,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } } @@ -590,7 +603,10 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); - n->body = ThreadAllreduceBuilder(target_node)(n->body); + ThreadAllreduceBuilder thread_all_reduce(target_node); + auto reduce_body = thread_all_reduce(n->body); + n->body = + UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936ca..f3966eb93b6c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -40,6 +40,8 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { @@ -250,7 +252,7 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - ICHECK(!ExprUseVar(local_index, warp_index_)) + ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); @@ -356,34 +358,21 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } + std::unordered_map new_storage_scopes_; + private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); - if (warp_buffer_.count(op->buffer_var.get())) { + if (GetPtrStorageScope(op->buffer_var) == "warp") { + new_storage_scopes_[op->buffer_var.get()] = "local"; WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - if (scope.rank == runtime::StorageRank::kWarp) { - warp_buffer_.insert(buf); - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); - } - } - return StmtMutator::VisitStmt_(op); - } - int warp_size_{0}; - std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; @@ -397,7 +386,9 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); - n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); + WarpMemoryRewriter warp_memory_rewriter(warp_size); + auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); + n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index ee52a6fc0988..393ce6c286b4 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -119,7 +119,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); ICHECK_LE(num_unpacked_args, num_args); - + bool pack_args = (num_unpacked_args == -1) || (num_args > num_unpacked_args); + if (num_unpacked_args == -1) { + // reset to zero + num_unpacked_args = 0; + } + ICHECK_GE(num_unpacked_args, 0); int num_packed_args = num_args - num_unpacked_args; // Data field definitions // The packed fields @@ -154,11 +159,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } return res; }; - // --------------------------- // start of logics // add signiture for packed arguments. - if (num_packed_args != 0) { + if (pack_args) { args.push_back(v_packed_args); args.push_back(v_packed_arg_type_ids); args.push_back(v_num_packed_args); @@ -214,13 +218,13 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } // allow return value if the function is packed. - if (num_packed_args != 0) { + if (pack_args) { args.push_back(v_out_ret_value); args.push_back(v_out_ret_tcode); args.push_back(v_resource_handle); } - size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 6 : 0); + size_t expected_nargs = num_unpacked_args + (pack_args ? 6 : 0); ICHECK_EQ(args.size(), expected_nargs); // Arg definitions are defined before buffer binding to avoid the use before @@ -282,6 +286,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { namespace transform { Pass MakePackedAPI(int num_unpacked_args) { + // packed arguments anyway while `num_unpacked_args` is -1 auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { IRModuleNode* mptr = m.CopyOnWrite(); std::vector > updates; diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc new file mode 100644 index 000000000000..e8865b260dc1 --- /dev/null +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_dynamic_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. + * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. + */ +#include +#include +#include +#include + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +bool IsDynamicSharedMemory(Var buffer_var) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; +} + +class AllocateCollector : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_.insert(op); + } + StmtExprVisitor::VisitStmt_(op); + } + + std::unordered_set dyn_shmem_allocs_; +}; + +class DynamicSharedMemoryRewriter : public StmtExprMutator { + public: + explicit DynamicSharedMemoryRewriter( + const std::unordered_set& dyn_shmem_allocs) + : dyn_shmem_allocs_{dyn_shmem_allocs} {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent && !allocated) { + // Allocate one dynamic shared memory allocation at the beginning of thread scope + int align = 1; + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; + align = std::max(align, alloc->dtype.bytes()); + } + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->extents.size(), 1); + buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + merged_alloc_size_ += alloc->extents[0] * align; + } + + allocated = true; + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, + const_true(), StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const StoreNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + auto value = StmtExprMutator::VisitExpr(op->value); + return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()); + return indexdiv(it->second, dtype.bytes()); + } + + Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + std::unordered_set dyn_shmem_allocs_; + PrimExpr merged_alloc_size_{0}; + std::unordered_map buffer_byte_offsets_; + bool allocated{false}; +}; + +Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { + AllocateCollector collector; + collector(stmt); + if (collector.dyn_shmem_allocs_.size() > 1) { + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + } + return stmt; +} + +namespace transform { + +Pass MergeDynamicSharedMemoryAllocations() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations") + .set_body_typed(MergeDynamicSharedMemoryAllocations); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index f01d98707586..795ae9d6a73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -33,6 +33,9 @@ #include +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -89,6 +92,17 @@ class VarUseDefAnalysis : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { this->HandleDef(op->buffer_var.get()); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + dyn_shmem_size_ = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + dyn_shmem_size_ *= op->extents[i]; + } + dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); + use_dyn_shmem_ = true; + } return StmtExprMutator::VisitStmt_(op); } @@ -175,6 +189,8 @@ class VarUseDefAnalysis : public StmtExprMutator { Array undefined_; Array thread_axis_; Array thread_extent_; + PrimExpr dyn_shmem_size_{0}; + bool use_dyn_shmem_{false}; std::unordered_map use_count_; std::unordered_map def_count_; @@ -262,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator { WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + if (m.use_dyn_shmem_) { + device_func = + WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); + } (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function @@ -273,6 +293,9 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } + if (m.use_dyn_shmem_) { + call_args.push_back(m.dyn_shmem_size_); + } return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..0567c8613fcd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -176,6 +172,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector()); + this->VisitStmt(op->else_case); auto v = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); @@ -208,7 +205,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +241,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); + } + return StorageScope(); // global by default } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..38b3a77b1a0c 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,11 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); @@ -156,10 +152,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetPtrStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -167,7 +161,6 @@ class StorageFlattener : public StmtExprMutator { } else { skey = StorageScope::Create(strkey); } - // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); @@ -200,9 +193,12 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - skey.to_string(), align, 0, kDefault); + auto* ptr_type = op->buffer->data->type_annotation.as(); + ICHECK(ptr_type); + auto new_var = + Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -228,7 +224,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, @@ -491,8 +486,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 36eeddb17d89..592a6a33375e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -37,6 +37,7 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../ir/buffer_common.h" #include "ir_utils.h" namespace tvm { @@ -75,8 +76,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +85,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +174,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -337,7 +327,14 @@ class InplaceOpVerifier : public StmtExprVisitor { const StoreNode* store_{nullptr}; }; -// Planner to plan and rewrite memory allocation. +/* \brief Rewrite and merge memory allocation. + * + * Using LinearAccessPatternFinder, determines which buffers could share an + * allocation. This includes both sequential usage of the same buffer and + * merging small allocations at the same scope into a single larger allocation. + * The merging of small allocations requires the codegen to cast the resulting + * value from the storage type to the output type after access. + */ class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; @@ -409,10 +406,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -496,8 +491,6 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -506,7 +499,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Remap the index PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { if (e->bits_offset == 0) return index; - uint64_t elem_bits = dtype.bits() * dtype.lanes(); + uint64_t elem_bits = dtype.bits(); ICHECK_EQ(e->bits_offset % elem_bits, 0U); return make_const(index.dtype(), e->bits_offset / elem_bits) + index; } @@ -523,7 +516,7 @@ class StoragePlanRewriter : public StmtExprMutator { // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { @@ -557,7 +550,7 @@ class StoragePlanRewriter : public StmtExprMutator { make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -598,7 +591,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->new_alloc = Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); - if (e->scope.tag.length() != 0) { + if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -716,7 +709,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +720,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +737,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } @@ -896,107 +889,547 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; -// Turn alloc into vector alloc -// if all its access is the same vector type. -class VectorAllocRewriter : public StmtExprMutator { +/* Helper struct containing information on how a buffer is declared and used + * + */ +struct BufferVarInfo { + enum DeclarationLocation { + kPrimFuncParam = (1 << 0), + kPrimFuncBufferMap = (1 << 1), + kAllocateNode = (1 << 2), + kLetNode = (1 << 3), + }; + + // The tir::Var that represents this buffer. + Var var; + + // The data type of an element of the buffer. + DataType element_dtype; + + /* The extent of the buffer. + * + * If multidimensional, the extent of the last dimension of the buffer. If the + * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding + * entry in buffer_map), then extent is zero. + */ + PrimExpr extent; + + // Where the buffer was declared + DeclarationLocation declaration_location; + + // When accessed, which element type is it accessed as. This may + // differ both in base type (e.g. int32* cast to float32* after + // packing in StorageRewrite) or in number of lanes (e.g. float16* + // cast to float16x4*). + std::unordered_set access_dtype; + + DataType get_preferred_dtype() const { + std::unordered_set base_access_dtype; + for (auto dtype : access_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + // If the array is accessed as multiple base types within a + // function, no point in changing the declared type. CodeGenC can + // handle this with a type-cast prior to indexing. Vulkan will + // raise an error at code-gen time, if a later pass doesn't split + // it out. + if (base_access_dtype.size() != 1) { + return element_dtype; + } + + DataType preferred_base_type = *base_access_dtype.begin(); + + // If there is only one vectorizable size used to access the + // buffer, and if that access size is compatible with the array + // size, then the buffer is vectorizable. In the future, this + // could be improved to allow vectorized buffer access of size + // GCD(*lanes_used), if necessary. + int preferred_lanes = element_dtype.lanes(); + if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { + arith::Analyzer analyzer_; + arith::ModularSet me = analyzer_.modular_set(extent); + + int lanes = access_dtype.begin()->lanes(); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + preferred_lanes = lanes; + } + } + + return preferred_base_type.with_lanes(preferred_lanes); + } +}; + +/* Checks whether buffers are accessed as scalar or vector parameters in a + * function. + * + */ +class VectorTypeAccessChecker : public StmtExprVisitor { public: - PrimExpr VisitExpr_(const LoadNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->dtype); - return StmtExprMutator::VisitExpr_(op); + /* Constructor + * + * @param params The parameters passed to a PrimFunc + * + * @param buffer_map The buffer_map associated with a PrimFunc + * + * @param allow_untyped_handles If a buffer or pointer variable is + * missing a type annotation, assume that it has the same underlying + * type as it is later accessed, with scalar element types. + */ + VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + bool allow_untyped_pointers = false) + : allow_untyped_pointers_(allow_untyped_pointers) { + // If a parameter is in the buffer map, we want to track the + // version in the map. + for (auto it : buffer_map) { + Buffer& buffer = it.second; + Var buffer_var = buffer->data; + DataType dtype = buffer->dtype; + PrimExpr extent = buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncParam); + } + + // If a pointer parameter isn't in the buffer map, then we want to + // track the parameter itself. + for (Var buffer_var : params) { + auto pointer_type = GetPointerType(buffer_var->type_annotation); + if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) { + DataType dtype = pointer_type.second; + PrimExpr extent = 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap); + } + } } - Stmt VisitStmt_(const StoreNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); - return StmtExprMutator::VisitStmt_(op); + void VisitExpr_(const LoadNode* op) final { + OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitExpr_(op); } - PrimExpr VisitExpr_(const CallNode* op) final { + + void VisitStmt_(const StoreNode* op) final { + OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); - UpdateTypeMap(buffer, dtype); + PrimExpr index = op->args[2]; + OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes())); } - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } - Stmt VisitStmt_(const AllocateNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const auto& tvec = acc_map_[op->buffer_var.get()]; - - if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { - int factor = tvec[0].lanes() / op->dtype.lanes(); - Array extents = op->extents; - arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); - if (me->base % factor == 0 && me->coeff % factor == 0) { - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - // create a new buffer var - DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); - // update the remap req. - var_remap_.Set(op->buffer_var, new_buffer_var); - return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); + void VisitStmt_(const AllocateNode* op) final { + const Array& extents = op->extents; + PrimExpr extent = extents[extents.size() - 1]; + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const LetNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitStmt_(op); + } + + void HandleLetNode(Var let_var) { + if (let_var->dtype.is_handle()) { + auto pointer_type = GetPointerType(let_var->type_annotation); + if (pointer_type.first) { + OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode); + } else if (allow_untyped_pointers_) { + OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); + } else { + LOG(FATAL) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; } } - return stmt; } - void UpdateTypeMap(const VarNode* buffer, DataType t) { - auto& tvec = acc_map_[buffer]; - if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) { - tvec.push_back(t); + /* Update the type map for a buffer based on its declaration + * + * @param buffer The VarNode representing the buffer. + * + * @param element_dtype The dtype of a single element of the buffer. + * If unknown, when used with the allow_untyped_handles option, + * should be a handle dtype. + * + * @param extent The extent of the buffer. Zero if size is unknown. + * + * @param declaration_location How the buffer was allocated, so that + * some locations can be rewritten without others. + */ + void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, + BufferVarInfo::DeclarationLocation declaration_location) { + ICHECK(info_map_.find(buffer.get()) == info_map_.end()) + << "Array declaration of " << buffer->name_hint << " occurred multiple times."; + + if (element_dtype == DataType::Bool()) { + element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); + } + + info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; + } + + /* Update the type map for a buffer based on its usage + * + * @param value_dtype The dtype of the value being stored to or + * loaded from the buffer. + * + * @param buffer The VarNode representing the buffer. + * + * @param index The index at which the value is being stored/loaded. + * + * @param predicate The predicate used for the store/load. + */ + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const PrimExpr& index, + const PrimExpr& predicate) { + auto it = info_map_.find(buffer); + ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer + << ") occurred before its declaration."; + BufferVarInfo& var_info = it->second; + + if (value_dtype.element_of() == DataType::Bool()) { + value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); + } + + if (var_info.element_dtype.is_handle()) { + ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; + var_info.element_dtype = value_dtype.element_of(); + } + + DataType access_dtype = value_dtype; + + int lanes_used = var_info.element_dtype.lanes(); + + // This can happen due to a previous pass that had rewrite_store_load = + // false. This occurs from the StorageRewrite in tvm::lower, followed by the + // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false is + // necessary because the C-based codegens do not yet support vectorized + // pointer types (e.g. float16x4*). Once they do, this if statement should + // instead be replaced by the below ICHECK_EQ. + if (index.dtype().lanes() * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes()); + lanes_used = 1; + var_info.element_dtype = var_info.element_dtype.with_lanes(1); + } + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) + // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " + // << index.dtype().lanes() << " indices into an array whose elements have " + // << var_info.element_dtype.lanes() << " lanes. " + // << "Expected output with " << index.dtype().lanes() * var_info.element_dtype.lanes() + // << " lanes."; + + // If the index is a RampNode with stride of 1 and offset + // divisible by the number of number of lanes, and the predicate + // does not apply any masking, then this array access could be + // vectorized. + const RampNode* ramp_index = index.as(); + if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) { + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { + lanes_used = ramp_index->lanes; + } } + + var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); } - // Internal access map - std::unordered_map > acc_map_; - // Variables to remap - Map var_remap_; + // Map of buffer variable information determined + std::unordered_map info_map_; + + // + bool allow_untyped_pointers_{false}; + // internal analyzer arith::Analyzer analyzer_; }; -PrimFunc PointerValueTypeRewrite(PrimFunc f) { - auto* n = f.CopyOnWrite(); - VectorAllocRewriter rewriter; - n->body = rewriter(std::move(n->body)); +/* \brief Rewrites buffer/pointer variables from scalar types to vectorized + * types. + * + * Some runtimes do not allow casting between composite types and the underlying + * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*). + * In these cases, in order to have vectorized load/store on an array, the + * element type of that array must be vectorized. This is in contrast to C-style + * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is + * valid. + * + * By default, VectorTypeRewriter will attempt to rewrite all buffer variables to + * vectorized access, if the load/store occurring in the PrimFunc are all + * vectorized. This includes adjusting the indices being used to access the + * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4* + * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to + * `vec_arr[offset/4]`.) + * + * Currently, several of the C-style runtimes do not support buffers whose + * elements are vectorized types, or rely on the presence of the Ramp nodes to + * identify vectorized loads. The boolean parameters in the constructor are to + * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these + * runtimes. Once all runtimes support vectorized buffer elements, these + * parameters can be removed. + */ +class VectorTypeRewriter : public StmtExprMutator { + public: + /* Constructor + * + * @param checker The VectorTypeAccessChecker that has previously read out + * information from the PrimFunc + * + * @param rewrite_params Whether pointer-type parameters passed into the + * function should be rewritten from scalar types to vectorized types. + * + * @param rewrite_buffer_map Whether buffers present in the buffer_map should + * have their data variable be rewritten from scalar types to vectorized types. + * + * @param rewrite_allocate_node Whether the buffer variable associated with + * AllocateNodes should be rewritten from scalar types to vectorized types. + * + * @param rewrite_indices Whether the indices to the Load and Store nodes + * should be rewritten to correspond to the new buffer_var type. + * + * @param rewrite_let_node Whether pointer declarations in let nodes + * should be re-written. + */ + VectorTypeRewriter(const std::unordered_map& info_map, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) + : rewrite_indices_(rewrite_indices) { + int rewrite_mask = 0; + if (rewrite_params) { + rewrite_mask |= BufferVarInfo::kPrimFuncParam; + } + if (rewrite_buffer_map) { + rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap; + } + if (rewrite_allocate_node) { + rewrite_mask |= BufferVarInfo::kAllocateNode; + } + if (rewrite_let_node) { + rewrite_mask |= BufferVarInfo::kLetNode; + } + + // Rewrite any buffer variables whose preferred type isn't their current type. + for (const auto& pair : info_map) { + const auto& var_info = pair.second; + DataType preferred = var_info.get_preferred_dtype(); + if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) { + Var old_buffer_var = var_info.var; + Var new_buffer_var(old_buffer_var->name_hint, + PointerType(PrimType(preferred), GetPtrStorageScope(old_buffer_var)), + old_buffer_var->span); + + rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, var_info.element_dtype, + preferred}; + } + } + } - Map var_remap = std::move(rewriter.var_remap_); - Array args; + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); - // rewrite paramters if needed. - for (Var var : f->params) { - if (var.dtype().is_handle()) { - const auto& tvec = rewriter.acc_map_[var.get()]; + if (!rewrite_indices_) { + return expr; + } - if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - // always set data type to be non vectorized so - // load/store can still work via scalarization - if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - args.push_back(var); - } + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return expr; + } + const auto& info = it->second; + + DataType out_dtype_base = info.new_element_dtype.element_of(); + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, + const_true(new_index.dtype().lanes()), op->span); + } else { + return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + if (!rewrite_indices_) { + return stmt; + } + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + const auto& info = it->second; + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), + op->span); + } else { + return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + if (!rewrite_indices_) { + return expr; + } + + const VarNode* buffer_var = op->args[1].as(); + auto it = rewrite_map_.find(buffer_var); + if (it == rewrite_map_.end()) { + return expr; } + const auto& info = it->second; + + PrimExpr index = op->args[2]; + PrimExpr extent = op->args[3]; + PrimExpr flag = op->args[4]; + + PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); + PrimExpr factor = make_const(extent.dtype(), info.new_element_dtype.lanes()); + extent = extent / factor; + index = index / factor; + Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; + return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); + } else { - args.push_back(var); + return StmtExprMutator::VisitExpr_(op); + } + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + + const auto& info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); + + Array extents = op->extents; + extents.Set(extents.size() - 1, + extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); + } + + /* Update the parameters and all remaining variable references + * + * Should be called after calling operator() on the body of the + * function. + * + * @param func A pointer to the PrimFunc being modified. + */ + void Finalize(PrimFunc* func_ptr) const { + ICHECK(func_ptr) << "Finalize expects a non-null pointer"; + auto& func = *func_ptr; + auto* n = func.CopyOnWrite(); + + // Remap any remaining references to the old buffer variables + Map var_remap; + for (const auto& pair : rewrite_map_) { + const auto& info = pair.second; + var_remap.Set(info.old_buffer_var, info.new_buffer_var); + } + n->body = Substitute(n->body, var_remap); + + // Remap the argument list to use the new buffer variables. + Array new_params; + for (const auto& old_param : n->params) { + auto it = rewrite_map_.find(old_param.get()); + if (it == rewrite_map_.end()) { + new_params.push_back(old_param); + } else { + const auto& info = it->second; + new_params.push_back(info.new_buffer_var); + } + } + n->params = new_params; + + // Remap the Buffer objects in so that the buffers use the new buffer variables + Map new_buffer_map; + for (const auto& pair : n->buffer_map) { + Var key = pair.first; + Buffer old_buffer = pair.second; + Var old_var = old_buffer->data; + + auto it = rewrite_map_.find(old_var.get()); + if (it == rewrite_map_.end()) { + new_buffer_map.Set(key, old_buffer); + } else { + auto& info = it->second; + int factor = info.new_element_dtype.lanes() / info.old_element_dtype.lanes(); + ICHECK_EQ(factor * info.new_element_dtype.lanes(), info.old_element_dtype.lanes()); + + auto* buffer_cow = old_buffer.CopyOnWrite(); + buffer_cow->data = info.new_buffer_var; + buffer_cow->dtype = info.new_element_dtype; + size_t ndim = buffer_cow->shape.size(); + const auto& last_dim = buffer_cow->shape[ndim - 1]; + buffer_cow->shape.Set(ndim - 1, last_dim / make_const(last_dim.dtype(), factor)); + new_buffer_map.Set(key, old_buffer); + } } + n->buffer_map = new_buffer_map; } - // no variable remap is needed. - if (var_remap.size() == 0) return f; + private: + struct RewriteInfo { + Var old_buffer_var; + Var new_buffer_var; + DataType old_element_dtype; + DataType new_element_dtype; + }; + + bool rewrite_indices_{true}; + std::unordered_map rewrite_map_; +}; + +// Rewrite allocates, pointer parameters, and buffer map into vectorized versions +// if each access into a buffer is the same vector type. +PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); + checker(f->body); + + VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, + rewrite_allocate_node, rewrite_indices, rewrite_let_node); + PrimFuncNode* n = f.CopyOnWrite(); + n->body = rewriter(std::move(n->body)); + rewriter.Finalize(&f); - // remap the variables. - ICHECK_EQ(args.size(), n->params.size()); - n->params = args; - n->body = Substitute(n->body, var_remap); return f; } @@ -1006,7 +1439,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f)); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index d0f58074ada0..1836b8ecec0d 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -69,7 +69,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -102,7 +102,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); // Only wmma.accumulator can use tvm_fill_fragment ICHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { @@ -119,16 +119,9 @@ class FragmentGetter : public StmtExprVisitor { // Get memory scope void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buffer = op->node.as(); - ICHECK(buffer); - scopes[buffer] = op->value.as()->value; - } StmtExprVisitor::VisitStmt_(op); } - // Memory scope for allocations - std::unordered_map scopes; // Fragment metadata for all fragments std::unordered_map fragments; }; diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..35e4563b8f58 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -223,14 +223,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +250,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +260,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +282,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +330,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc new file mode 100644 index 000000000000..4143577a0b17 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.cc + * \brief A pass to update storage scopes for buffer variables. + */ +#include "update_pointer_storage_scope.h" + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + +UpdatePointerStorageScope::UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes) { + for (auto& kv : new_storage_scopes) { + new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); + } +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { + auto it = new_var_remap_.find(op); + if (it == new_var_remap_.end()) { + return GetRef(op); + } + return it->second; +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + StmtExprMutator::VisitStmt(op->body)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h new file mode 100644 index 000000000000..f310194a4a51 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.h + * \brief A pass to update storage scopes for buffer variables. + */ +#ifndef TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class UpdatePointerStorageScope : public StmtExprMutator { + public: + explicit UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes); + + virtual PrimExpr VisitExpr_(const VarNode*); + virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AllocateNode*); + virtual Stmt VisitStmt_(const StoreNode*); + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ diff --git a/tests/cpp/topi_ewise_test.cc b/tests/cpp/topi_ewise_test.cc index 22ef8c7dffaa..e5ba92a89058 100644 --- a/tests/cpp/topi_ewise_test.cc +++ b/tests/cpp/topi_ewise_test.cc @@ -25,7 +25,7 @@ namespace tvm { namespace topi { TEST(Tensor, Basic) { using namespace tvm; - Var m("m"), n("n"), l("l"); + Var m("m"), l("l"); Tensor A = placeholder({m, l}, DataType::Float(32), "A"); auto C = topi::exp(A); } diff --git a/tests/crt/aot_memory_test.cc b/tests/crt/aot_memory_test.cc index abda7bebf766..06565dda68d1 100644 --- a/tests/crt/aot_memory_test.cc +++ b/tests/crt/aot_memory_test.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ + #include #include @@ -24,83 +25,126 @@ // Check with LIFO checks enabled for stack allocator #define TVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK + +// Number of memory misalignment in bytes +#define NUM_MEMORY_MISALIGNMENT_BYTES 1 + +/*! + * Align memory pointer. + * This function modifies memory_ptr to adjust alignment. + * \return Number of memory offset. + */ +static uint32_t align_pointer(uint8_t** memory_ptr) { + uint32_t extra = (uintptr_t)(*memory_ptr) % TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES; + uint32_t offset = + (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - extra) & (TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES - 1); + *memory_ptr += offset; + return offset; +} + +/*! + * Add misalignment to memory pointer. + * This function modifies memory_ptr. + * \return Number of memory offset. + */ +static uint32_t misalign_pointer(uint8_t** memory_ptr) { + uint32_t extra = (uintptr_t)(*memory_ptr) % TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES; + if (extra == 0) { + *memory_ptr += NUM_MEMORY_MISALIGNMENT_BYTES; + return 1; + } + return 0; +} + /* - * Tests allocations are properly aligned when allocated + * Tests allocations are properly aligned when allocated. */ TEST(AOTMemory, Allocate) { - static uint8_t model_memory[96]; + static uint8_t model_memory[128]; tvm_workspace_t tvm_runtime_workspace; + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 96), kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 2, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 24, &two_blocks, 1), kTvmErrorNoError); - ASSERT_EQ(two_blocks, &model_memory[32 + 2 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(two_blocks, &model_memory_ptr[32 + 2 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* block_three = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_three, 1), kTvmErrorNoError); - ASSERT_EQ(block_three, &model_memory[64 + 3 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_three, &model_memory_ptr[64 + 3 * STACK_ALLOCATOR_TAG_SIZE_BYTES]); } /* - * Tests resetting the stack after dealloc + * Tests resetting the stack after dealloc. */ TEST(AOTMemory, Free) { static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free_Body(&tvm_runtime_workspace, block_two, 1)); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 2, &two_blocks, 1), kTvmErrorNoError); - ASSERT_EQ(two_blocks, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(two_blocks, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(kTvmErrorNoError, StackMemoryManager_Free_Body(&tvm_runtime_workspace, two_blocks, 1)); void* block_three = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_three, 1), kTvmErrorNoError); - ASSERT_EQ(block_three, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_three, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); } /* - * Tests we return NULL if we over allocate + * Tests we return NULL if we over allocate. */ TEST(AOTMemory, OverAllocate) { - static uint8_t model_memory[72]; + static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); void* two_blocks = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 64, &two_blocks, 1), @@ -109,27 +153,54 @@ TEST(AOTMemory, OverAllocate) { } /* - * Test for out-of-order memory deallocation + * Test for out-of-order memory deallocation. */ TEST(AOTMemory, FreeOutOfOrder) { static uint8_t model_memory[80]; tvm_workspace_t tvm_runtime_workspace; - ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory, 80), kTvmErrorNoError); + uint8_t* model_memory_ptr = model_memory; + uint32_t offset = align_pointer(&model_memory_ptr); + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); void* block_one = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_one, 1), kTvmErrorNoError); - ASSERT_EQ(block_one, &model_memory[0]); + ASSERT_EQ(block_one, &model_memory_ptr[0]); void* block_two = NULL; ASSERT_EQ(StackMemoryManager_Allocate_Body(&tvm_runtime_workspace, 1, &block_two, 1), kTvmErrorNoError); - ASSERT_EQ(block_two, &model_memory[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); + ASSERT_EQ(block_two, &model_memory_ptr[16 + STACK_ALLOCATOR_TAG_SIZE_BYTES]); ASSERT_EQ(StackMemoryManager_Free_Body(&tvm_runtime_workspace, block_one, 1), kTvmErrorPlatformStackAllocBadFree); } +/* + * Test for initial memory misalignment. + */ +TEST(AOTMemory, InitialMemoryMisAlignment) { + static uint8_t model_memory[80]; + tvm_workspace_t tvm_runtime_workspace; + uint8_t* model_memory_ptr = model_memory; + + // Add misaslignment to memory pointer + uint32_t offset = misalign_pointer(&model_memory_ptr); + + // Calculate expected offset + uint8_t* misaligned_ptr = model_memory_ptr; + uint32_t alignment_offset = align_pointer(&misaligned_ptr); + + ASSERT_EQ(StackMemoryManager_Init(&tvm_runtime_workspace, model_memory_ptr, + sizeof(model_memory) - offset), + kTvmErrorNoError); + + ASSERT_EQ(tvm_runtime_workspace.next_alloc, &model_memory_ptr[alignment_offset]); + ASSERT_EQ(tvm_runtime_workspace.workspace_size, sizeof(model_memory) - offset - alignment_offset); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index a7fb76b323c1..f31630c2a705 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -136,6 +136,7 @@ "apps/microtvm/zephyr/qemu-hack/qemu-system-arm", "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32", "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64", + "apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64", "apps/microtvm/zephyr/host_driven/prj.conf", "apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf", "apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf", @@ -144,6 +145,8 @@ "apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf", "apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf", "apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf", + "apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf", + "apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf", "apps/microtvm/zephyr/host_driven/qemu-hack", "apps/microtvm/zephyr/aot_demo/prj.conf", "apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf", @@ -153,6 +156,8 @@ "apps/microtvm/zephyr/aot_demo/boards/nucleo_f746zg.conf", "apps/microtvm/zephyr/aot_demo/boards/stm32f746g_disco.conf", "apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf", + "apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf", + "apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf", "apps/microtvm/zephyr/aot_demo/qemu-hack", # microTVM Virtual Machines "apps/microtvm/reference-vm/zephyr/Vagrantfile", diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 6ca5a530be9d..0b50ecd12ec5 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -24,10 +24,12 @@ "host": ("host", "qemu_x86"), "host_riscv32": ("host", "qemu_riscv32"), "host_riscv64": ("host", "qemu_riscv64"), - "stm32f746xx_nucleo": ("stm32f746xx", "nucleo_f746zg"), - "stm32f746xx_disco": ("stm32f746xx", "stm32f746g_disco"), - "nrf5340dk": ("nrf5340dk", "nrf5340dk_nrf5340_cpuapp"), "mps2_an521": ("mps2_an521", "mps2_an521-qemu"), + "nrf5340dk": ("nrf5340dk", "nrf5340dk_nrf5340_cpuapp"), + "stm32f746xx_disco": ("stm32f746xx", "stm32f746g_disco"), + "stm32f746xx_nucleo": ("stm32f746xx", "nucleo_f746zg"), + "stm32l4r5zi_nucleo": ("stm32l4r5zi", "nucleo_l4r5zi"), + "zynq_mp_r5": ("zynq_mp_r5", "qemu_cortex_r5"), } diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index cf9447bc2b10..18587acd46ae 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -33,6 +33,7 @@ import tvm import tvm.rpc import tvm.micro +import tvm.testing import tvm.relay as relay from tvm.micro.contrib import zephyr @@ -124,6 +125,7 @@ def _make_add_sess(model, zephyr_board, west_cmd, build_config): # The same test code can be executed on both the QEMU simulation and on real hardware. +@tvm.testing.requires_micro def test_compile_runtime(platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" @@ -147,6 +149,7 @@ def test_basic_add(sess): test_basic_add(sess) +@tvm.testing.requires_micro def test_platform_timer(platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" @@ -175,6 +178,7 @@ def test_basic_add(sess): test_basic_add(sess) +@tvm.testing.requires_micro def test_relay(platform, west_cmd, skip_build, tvm_debug): """Testing a simple relay graph""" model, zephyr_board = PLATFORMS[platform] @@ -204,6 +208,7 @@ def test_relay(platform, west_cmd, skip_build, tvm_debug): tvm.testing.assert_allclose(result, x_in * x_in + 1) +@tvm.testing.requires_micro def test_onnx(platform, west_cmd, skip_build, tvm_debug): """Testing a simple ONNX model.""" model, zephyr_board = PLATFORMS[platform] @@ -334,6 +339,7 @@ def check_result( tvm.testing.assert_allclose(out.numpy(), results[idx], rtol=TOL, atol=TOL) +@tvm.testing.requires_micro def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): """This is a simple test case to check BYOC capabilities of microTVM""" model, zephyr_board = PLATFORMS[platform] @@ -410,6 +416,7 @@ def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config pytest.param((16 * 1024,), id="(16*1024)"), ], ) +@tvm.testing.requires_micro def test_rpc_large_array(platform, west_cmd, skip_build, tvm_debug, shape): """Test large RPC array transfer.""" model, zephyr_board = PLATFORMS[platform] diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index afdbdc590de0..d1c9d393770a 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -29,11 +29,13 @@ import tvm import tvm.rpc import tvm.micro +import tvm.testing import tvm.relay as relay from tvm.micro.contrib import zephyr from tvm.contrib import utils from tvm.contrib.download import download_testdata +from tvm.micro.interface_api import generate_c_interface_header import conftest @@ -152,6 +154,7 @@ def _get_message(fd, expr: str): return data +@tvm.testing.requires_micro def test_tflite(platform, west_cmd, skip_build, tvm_debug): """Testing a TFLite model.""" model, zephyr_board = PLATFORMS[platform] @@ -181,7 +184,9 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "float32"} ) - target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) + target = tvm.target.target.micro( + model, options=["-link-params=1", "--executor=aot", "--unpacked-api=1", "--interface-api=c"] + ) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered = relay.build(relay_mod, target, params=params) @@ -192,6 +197,7 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): ) sample = np.load(sample_path) model_files_path = os.path.join(runtime_path, "include") + generate_c_interface_header(lowered.libmod_name, ["input_1"], ["output"], model_files_path) _create_header_file((f"input_data"), sample, model_files_path) _create_header_file( "output_data", np.zeros(shape=output_shape, dtype="float32"), model_files_path @@ -213,7 +219,11 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): assert result == 8 +@tvm.testing.requires_micro def test_qemu_make_fail(platform, west_cmd, skip_build, tvm_debug): + if platform not in ["host", "mps2_an521"]: + pytest.skip(msg="Only for QEMU targets.") + """Testing QEMU make fail.""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} diff --git a/tests/python/contrib/test_bnns/test_conv2d_patterns.py b/tests/python/contrib/test_bnns/test_conv2d_patterns.py index b81e74b6d8fa..5fc9e9522fbd 100644 --- a/tests/python/contrib/test_bnns/test_conv2d_patterns.py +++ b/tests/python/contrib/test_bnns/test_conv2d_patterns.py @@ -57,7 +57,7 @@ def test_pattern_conv2d_with_bias_add(): res = relay.nn.bias_add(res, b, axis=axis) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "nn.bias_add") assert bias_is_fused if axis == 1 else not bias_is_fused @@ -73,7 +73,7 @@ def test_pattern_conv2d_with_add(): res = relay.add(res, b) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "add") assert bias_is_fused == should_be_fused @@ -102,6 +102,6 @@ def test_pattern_conv2d_with_non_cons_bias(): res = relay.nn.bias_add(res, b, axis=1) mod = partition(res) - bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_0"], "nn.bias_add") + bias_is_fused = is_op_fused(mod["tvmgen_default_bnns_main_0"], "nn.bias_add") assert not bias_is_fused diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 069f3f3769b5..839c6f50c070 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -98,6 +98,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): @tvm.testing.requires_gpu +@requires_cudnn def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) @@ -171,6 +172,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): @tvm.testing.requires_gpu +@requires_cudnn def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) verify_conv3d("float32", "float32", tensor_format=0, groups=2) diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index 6ff8011cf4d7..65d2738447cc 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -122,13 +122,13 @@ def test_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"5d3cee6ecc488c40ecf533c5cbacc534"} + _compile_hash = {"1fd4ef29a1ea9f3a015cab87c0b8014a"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"896c28b4f06341ea638ead3a593e1aed"} + _compile_hash = {"b879dfbff1f907eaf6129dfd41b44ece"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"9298b6c51e2a82f70e91dd11dd6af412"} + _compile_hash = {"9c9f63b30824f5b223cdb27d2f22c857"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"407eb47346c8afea2d15e8f0d1c079f2"} + _compile_hash = {"cd13279061df2319124a7aac81581d81"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", @@ -148,13 +148,13 @@ def test_inception_v3(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"1bc66e83c3de5a9773a719b179c65b1a"} + _compile_hash = {"b90ed315639c6a0e97584c2dbc42a55c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"551cde850c6ef960d19be4f317fb8e68"} + _compile_hash = {"5693569055695e581a8739194d0301aa"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"d44eece5027ff56e5e7fcf014367378d"} + _compile_hash = {"46ccafc840633633aca441645e41b444"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"1ba555b4bc60c428018a0f2de9d90532"} + _compile_hash = {"4a33f397ac3e15c0f9869f7b8286fc2f"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite_11_05_08/inception_v3_quant.tgz", @@ -173,13 +173,13 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"578b8ee279911b49912a77a64f5ff620"} + _compile_hash = {"b36877d2386d9f9c37a11772e3c4072c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"30f078bd42757e8686eafa1f28d0d352"} + _compile_hash = {"b5046a6f56d78af0b4f51960bf2deeda"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"53f126cf654d4cf61ebb23c767f6740b"} + _compile_hash = {"4a1a56393078367dd27915a188d6a6af"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"851665c060cf4719248919d17325ae02"} + _compile_hash = {"905caf389dd6b868aeff6acbca1fecef"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/inception_v4_299_quant_20181026.tgz", @@ -198,13 +198,13 @@ def test_ssd_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"cd335229a2052f30273f127a233bd319", "95dedc29d911cdc6b28207ca08e42470"} + _compile_hash = {"956caf9e7fe5cfd5c042bd17857f7407", "4313033d14328e2aa022b1bd71b27b1c"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"deee52e136327436411fc725624ae2ea", "6526509d3cbee014e38c79e22bb29d7f"} + _compile_hash = {"dc60cc687d892cd2877873094e9dfc0b", "6b3deeec16c24c0dcef23df0db5fb162"} if tei.get_ethosn_api_version() == 2011: - _compile_hash = {"6e8c4586bdd26527c642a4f016f52284", "057c5efb094c79fbe4483b561147f1d2"} + _compile_hash = {"10826406ae724e52f360a06c35ced09d", "9a484d5ecec7acb18c9d6bc6058be031"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"dc687e60a4b6750fe740853f22aeb2dc", "1949d86100004eca41099c8e6fa919ab"} + _compile_hash = {"425b38830f34b6eb448fa77dbfe9ac96", "de49128643cbf1c659a9a63aad1cba62"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index 27a8ec6df357..81115b6c0238 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -19,9 +19,17 @@ from tvm import te from tvm.contrib import miopen import numpy as np +import pytest + + +requires_miopen = pytest.mark.skipif( + tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True) is None, + reason="MIOpen is not enabled", +) @tvm.testing.requires_rocm +@requires_miopen def test_conv2d(): in_channel = 3 out_channel = 64 @@ -35,9 +43,6 @@ def test_conv2d(): dilation_w = 1 xshape = [1, in_channel, 128, 128] - if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True): - print("skip because miopen is not enabled...") - return wshape = (out_channel, in_channel, filter_h, filter_w) X = te.placeholder(xshape, name="X") @@ -72,5 +77,60 @@ def verify(): verify() +def verify_softmax(shape, axis, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +def verify_softmax_4d(shape, dtype="float32", log_softmax=False): + miopen_op = miopen.log_softmax if log_softmax else miopen.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + + A = te.placeholder(shape, dtype=dtype, name="A") + B = miopen_op(A, axis=1) + s = te.create_schedule([B.op]) + + dev = tvm.rocm(0) + n, c, h, w = shape + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) + b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax") + f(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) + + +@tvm.testing.requires_rocm +@requires_miopen +def test_softmax(): + verify_softmax((32, 10), -1) + verify_softmax((3, 4), -1) + verify_softmax_4d((1, 16, 256, 256)) + verify_softmax_4d((1, 16, 256, 256)) + + verify_softmax((32, 10), -1, log_softmax=True) + verify_softmax((3, 4), -1, log_softmax=True) + verify_softmax_4d((1, 16, 256, 256), log_softmax=True) + + if __name__ == "__main__": test_conv2d() diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 3636409f8a06..1c2d00aed866 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -27,6 +27,7 @@ import tvm from tvm import relay from tvm.contrib.target.onnx import to_onnx +from tvm.relay.testing import run_infer_type def func_to_onnx(func, name): @@ -174,6 +175,60 @@ def verify_conv2d( verify_conv2d("float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(4, 4)) +def test_conv2d_transpose(): + """Conv2d_Transpose unit tests.""" + + def verify_conv2d_transpose( + dtype, scale, dshape, kshape, padding=(1, 1), groups=1, dilation=(1, 1), **attrs + ): + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d_transpose( + x, w, padding=padding, dilation=dilation, groups=groups, **attrs + ) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + verify_results(func, [data, kernel], "test_conv2d_transpose", rtol=1e-5, atol=1e-5) + + dshape = (1, 3, 224, 224) + kshape = (3, 10, 3, 3) + verify_conv2d_transpose( + "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(3, 3) + ) + + dshape = (1, 3, 224, 224) + kshape = (3, 10, 3, 3) + verify_conv2d_transpose( + "float32", 1, dshape, kshape, padding=(2, 2), channels=10, kernel_size=(3, 3) + ) + + dshape = (1, 3, 18, 18) + kshape = (3, 10, 2, 2) + verify_conv2d_transpose( + "float32", + 1, + dshape, + kshape, + padding=(2, 2), + channels=10, + kernel_size=(2, 2), + dilation=(1, 1), + ) + + dshape = (1, 3, 18, 18) + kshape = (3, 10, 4, 4) + verify_conv2d_transpose( + "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(4, 4) + ) + + dshape = (1, 3, 18, 18) + kshape = (3, 10, 4, 4) + verify_conv2d_transpose( + "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=(4, 4) + ) + + def test_reshape(): def verify_reshape(shape, newshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -270,14 +325,16 @@ def verify_batch_norm(axis=1): def test_pad(): + """Pad unit test.""" + def verify_pad(): - for dtype in ["float16", "float32"]: - dshape = (4, 10, 7, 7) - x = relay.var("x", shape=dshape, dtype=dtype) - y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4))) - func = relay.Function([x], y) - x_data = np.random.uniform(size=dshape).astype(dtype) - verify_results(func, [x_data], "test_pad", rtol=1e-5, atol=1e-5) + dshape = (4, 10, 7, 7) + x = relay.var("x", shape=dshape, dtype="int32") + y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4))) + func = relay.Function([x], y) + func = run_infer_type(func) + x_data = np.random.randint(low=-255, high=255, size=dshape).astype(np.int32) + verify_results(func, [x_data], "test_pad", rtol=1e-5, atol=1e-5) verify_pad() @@ -516,6 +573,8 @@ def verify_expand_dims(dshape, axis, num_newaxis, dtype="float32"): def test_lrn(): + """LRN unit test.""" + def verify_lrn(xshape, size, dtype="float32"): x = relay.var("x", relay.ty.TensorType(xshape, dtype)) y = relay.nn.lrn(x, size=size, axis=1, alpha=1.0, beta=1.0, bias=1.0) @@ -530,10 +589,121 @@ def verify_lrn(xshape, size, dtype="float32"): verify_lrn(i, s) +def test_sigmoid(): + """Sigmoid unit test.""" + + def verify_sigmoid(dshape, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.sigmoid(x) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], "test_sigmoid", rtol=1e-4, atol=1e-4) + + isize = [(1, 3, 480, 640), (1, 3, 224, 224)] + + for i in isize: + verify_sigmoid(i) + + +def test_copy(): + """Copy unit test.""" + + def verify_copy(dshape, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.copy(x) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], "test_copy", rtol=1e-4, atol=1e-4) + + isize = [(1, 3, 480, 640), (1, 3, 224, 224)] + + for i in isize: + verify_copy(i) + + +def test_round(): + """Round unit test.""" + + def verify_round(dshape, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.round(x) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], "test_round", rtol=1e-4, atol=1e-4) + + isize = [(1, 3, 480, 640), (1, 3, 224, 224)] + + for i in isize: + verify_round(i) + + +def test_cast(): + """Cast unit test.""" + + def verify_cast(dshape, dtype): + x = relay.var("x", relay.ty.TensorType(dshape, "float32")) + y = relay.cast(x, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype("float32") + verify_results(func, [x_data], "test_cast", rtol=1e-4, atol=1e-4) + + isize = [(1, 3, 480, 640), (1, 3, 224, 224)] + out_dtypes = ["int8", "int16", "uint8", "uint16"] + + for i in isize: + for o_dtype in out_dtypes: + verify_cast(i, o_dtype) + + +def test_resize(): + """Resize unit test.""" + + def verify_resize(dshape, outsize, method, coord_trans, rounding_method, dtype="float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.image.resize2d( + x, + outsize, + layout="NCHW", + method=method, + coordinate_transformation_mode=coord_trans, + rounding_method=rounding_method, + ) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], "test_resize", rtol=1e-4, atol=1e-4) + + method = ["nearest_neighbor", "linear", "cubic"] + coord_trans = ["half_pixel", "align_corners", "asymmetric"] + rounding_method = ["round", "floor", "ceil"] + + isize = (1, 3, 480, 640) + + # Downsample + osize = (240, 320) + for i in method: + for j in coord_trans: + for k in rounding_method: + if (i == "nearest_neighbor" and j == "align_corners") or ( + i == "cubic" and j in ["half_pixel", "align_corners"] + ): + continue + verify_resize(isize, osize, method=i, coord_trans=j, rounding_method=k) + + # Upsample + osize = (960, 1280) + for i in method: + for j in coord_trans: + for k in rounding_method: + if (i == "nearest_neighbor" and j == "align_corners") or (i == "cubic"): + continue + verify_resize(isize, osize, method=i, coord_trans=j, rounding_method=k) + + if __name__ == "__main__": test_add() test_bias_add() test_conv2d() + test_conv2d_transpose() test_reshape() test_transpose() test_dense() @@ -554,3 +724,8 @@ def verify_lrn(xshape, size, dtype="float32"): test_clip() test_expand_dims() test_lrn() + test_sigmoid() + test_copy() + test_round() + test_cast() + test_resize() diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 59f1c3aa4d68..3f57df5a5f4a 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1251,33 +1251,35 @@ def test_tensorrt_dynamic_batch_conv(): x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] - for use_trt in [True, False]: - x = relay.var("x", shape=x_shape, dtype="float32") - kernel = relay.var("kernel", shape=k_shape, dtype="float32") - out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) - f = relay.Function([x, kernel], out) - mod = tvm.IRModule() - mod["main"] = f - if use_trt: - mod, _ = tensorrt.partition_for_tensorrt(mod, params) - + for use_implicit_batch in [True, False]: + result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + kernel = relay.var("kernel", shape=k_shape, dtype="float32") + out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) + f = relay.Function([x, kernel], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod, config = tensorrt.partition_for_tensorrt( + mod, params, use_implicit_batch=use_implicit_batch + ) + if not skip_runtime_test(): + for target in ["llvm", "cuda"]: + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + relay_exec = relay.create_executor( + "vm", mod=mod, device=tvm.device(target), target=target + ) + for i, batch_size in enumerate(batches_to_test): + result_arr[i][target][use_trt] = relay_exec.evaluate()( + x_data[:batch_size, ...], **params + ) if not skip_runtime_test(): - for target in ["llvm", "cuda"]: - with relay.build_config(opt_level=3): - relay_exec = relay.create_executor( - "vm", mod=mod, device=tvm.cpu(0), target="llvm" - ) - - for i, batch_size in enumerate(batches_to_test): - result_arr[i][target][use_trt] = relay_exec.evaluate()( - x_data[:batch_size, ...], **params - ) - - if not skip_runtime_test(): - for i in range(len(batches_to_test)): - for target in ["llvm", "cuda"]: - assert_result_dict_holds(result_arr[i][target]) + for i in range(len(batches_to_test)): + for target in ["llvm", "cuda"]: + assert_result_dict_holds(result_arr[i][target]) def test_maskrcnn_resnet50() -> None: diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py index 18c57d485d76..2e16792542ca 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py @@ -288,8 +288,8 @@ def expected(): func0 = relay.Function( [data0, weight0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0], bn.astuple() ) - func0 = set_func_attr(func0, "vitis_ai", "tvmgen_default_vitis_ai_0") - gv0 = relay.GlobalVar("tvmgen_default_vitis_ai_0") + func0 = set_func_attr(func0, "vitis_ai", "tvmgen_default_vitis_ai_main_0") + gv0 = relay.GlobalVar("tvmgen_default_vitis_ai_main_0") mod = tvm.IRModule() mod[gv0] = func0 mod = relay.transform.InferType()(mod) diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 4669fab916a6..0426f5678153 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -18,6 +18,7 @@ import pytest import os import shlex +import sys import tvm from tvm.driver import tvmc @@ -130,3 +131,7 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor." assert tvmc_package.params is not None, ".params must be set in the MLF archive." assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index cb6b82a32937..31fa688ad717 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -177,6 +177,11 @@ def test_shape_parser(): shape_dict = tvmc.common.parse_shape_string(shape_string) # Convert to strings to allow comparison with Any. assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected # Check that invalid pattern raises expected error. shape_string = "input:[a,10]" @@ -186,6 +191,22 @@ def test_shape_parser(): shape_string = "input:5,10 input2:10,10" with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 data/:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /data:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with invalid slashes raises error. + shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) def test_target_from_cli__error_duplicate(): diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 72dac9b2501f..8892018a79e9 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -30,6 +30,11 @@ import coremltools as cm import model_zoo import tvm.testing +import tempfile +from os import path +import tvm +import tvm.relay as relay +import tensorflow.keras as keras def get_tvm_output( @@ -206,12 +211,15 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): dtype = "float32" a_np = np.full(input_dim, 1, dtype=dtype) + if mode == "NN": - b_np = tvm.topi.testing.upsampling_python(a_np, (scale, scale)) + method = "nearest_neighbor" + coord_trans = "asymmetric" else: - new_h = input_dim[2] * scale - new_w = input_dim[3] * scale - b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), "NCHW") + method = "linear" + coord_trans = "align_corners" + + b_np = tvm.topi.testing.resize2d_python(a_np, (scale, scale), "NCHW", method, coord_trans) input = [("input", datatypes.Array(*input_dim))] output = [("output", datatypes.Array(*b_np.shape))] @@ -780,6 +788,44 @@ def test_forward_convolution(): verify_convolution((1, 3, 224, 224), filter=(32, 3, 3, 3), padding="SAME") +def test_can_build_keras_to_coreml_to_relay(): + """Test multiple conversion paths and importing from + a saved file.""" + model = keras.models.Sequential() + model.add( + keras.layers.Conv2D( + filters=6, + kernel_size=(1, 1), + activation="relu", + padding="same", + input_shape=(3, 3, 1), + data_format="channels_first", + ) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + kmodel_fn = path.join(tmpdir, "c1mdl.h5") + model.save(kmodel_fn) + + mdl = cm.convert(kmodel_fn) + model_file = path.join(tmpdir, "c1.mlmodel") + mdl.save(model_file) + + mdl = cm.models.MLModel(model_file) + desc = mdl.get_spec().description + iname = desc.input[0].name + ishape = desc.input[0].type.multiArrayType.shape + shape_dict = {} + for i in mdl.get_spec().description.input: + iname = i.name + ishape = i.type.multiArrayType.shape + shape_dict[iname] = ishape + mod, params = relay.frontend.from_coreml(mdl, shape_dict) + + with tvm.transform.PassContext(opt_level=3): + relay.build(mod, "llvm", params=params) + + if __name__ == "__main__": test_forward_AddLayerParams() test_forward_ConcatLayerParams() @@ -798,3 +844,4 @@ def test_forward_convolution(): test_resnet50_checkonly() test_forward_image_scaler() test_forward_convolution() + test_can_build_keras_to_coreml_to_relay() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 52c3346e5807..8b633c18977a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -376,7 +376,7 @@ def verify_depth_to_space(inshape, outshape, mode, blockSize): @tvm.testing.uses_gpu def test_depth_to_space(): # current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument. - # TO-DO, we can add mode arguement to test CRD mode and DCR mode + # TO-DO, we can add mode argument to test CRD mode and DCR mode # in the future when we update to a newer onnx version. verify_depth_to_space((1, 8, 2, 3), (1, 2, 4, 6), mode="CRD", blockSize=2) @@ -1117,7 +1117,14 @@ def verify_gemm(a_shape, b_shape, c_shape=None, freeze_params=False, dtype="floa ) model = helper.make_model(graph, producer_name="gemm_test") - verify_with_ort_with_inputs(model, input_values, freeze_params=freeze_params, dtype=dtype) + atol = 1e-5 + rtol = 1e-5 + if dtype == "float16": + atol = 1e-3 + rtol = 1e-3 + verify_with_ort_with_inputs( + model, input_values, freeze_params=freeze_params, dtype=dtype, atol=atol, rtol=rtol + ) @tvm.testing.uses_gpu @@ -1173,8 +1180,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev): verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target]) -# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed -@tvm.testing.parametrize_targets("llvm") +@tvm.testing.uses_gpu def test_batch_matmul(target, dev): verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev) verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, dev) @@ -1183,6 +1189,8 @@ def test_batch_matmul(target, dev): verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, dev) verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, dev) verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, dev) + verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32), target, dev) + verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16), target, dev) def verify_simple_dynamic_model(a_shape, b_shape, target, dev): @@ -1221,7 +1229,6 @@ def verify_model(ex, a_shape, b_shape): b_anys = [relay.Any()] * len(b_shape) mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) verify_model(ex, a_shape, b_shape) verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) @@ -1366,11 +1373,12 @@ def verify_upsample3d_trilinear(): y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") scales = [1.0, 1.0, 2.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.trilinear_resize3d_python( + out_array = tvm.topi.testing.resize3d_python( in_array, - (3 * scale, 3 * scale, 3 * scale), + (scale, scale, scale), "NCDHW", - coordinate_transformation_mode="half_pixel", + "linear", + coordinate_transformation_mode="asymmetric", ) ref_array = np.array(scales) @@ -2253,7 +2261,7 @@ def verify_where(condition, x, y, dtype, outdata, dynamic=False): @tvm.testing.uses_gpu def test_where(): - condition = np.array([[1, 0], [1, 1]], dtype=np.bool) + condition = np.array([[1, 0], [1, 1]], dtype=bool) x = np.array([[1, 2], [3, 4]], dtype=np.int64) y = np.array([[9, 8], [7, 6]], dtype=np.int64) outdata = np.where(condition, x, y) @@ -2274,7 +2282,7 @@ def test_where(): outdata = np.where(condition, x, y) verify_where(condition, x, y, TensorProto.FLOAT, outdata) - condition = np.array(1, dtype=np.bool) + condition = np.array(1, dtype=bool) x = np.array([[1, 2], [3, 4]], dtype=np.float32) y = np.array([[5, 6], [7, 8]], dtype=np.float32) outdata = np.where(condition, x, y) @@ -2489,7 +2497,7 @@ def repeat(N, D): repeat(1, D), repeat(1, D), ) - # Convolution with assymetric padding + # Convolution with asymmetric padding verify_conv( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), @@ -2582,7 +2590,6 @@ def repeat(N, D): def verify_convtranspose_with_padding( x_shape, w_shape, - y_shape, padding, kernel_shape, strides, @@ -2618,12 +2625,12 @@ def verify_convtranspose_with_padding( helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)), helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_shape)), ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, ["?"] * len(x_shape))], ) model = helper.make_model(graph, producer_name="convtranspose_pad_test") - verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, convert_to_static=True) + verify_with_ort(model, [x_shape, w_shape], use_vm=True, convert_to_static=True) def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): @@ -2651,7 +2658,7 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): ) model = helper.make_model(graph, producer_name="convtranspose_test") - verify_with_ort(model, [x_shape, w_shape], y_shape) + verify_with_ort(model, [x_shape, w_shape], y_shape, opset=11) @tvm.testing.uses_gpu @@ -2668,14 +2675,12 @@ def test_convtranspose(): def repeat(N, D): return tuple([N for _ in range(D)]) - # TODO(mbrookhart): onnxruntime in CI only supports 2D, - # find something else to test 1D and 3D against - for D in [2]: + # Once onnxruntime update is complete + for D in [1, 2, 3]: # Convolution with padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), 2 * repeat(1, D), repeat(3, D), repeat(1, D), @@ -2685,50 +2690,45 @@ def repeat(N, D): verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), ) - # Convolution with autopadding + # Convolution with unset padding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(5, D), - None, + 2 * repeat(0, D), repeat(3, D), repeat(1, D), repeat(1, D), - auto_pad="SAME_UPPER", + True, ) - # Convolution with valid autopadding + # Convolution with autopadding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), None, repeat(3, D), repeat(1, D), repeat(1, D), - auto_pad="VALID", + auto_pad="SAME_UPPER", ) - # Convolution with unset padding + # Convolution with valid autopadding verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(7, D), - 2 * repeat(0, D), + None, repeat(3, D), repeat(1, D), repeat(1, D), - True, + auto_pad="VALID", ) # Convolution with non uniform stride verify_convtranspose_with_padding( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), - (1, 1) + repeat(9, D), None, repeat(3, D), repeat(2, D), @@ -2740,7 +2740,6 @@ def repeat(N, D): # verify_convtranspose_with_padding( # (1, 1) + repeat(5, D), # (1, 1) + repeat(3, D), - # (1, 1) + repeat(5, D), # 2 * repeat(2, D), # repeat(3, D), # repeat(1, D), @@ -3548,7 +3547,7 @@ def test_gru(): @tvm.testing.uses_gpu def test_resize(): - def verify(ishape, oshape, scales, mode, coord_trans): + def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, exclude=False): nodes = [ make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []), make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), @@ -3566,6 +3565,8 @@ def verify(ishape, oshape, scales, mode, coord_trans): outputs=["Y"], mode=mode, coordinate_transformation_mode=coord_trans, + cubic_coeff_a=alpha, + exclude_outside=exclude, ) ) @@ -3582,29 +3583,66 @@ def verify(ishape, oshape, scales, mode, coord_trans): verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) - # upsampling - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") - - # downsampling - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") - - # scales are specified instead of sizes - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "nearest", "half_pixel") + for ndim in [1, 2, 3]: + method = "nearest" + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method, coord_trans) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method, coord_trans) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method, coord_trans) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method, coord_trans) + + method = "linear" + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) + + if ndim == 2: + # ONNX Runtime only supports cubic interpolation for 2D images + method = "cubic" + for alpha in [0.5, 0.75]: + for exclude in [True, False]: + # upsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [64] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # downsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [16] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # scales are specified instead of sizes + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [0.5] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [2] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) def verify_opset_10(ishape, scales, mode): nodes = [ @@ -3922,7 +3960,7 @@ def verify_cond_loop(): trip_count = np.array(5).astype(np.int64) res_y = np.array([13]).astype(np.float32) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) loop_graph = onnx.helper.make_graph( [loop_node], "loop_outer", @@ -3940,9 +3978,9 @@ def verify_cond_loop(): # Set a high trip count so that condition trips first. trip_count = np.array(40).astype(np.int64) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, opset=11) def verify_count_loop(): @@ -3978,7 +4016,7 @@ def verify_count_loop(): trip_count = np.array(5).astype(np.int64) res_y = np.array([13]).astype(np.float32) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) loop_graph = onnx.helper.make_graph( [loop_node], "loop_outer", @@ -3995,12 +4033,12 @@ def verify_count_loop(): loop_model = onnx.helper.make_model(loop_graph) trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] - verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) + verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True, opset=11) -def verify_tensor_loop(): +def verify_tensor_loop(shapeless_output=False): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) @@ -4032,7 +4070,14 @@ def verify_tensor_loop(): ) trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) + + # Allow testing of malformed nodes since pytorch likes to create these. + if shapeless_output: + scan_shape = None + else: + scan_shape = [5, 3, 3, 3, 3] + loop_graph = onnx.helper.make_graph( [loop_node], "loop_outer", @@ -4043,16 +4088,16 @@ def verify_tensor_loop(): ], outputs=[ onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape), ], ) loop_model = onnx.helper.make_model(loop_graph) trip_count = np.array(5).astype(np.int64) - cond = np.array(1).astype(np.bool) + cond = np.array(1).astype(bool) input_vals = [trip_count, cond, y] verify_with_ort_with_inputs( - loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True + loop_model, input_vals, use_vm=True, freeze_params=True, convert_to_static=True, opset=11 ) @@ -4063,31 +4108,45 @@ def test_loop(): verify_count_loop() # Test a loop that uses an array output. verify_tensor_loop() + # Test a loop that is malformed and has no output shape defined. + verify_tensor_loop(shapeless_output=True) -def verify_if(cond_array): +def verify_if(cond_array, num_outputs): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. - then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) - else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) - x = np.array([1, 2, 3, 4, 5]).astype(np.float32) - y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + def append_constant_nodes(nodes, outputs, expected, name): + outputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, [5])) - then_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x) - ) + expected.append(np.random.randn(5).astype("float32")) - else_const_node = onnx.helper.make_node( - "Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y) - ) + nodes.append( + onnx.helper.make_node( + "Constant", inputs=[], outputs=[name], value=numpy_helper.from_array(expected[-1]) + ) + ) + + if_outputs = [] + graph_outputs = [] + + then_nodes, then_outs, then_expected = [], [], [] + else_nodes, else_outs, else_expected = [], [], [] + + for i in range(num_outputs): + append_constant_nodes(then_nodes, then_outs, then_expected, "then_out{}".format(i)) + append_constant_nodes(else_nodes, else_outs, else_expected, "else_out{}".format(i)) - then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + if_outputs.append("res{}".format(i)) + graph_outputs.append( + onnx.helper.make_tensor_value_info("res{}".format(i), onnx.TensorProto.FLOAT, [5]), + ) - else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + then_body = onnx.helper.make_graph(then_nodes, "then_body", [], then_outs) + else_body = onnx.helper.make_graph(else_nodes, "else_body", [], else_outs) if_node = onnx.helper.make_node( - "If", inputs=["cond"], outputs=["res"], then_branch=then_body, else_branch=else_body + "If", inputs=["cond"], outputs=if_outputs, then_branch=then_body, else_branch=else_body ) if_graph = onnx.helper.make_graph( @@ -4096,9 +4155,7 @@ def verify_if(cond_array): inputs=[ onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), ], - outputs=[ - onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), - ], + outputs=graph_outputs, ) if_model = onnx.helper.make_model(if_graph) @@ -4106,12 +4163,14 @@ def verify_if(cond_array): cond = np.array([1]).astype("bool") else: cond = np.array(1).astype("bool") - correct_out = x if cond else y + correct_out = then_expected if cond else else_expected # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace this with # verify_with_ort once we update versions. for target, dev in tvm.testing.enabled_targets(): tvm_out = get_tvm_output_with_vm(if_model, [cond], target, dev, freeze_params=True) + if not isinstance(tvm_out, list): + tvm_out = [tvm_out] for i in range(len(tvm_out)): tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) @@ -4119,8 +4178,10 @@ def verify_if(cond_array): @tvm.testing.uses_gpu def test_if(): # Confirm that if works with cond as an array or scalar. - verify_if(cond_array=False) - verify_if(cond_array=True) + verify_if(cond_array=False, num_outputs=1) + verify_if(cond_array=False, num_outputs=2) + verify_if(cond_array=True, num_outputs=1) + verify_if(cond_array=True, num_outputs=2) @tvm.testing.uses_gpu @@ -4363,15 +4424,36 @@ def verify_eyelike(indata): onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/")) unsupported_onnx_tests = [ - "test_basic_convinteger/", + "test_adagrad/", + "test_adagrad_multiple/", + "test_adam/", + "test_adam_multiple/", + "test_argmax_default_axis_example_select_last_index/", + "test_argmax_default_axis_random_select_last_index/", + "test_argmax_keepdims_example_select_last_index/", + "test_argmax_keepdims_random_select_last_index/", + "test_argmax_negative_axis_keepdims_example_select_last_index/", + "test_argmax_negative_axis_keepdims_random_select_last_index/", + "test_argmax_no_keepdims_example_select_last_index/", + "test_argmax_no_keepdims_random_select_last_index/", + "test_argmin_default_axis_example_select_last_index/", + "test_argmin_default_axis_random_select_last_index/", + "test_argmin_keepdims_example_select_last_index/", + "test_argmin_keepdims_random_select_last_index/", + "test_argmin_negative_axis_keepdims_example_select_last_index/", + "test_argmin_negative_axis_keepdims_random_select_last_index/", + "test_argmin_no_keepdims_example_select_last_index/", + "test_argmin_no_keepdims_random_select_last_index/", + "test_cast_BFLOAT16_to_FLOAT/", "test_cast_DOUBLE_to_FLOAT16/", + "test_cast_FLOAT_to_BFLOAT16/", "test_cast_FLOAT_to_STRING/", "test_cast_STRING_to_FLOAT/", + "test_celu/", "test_compress_0/", "test_compress_1/", "test_compress_default_axis/", "test_compress_negative_axis/", - "test_convinteger_with_padding/", "test_convtranspose_dilations/", "test_convtranspose_output_shape/", "test_cumsum_1d/", @@ -4383,17 +4465,109 @@ def verify_eyelike(indata): "test_cumsum_2d_negative_axis/", "test_det_2d/", "test_det_nd/", + "test_dropout_default/", + "test_dropout_default_mask/", + "test_dropout_default_mask_ratio/", + "test_dropout_default_ratio/", + "test_einsum_batch_diagonal/", + "test_einsum_batch_matmul/", + "test_einsum_inner_prod/", + "test_einsum_sum/", + "test_einsum_transpose/", + "test_greater_equal/", + "test_greater_equal_bcast/", + "test_hardmax_axis_0/", + "test_hardmax_axis_1/", + "test_hardmax_default_axis/", + "test_if_seq/", + "test_less_equal/", + "test_less_equal_bcast/", + "test_logsoftmax_axis_0/", + "test_logsoftmax_axis_0_expanded/", + "test_logsoftmax_axis_1/", + "test_logsoftmax_axis_1_expanded/", + "test_logsoftmax_axis_2_expanded/", + "test_logsoftmax_default_axis/", + "test_logsoftmax_default_axis_expanded/", + "test_logsoftmax_example_1_expanded/", + "test_logsoftmax_large_number_expanded/", + "test_logsoftmax_negative_axis_expanded/", + "test_loop11/", + "test_loop13_seq/", "test_matmulinteger/", "test_maxpool_2d_same_lower/", "test_maxpool_2d_same_upper/", "test_maxpool_with_argmax_2d_precomputed_pads/", "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", + "test_momentum/", + "test_momentum_multiple/", "test_mvn/", + "test_nesterov_momentum/", + "test_nllloss_NC/", + "test_nllloss_NC_expanded/", + "test_nllloss_NCd1/", + "test_nllloss_NCd1_expanded/", + "test_nllloss_NCd1_ii/", + "test_nllloss_NCd1_ii_expanded/", + "test_nllloss_NCd1_mean_weight_negative_ii/", + "test_nllloss_NCd1_mean_weight_negative_ii_expanded/", + "test_nllloss_NCd1_weight/", + "test_nllloss_NCd1_weight_expanded/", + "test_nllloss_NCd1_weight_ii/", + "test_nllloss_NCd1_weight_ii_expanded/", + "test_nllloss_NCd1d2/", + "test_nllloss_NCd1d2_expanded/", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii/", + "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded/", + "test_nllloss_NCd1d2_reduction_mean/", + "test_nllloss_NCd1d2_reduction_mean_expanded/", + "test_nllloss_NCd1d2_reduction_sum/", + "test_nllloss_NCd1d2_reduction_sum_expanded/", + "test_nllloss_NCd1d2_with_weight/", + "test_nllloss_NCd1d2_with_weight_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_mean/", + "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_sum/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii/", + "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded/", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii/", + "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded/", + "test_nllloss_NCd1d2d3_sum_weight_high_ii/", + "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded/", + "test_nllloss_NCd1d2d3d4d5_mean_weight/", + "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded/", + "test_nllloss_NCd1d2d3d4d5_none_no_weight/", + "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded/", + "test_pow_types_float/", + "test_pow_types_float32_int32/", + "test_pow_types_float32_int64/", + "test_pow_types_float32_uint32/", + "test_pow_types_float32_uint64/", + "test_pow_types_int/", + "test_pow_types_int32_float32/", + "test_pow_types_int32_int32/", + "test_pow_types_int64_float32/", + "test_pow_types_int64_int64/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", + "test_reduce_sum_default_axes_keepdims_example/", + "test_reduce_sum_default_axes_keepdims_random/", + "test_reduce_sum_do_not_keepdims_example/", + "test_reduce_sum_do_not_keepdims_random/", + "test_reduce_sum_empty_axes_input_noop_example/", + "test_reduce_sum_empty_axes_input_noop_random/", + "test_reduce_sum_keepdims_example/", + "test_reduce_sum_keepdims_random/", + "test_reduce_sum_negative_axes_keepdims_example/", + "test_reduce_sum_negative_axes_keepdims_random/", + "test_resize_downsample_sizes_cubic/", + "test_resize_downsample_sizes_linear_pytorch_half_pixel/", + "test_resize_downsample_sizes_nearest/", "test_resize_tf_crop_and_resize/", - ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 + "test_resize_upsample_sizes_cubic/", + "test_resize_upsample_sizes_nearest/", "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", @@ -4401,8 +4575,94 @@ def verify_eyelike(indata): "test_round/", "test_scan9_sum/", "test_scan_sum/", + "test_sce_NCd1_mean_weight_negative_ii/", + "test_sce_NCd1_mean_weight_negative_ii_expanded/", + "test_sce_NCd1_mean_weight_negative_ii_log_prob/", + "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob/", + "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded/", + "test_sce_NCd1d2d3_sum_weight_high_ii/", + "test_sce_NCd1d2d3_sum_weight_high_ii_expanded/", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob/", + "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded/", + "test_sce_NCd1d2d3d4d5_mean_weight/", + "test_sce_NCd1d2d3d4d5_mean_weight_expanded/", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob/", + "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded/", + "test_sce_NCd1d2d3d4d5_none_no_weight/", + "test_sce_NCd1d2d3d4d5_none_no_weight_expanded/", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob/", + "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded/", + "test_sce_mean/", + "test_sce_mean_3d/", + "test_sce_mean_3d_expanded/", + "test_sce_mean_3d_log_prob/", + "test_sce_mean_3d_log_prob_expanded/", + "test_sce_mean_expanded/", + "test_sce_mean_log_prob/", + "test_sce_mean_log_prob_expanded/", + "test_sce_mean_no_weight_ii/", + "test_sce_mean_no_weight_ii_3d/", + "test_sce_mean_no_weight_ii_3d_expanded/", + "test_sce_mean_no_weight_ii_3d_log_prob/", + "test_sce_mean_no_weight_ii_3d_log_prob_expanded/", + "test_sce_mean_no_weight_ii_4d/", + "test_sce_mean_no_weight_ii_4d_expanded/", + "test_sce_mean_no_weight_ii_4d_log_prob/", + "test_sce_mean_no_weight_ii_4d_log_prob_expanded/", + "test_sce_mean_no_weight_ii_expanded/", + "test_sce_mean_no_weight_ii_log_prob/", + "test_sce_mean_no_weight_ii_log_prob_expanded/", + "test_sce_mean_weight/", + "test_sce_mean_weight_expanded/", + "test_sce_mean_weight_ii/", + "test_sce_mean_weight_ii_3d/", + "test_sce_mean_weight_ii_3d_expanded/", + "test_sce_mean_weight_ii_3d_log_prob/", + "test_sce_mean_weight_ii_3d_log_prob_expanded/", + "test_sce_mean_weight_ii_4d/", + "test_sce_mean_weight_ii_4d_expanded/", + "test_sce_mean_weight_ii_4d_log_prob/", + "test_sce_mean_weight_ii_4d_log_prob_expanded/", + "test_sce_mean_weight_ii_expanded/", + "test_sce_mean_weight_ii_log_prob/", + "test_sce_mean_weight_ii_log_prob_expanded/", + "test_sce_mean_weight_log_prob/", + "test_sce_mean_weight_log_prob_expanded/", + "test_sce_none/", + "test_sce_none_expanded/", + "test_sce_none_log_prob/", + "test_sce_none_log_prob_expanded/", + "test_sce_none_weights/", + "test_sce_none_weights_expanded/", + "test_sce_none_weights_log_prob/", + "test_sce_none_weights_log_prob_expanded/", + "test_sce_sum/", + "test_sce_sum_expanded/", + "test_sce_sum_log_prob/", + "test_sce_sum_log_prob_expanded/", + "test_sequence_insert_at_back/", + "test_sequence_insert_at_front/", "test_simple_rnn_defaults/", "test_simple_rnn_with_initial_bias/", + "test_softmax_axis_0/", + "test_softmax_axis_0_expanded/", + "test_softmax_axis_1/", + "test_softmax_axis_1_expanded/", + "test_softmax_axis_2_expanded/", + "test_softmax_default_axis/", + "test_softmax_default_axis_expanded/", + "test_softmax_example_expanded/", + "test_softmax_large_number_expanded/", + "test_softmax_negative_axis_expanded/", + "test_split_variable_parts_1d/", + "test_split_variable_parts_2d/", + "test_split_variable_parts_default_axis/", + "test_split_zero_size_splits/", + "test_squeeze/", + "test_squeeze_negative_axes/", "test_strnormalizer_export_monday_casesensintive_lower/", "test_strnormalizer_export_monday_casesensintive_nochangecase/", "test_strnormalizer_export_monday_casesensintive_upper/", @@ -4416,9 +4676,22 @@ def verify_eyelike(indata): "test_tfidfvectorizer_tf_onlybigrams_levelempty/", "test_tfidfvectorizer_tf_onlybigrams_skip5/", "test_tfidfvectorizer_tf_uniandbigrams_skip5/", + "test_training_dropout/", + "test_training_dropout_default/", + "test_training_dropout_default_mask/", + "test_training_dropout_mask/", + "test_training_dropout_zero_ratio/", + "test_training_dropout_zero_ratio_mask/", "test_unique_sorted_with_axis/", "test_unique_sorted_with_axis_3d/", "test_unique_sorted_with_negative_axis/", + "test_unsqueeze_axis_0/", + "test_unsqueeze_axis_1/", + "test_unsqueeze_axis_2/", + "test_unsqueeze_negative_axes/", + "test_unsqueeze_three_axes/", + "test_unsqueeze_two_axes/", + "test_unsqueeze_unsorted_axes/", "test_upsample_nearest/", ] @@ -4692,7 +4965,7 @@ def repeat(N, D): bias=True, ) - # Convolution with assymetric padding + # Convolution with asymmetric padding verify_qlinearconv( (1, 1) + repeat(5, D), (1, 1) + repeat(3, D), @@ -4815,6 +5088,221 @@ def test_qlinearadd(): verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) +def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None, target="llvm"): + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + node = helper.make_node( + "RandomUniform", [], ["out"], shape=shape, dtype=ONNX_DTYPE, high=high, low=low + ) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + graph = helper.make_graph( + [node], + "random_uniform_test", + inputs=[], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_uniform_test") + return get_tvm_output_with_vm(model, [], target=target, device=tvm.device(target, 0)) + + +def test_random_uniform(): + targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] + for target in targets: + # Check that function runs and produces proper shape. + vals = get_random_uniform([10], dtype="float32", target=target) + assert list(vals.shape) == [10] + assert vals.dtype == "float32" + + # Test N-D tensor generation. + vals = get_random_uniform([1, 3, 100, 100], dtype="float32", target=target) + assert list(vals.shape) == [1, 3, 100, 100] + + # Check that bounds aren't exceeded. + vals = get_random_uniform(shape=[100], high=100, low=-100) + assert list(vals.shape) == [100] + assert all(vals >= -100) and all(vals <= 100) + + # Check that a fixed seed produces the same values when run twice. + vals_1 = get_random_uniform(shape=[10], seed=1) + vals_2 = get_random_uniform(shape=[10], seed=1) + assert all(vals_1 == vals_2) + + # Test against an expected output with a fixed seed. + real = get_random_uniform(shape=[10], seed=5) + expected = np.asarray( + [ + 0.8614111, + 0.46572232, + 0.6007328, + 0.21619737, + 0.6361222, + 0.7298056, + 0.13094282, + 0.03556716, + 0.32997167, + 0.2977605, + ] + ) + tvm.testing.assert_allclose(real, expected, rtol=1e-5) + + +def verify_convinteger( + x_shape, + w_shape, + y_shape, + padding, + kernel_shape, + strides, + dilations, + auto_pad="NOTSET", + dtype="uint8", +): + + x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) + w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + input_nodes = [ + helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), + helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), + helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []), + helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []), + ] + input_names = [ + "x", + "w", + "x_zero_point", + "w_zero_point", + ] + input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array] + + if padding is None: + ## autopadding with unset default attributes + kwargs = {} + if not all([s == 1 for s in strides]): + kwargs["strides"] = strides + if not all([d == 1 for d in dilations]): + kwargs["dilations"] = dilations + + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + # Default values for other attributes: + auto_pad=auto_pad, + **kwargs, + ) + else: + node = helper.make_node( + "ConvInteger", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + pads=padding, + ) + + graph = helper.make_graph( + [node], + "convinteger_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], + ) + model = helper.make_model(graph, producer_name="convinteger_test") + # opt_level=1 will cause error + verify_with_ort_with_inputs(model, input_values, opt_level=2) + + +def test_convinteger(): + def repeat(N, D): + return tuple([N for _ in range(D)]) + + # only support 2D ConvInteger because we only support qnn.conv2d for now. + D = 2 + + # Convolution with padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + + # Convolution with asymmetric padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(4, D), + repeat(0, D) + repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution without padding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + ) + # Convolution with autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with valid autopadding + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="VALID", + ) + # Convolution with non uniform stride + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER", + ) + # Convolution with dilation + verify_convinteger( + (1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D), + ) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4898,3 +5386,6 @@ def test_qlinearadd(): test_reverse_sequence() test_eyelike() test_qlinearconv() + test_random_uniform() + test_convinteger() + test_batch_matmul() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2ec281094080..f76ea9a5d324 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3893,6 +3893,25 @@ def test_forward_nll_loss(): verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) +@tvm.testing.uses_gpu +def test_forward_flip(): + torch.set_grad_enabled(False) + + class Flip(Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return x.flip([self.axis]) + + input = torch.randn(2, 3, 4) + verify_model(Flip(axis=0), input_data=input) + verify_model(Flip(axis=1), input_data=input) + verify_model(Flip(axis=2), input_data=input) + verify_model(Flip(axis=-1), input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4035,6 +4054,7 @@ def test_forward_nll_loss(): test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() + test_forward_flip() # Model tests test_resnet18() diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_lstms.py new file mode 100644 index 000000000000..967245e1ef9d --- /dev/null +++ b/tests/python/frontend/pytorch/test_lstms.py @@ -0,0 +1,363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import numpy as np +import torch +import onnx +import io +import sys +import pytest + +from tvm import relay +from tvm.contrib import graph_executor + +from torch import nn + +## Model parameters +model_feature_size = 16 +model_hidden_size = 32 +model_num_layers = 2 +seqs_length = 2 +projection_size = 20 +batch_size = 2 + + +def check_torch_version_for_proj_in_lstm(): + """ + proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0 torch version + """ + me = False + + version = torch.__version__ + major, minor, micro = version.split(".") + + if int(major) > 1: + me = True + elif int(major) == 1: + if int(minor) >= 8: + me = True + + return me + + +class LSTM_Model(nn.Module): + def __init__( + self, + device, + batch_first=False, + layer_num=1, + bidirectional=False, + proj_size=0, + use_bias=True, + rnd_weights_init=False, + ): + super().__init__() + + self.device = device + self.batch_first = batch_first + self.use_bias = use_bias + + if check_torch_version_for_proj_in_lstm(): + self.lstm = nn.LSTM( + input_size=model_feature_size, + hidden_size=model_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + proj_size=proj_size, + batch_first=batch_first, + bias=use_bias, + ).to(device) + else: + if proj_size > 0: + print( + "WARNING: projection is not supported for torch version less than 1.8.0! ", + "LSTM was constructed without projection!", + ) + # sys.exit() + self.lstm = nn.LSTM( + input_size=model_feature_size, + hidden_size=model_hidden_size, + num_layers=layer_num, + bidirectional=bidirectional, + batch_first=batch_first, + bias=use_bias, + ).to(device) + + if rnd_weights_init: + self.gen_rnd_weights() + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along LSTM layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, model_feature_size) or (batch_size, seqs_length, model_feature_size) if self.batch_first = True + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, model_hidden_size) + """ + # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state + # and the final cell state. + out, (hidden, cell) = self.lstm(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model with biases + Without projection: + For first weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + For first bidirectional weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + For other weights group: + Wi (4*model_hidden_size, model_hidden_size) + Wh (4*model_hidden_size, model_hidden_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + With projection: + For first weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For first bidirectional weights group: + Wi (4*model_hidden_size, model_feature_size) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For other weights group: + Wi (4*model_hidden_size, proj_size * num_directions) + Wh (4*model_hidden_size, proj_size) + Bi (4*model_hidden_size) + Bh (4*model_hidden_size) + P (proj_size, model_hidden_size) + For generation of random weigths for the model without biases Bi and Bh are skipped + """ + for weight_group in self.lstm.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_input(self): + shape = [seqs_length, batch_size, model_feature_size] + if self.batch_first: + shape = [batch_size, seqs_length, model_feature_size] + res = torch.rand(shape) + + return res, shape + + +def compare(input, gold_data, rtol=1e-5, atol=1e-5): + tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) + + +def check_lstm_with_type( + lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) +): + has_proj = "p" in lstm_type + + device = torch.device("cpu") + hidden_layers_num = 1 + model = None + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in [True]: # (True, False): + if lstm_type == "uni": + model = LSTM_Model( + device, + batch_first=batch_first, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "b": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "p": + model = LSTM_Model( + device, + batch_first=batch_first, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "s": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "sb": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + elif lstm_type == "sp": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "bp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "sbp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + else: + print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) + return + + model.eval() + + # Get golden output from original model + input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) + input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) + dummy_input, input_shape = model.get_dummy_input() + golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() + + dtype = "float32" + h_zeros = np.zeros(input_hidden_shape, dtype=dtype) + if has_proj: + h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) + c_zeros = np.zeros(input_hidden_shape, dtype=dtype) + + tvm_output = None + for format in ["ts"]: # ["ts", "onnx"]: + if format == "ts": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_input).eval() + + # Import model to Relay + shape_list = [("input", input_shape)] + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + elif format == "onnx": + if has_proj: + print( + "WARNING: torch.onnx.export does not support conversion LSTM with projection " + "from pytorch! TODO: waiting for the support and correct test after that." + ) + continue + onnx_io = io.BytesIO() + with torch.no_grad(): + h0 = torch.rand(input_hidden_shape) + if has_proj: + h0 = torch.rand(input_hidden_shape_with_proj) + c0 = torch.rand(input_hidden_shape) + input_names = ["input", "h0", "c0"] + + # default export (without dynamic input) + torch.onnx.export( + model, (dummy_input, (h0, c0)), onnx_io, input_names=input_names + ) + onnx_io.seek(0, 0) + onnx_model = onnx.load_model(onnx_io) + + # Import model to Relay + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape, + "c0": input_hidden_shape, + } + if has_proj: + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape_with_proj, + "c0": input_hidden_shape, + } + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=1): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + m.set_input( + input=tvm.nd.array(dummy_input.numpy().astype(dtype)), + h0=tvm.nd.array(h_zeros), + c0=tvm.nd.array(c_zeros), + ) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output_batch) + + +@tvm.testing.uses_gpu +def test_lstms(): + for target, dev in tvm.testing.enabled_targets(): + check_lstm_with_type("uni", target, dev) + # check_lstm_with_type("p", target, dev) + check_lstm_with_type("s", target, dev) + check_lstm_with_type("b", target, dev) + # check_lstm_with_type("bp", target, dev) + # check_lstm_with_type("sp", target, dev) + check_lstm_with_type("sb", target, dev) + # check_lstm_with_type("sbp", target, dev) + + +if __name__ == "__main__": + test_lstms() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 583014f657ad..6733b326c395 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -29,6 +29,13 @@ import tensorflow.compat.v1 as tf except ImportError: import tensorflow as tf + +# Only allow TF to run on half the GPU RAM to save the other half +# For TVM +gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) +sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) +sess.close() + from tensorflow.python.framework import constant_op from tensorflow.python.framework import graph_util from tensorflow.python.ops import nn_ops @@ -117,7 +124,7 @@ def run_tvm_graph( disabled_pass=None, ignore_in_shape=False, serialize=False, - use_dense_op=True, + convert_config=None, ): """Generic function to compile on relay and execute on tvm""" input_data = convert_to_list(input_data) @@ -136,7 +143,7 @@ def run_tvm_graph( layout=layout, shape=shape_dict, outputs=out_names, - use_dense_op=use_dense_op, + convert_config=convert_config, ) dev = tvm.device(target, 0) if mode == "debug": @@ -218,7 +225,7 @@ def compare_tf_with_tvm( add_shapes_to_graph_def=True, targets=None, ignore_in_shape=False, - use_dense_op=True, + convert_config=None, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -266,7 +273,7 @@ def name_without_num(name): mode=mode, cuda_layout=cuda_layout, ignore_in_shape=ignore_in_shape, - use_dense_op=use_dense_op, + convert_config=convert_config, ) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared @@ -318,6 +325,8 @@ def _test_pooling(input_shape, **kwargs): if is_gpu_available(): if len(input_shape) == 4: input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] + if isinstance(kwargs["padding"], list): + kwargs["padding"] = [kwargs["padding"][ii] for ii in (0, 3, 1, 2)] kwargs["data_format"] = "NCHW" _test_pooling_iteration(input_shape, **kwargs) @@ -1802,8 +1811,12 @@ def _test_matmul(i, j, k, dtype, outer=None): A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=True) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=False) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": True} + ) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": False} + ) def test_forward_matmul(): @@ -1821,7 +1834,18 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": False}, + ) def _test_batch_matmul_dynamic( @@ -1834,10 +1858,23 @@ def _test_batch_matmul_dynamic( A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype) - # for now, in TOPI, only cublas's implementation support dynamic shape + # for now, in TOPI, only llvm & cublas's implementation support dynamic shape # TODO add more backends support in TOPI compare_tf_with_tvm( - [A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"] + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": False}, ) @@ -1856,7 +1893,6 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) -@tvm.testing.requires_cuda def test_forward_batch_matmul_dynamic(): _test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32") _test_batch_matmul_dynamic( @@ -2553,7 +2589,9 @@ def test_forward_stridedslice(): _test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1) _test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1) _test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([2, 3, 4], [-2], [0], [1], "float32", shrink_axis_mask=8) _test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8) _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32") _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index b3504ff38328..001ba6de1967 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -354,7 +354,8 @@ def get_input(self): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) def func(self, x): a, b, c = tf.split(x, 3, axis=1) - return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1) + axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, dtype="int32")) + return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis) run_all(ConcatV2) @@ -447,5 +448,142 @@ def func(self, x): run_model_graph(StatelessWhile2Var, outputs=["Identity:output:0"]) +def test_tensorlist(): + def run_test(elem_shape): + class TensorList(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1, :] = np.zeros((3,), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList) + run_func_graph(TensorList, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_stack(): + def run_test(elem_shape): + class TensorListStack(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1] = np.zeros((3,), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack) + run_func_graph(TensorListStack, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_2d(): + def run_test(elem_shape): + class TensorList2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList2D) + run_func_graph(TensorList2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + +def test_tensorlist_stack_2d(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + +def test_tensorlist_stack_unpack(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((1, 3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=1, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1 + ) + output = tf.raw_ops.Unpack(value=output, num=1, axis=0) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index 394a49d0f2e9..1b5a6342f07d 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -109,5 +109,60 @@ def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)): run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3)) +def test_tensorlist_stack_model(): + def tensorlist_stack_model(input_shape): + class TensorArrayStackLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + outputs = outputs.unstack(inputs) + + return outputs.stack() + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayStackLayer()] + ) + return model + + run_sequential_model(tensorlist_stack_model, input_shape=(3, 32)) + + +def test_tensorlist_read_model(): + def tensorlist_read_model(input_shape): + class TensorArrayReadLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + for i in range(inputs.shape[0]): + outputs = outputs.write(i, inputs[i, :]) + + return outputs.read(0) + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayReadLayer()] + ) + return model + + run_sequential_model(tensorlist_read_model, input_shape=(3, 32)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7b5377b3363d..5b5c7fda1de9 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4412,7 +4412,7 @@ def test_forward_mediapipe_hand_landmark(): # -------------- def test_prevent_tensorflow_dynamic_range(): """ - Should prevent runnung "dynamic range quantization" optimized TFLite graph + Should prevent running "dynamic range quantization" optimized TFLite graph """ data_array = np.random.randint(0, 2, (1, 1024, 1024)).astype(dtype=np.float32) filter_array = np.random.randint(0, 2, (1024, 1024)).astype(dtype=np.float32) diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py new file mode 100644 index 000000000000..3fa4795870d5 --- /dev/null +++ b/tests/python/integration/test_lower.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test workload for lowering and build""" +import tvm +from tvm import tir +from tvm.script import ty +import tvm.testing +import numpy as np + + +@tvm.script.tir +def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + # match buffer + A = tir.match_buffer(a, [1024, 1024], "float16") + B = tir.match_buffer(b, [1024, 1024], "float16") + C = tir.match_buffer(c, [1024, 1024], "float32") + + # body + for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): + with tir.block([16, 8]) as [bx, by]: + tir.bind(bx, blockIdx_x) + tir.bind(by, blockIdx_y) + shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") + wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") + wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") + for ty in tir.thread_binding(0, 2, "threadIdx.y"): + for tz in tir.thread_binding(0, 2, "threadIdx.z"): + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads([]) + tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_fill_fragment( + C0.data, + 16, + 16, + 16, + i * 4 + j, + tir.float32(0), + dtype="handle", + ) + ) + + for ko in range(0, 32): + # copy data from global to shared + for tx in tir.thread_binding(0, 32, "threadIdx.x"): + for i0, j0 in tir.grid(1, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, bx * 64 + ty * 32 + tx + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_A[vi, vj + 8] = A[vi, vj] + + for i0, j0 in tir.grid(2, 4): + for j1 in tir.vectorized(0, 4): + with tir.block([1024, 1024]) as [vi, vj]: + tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) + tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + shared_B[vi, vj + 8] = B[vi, vj] + + for ki in range(0, 2): + for i in range(0, 2): + with tir.block([64, 64]) as [vi, vk]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + A0 = tir.match_buffer( + shared_A[ + vi * 16 : vi * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_A0 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_A0.data, + 16, + 16, + 16, + i, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset + 8, + A0.strides[0], + 1, + dtype="handle", + ), + A0.strides[0], + "row_major", + dtype="handle", + ) + ) + for j in range(0, 4): + with tir.block([64, 64]) as [vj, vk]: + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ] + ) + tir.writes( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] + ) + s0 = tir.var("int32") + s1 = tir.var("int32") + B0 = tir.match_buffer( + shared_B[ + vj * 16 : vj * 16 + 16, + vk * 16 : vk * 16 + 16 + 8, + ], + (16, 16 + 8), + "float16", + strides=[s0, s1], + scope="shared", + offset_factor=1, + ) + wmma_B0 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + wmma_B0.data, + 16, + 16, + 16, + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + B0.data, + B0.elem_offset + 8, + B0.strides[0], + 1, + dtype="handle", + ), + B0.strides[0], + "col_major", + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + vi, + vj, + vk, + ]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.bind(vk, ko * 2 + ki) + tir.reads( + [ + wmma_A[ + vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_B[ + vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + ], + wmma_C[ + vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 + ], + ] + ) + tir.writes( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] + ) + wmma_A1 = tir.match_buffer( + wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_a", + offset_factor=1, + ) + wmma_B1 = tir.match_buffer( + wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + "float16", + strides=[16, 1], + scope="wmma.matrix_b", + offset_factor=1, + ) + wmma_C1 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + tir.evaluate( + tir.tvm_mma_sync( + wmma_C1.data, + i * 4 + j, + wmma_A1.data, + i, + wmma_B1.data, + j, + wmma_C1.data, + i * 4 + j, + dtype="handle", + ) + ) + for i, j in tir.grid(2, 4): + with tir.block([64, 64]) as [vi, vj]: + tir.bind(vi, bx * 4 + ty * 2 + i) + tir.bind(vj, by * 8 + tz * 4 + j) + tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = tir.var("int32") + s1 = tir.var("int32") + wmma_C2 = tir.match_buffer( + wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[16 * 4, 1], + scope="wmma.accumulator", + offset_factor=1, + ) + C1 = tir.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + "float32", + strides=[s0, s1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_store_matrix_sync( + wmma_C2.data, + 16, + 16, + 16, + i * 4 + j, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float32"), + C1.data, + C1.elem_offset, + C1.strides[0], + 1, + dtype="handle", + ), + C1.strides[0], + "row_major", + dtype="handle", + ) + ) + + +@tvm.testing.requires_cuda +def test_gemm_tensorcore(): + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(1024, 1024)).astype("float16") + b_np = np.random.uniform(size=(1024, 1024)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((1024, 1024), dtype="float32"), dev) + f = tvm.build(tensorcore_gemm, target="cuda", name="dense") + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + evaluator = f.time_evaluator(f.entry_name, dev, number=100) + t = evaluator(a, b, c).mean + num_flops = 2 * 1024 * 1024 * 1024 + gflops = num_flops / (t * 1e3) / 1e6 + print("gemm with tensor core: %f ms" % (t * 1e3)) + print("GFLOPS: %f" % gflops) + + +if __name__ == "__main__": + test_gemm_tensorcore() diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/aot_test.mk index 2426d9fd2963..81e31762611f 100644 --- a/tests/python/relay/aot/aot_test.mk +++ b/tests/python/relay/aot/aot_test.mk @@ -34,7 +34,8 @@ PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I$(DMLC_CORE)/include \ -I$(TVM_ROOT)/3rdparty/dlpack/include \ -I$(AOT_ROOT)\ - -I$(build_dir) + -I$(build_dir) \ + -I$(CODEGEN_ROOT)/host/include $(ifeq VERBOSE,1) QUIET ?= diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 836ff4b22b20..900eb67e2b48 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -16,24 +16,20 @@ # under the License. import os -import io -import struct -import numpy as np +import itertools import pathlib -import shutil import subprocess -import tempfile import tarfile import json +import pytest +import numpy as np import tvm from tvm import relay -from tvm.relay import transform from tvm.contrib import utils, graph_executor from tvm.relay.backend import compile_engine from tvm.relay.backend.utils import mangle_module_name -from tvm.contrib import utils from tvm.micro import export_model_library_format @@ -82,6 +78,26 @@ def convert_to_list(x): return mod, params +def parametrize_aot_options(test): + """Parametrize over valid option combinations""" + + interface_api = ["packed", "c"] + use_unpacked_api = [True, False] + use_calculated_workspaces = [True, False] + + all_combinations = itertools.product(interface_api, use_unpacked_api, use_calculated_workspaces) + # Filter out packed operators with c interface + valid_combinations = filter( + lambda parameters: not (parameters[0] == "c" and parameters[1] == False), + all_combinations, + ) + + return pytest.mark.parametrize( + ["interface_api", "use_unpacked_api", "use_calculated_workspaces"], + valid_combinations, + )(test) + + def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): """ This method runs a process and logs the output to both a log file and stdout @@ -102,12 +118,11 @@ def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): print(text, end="") -def emit_main_network_definition(main_file, mod_name): - main_file.write(f'extern tvm_model_t {mangle_name(mod_name,"network")};\n') - - def emit_main_prologue(main_file, workspace_bytes): - main_file.write(f"#define WORKSPACE_SIZE ({workspace_bytes})\n") + # Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment. + main_file.write( + f"#define WORKSPACE_SIZE ({workspace_bytes} + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n" + ) main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") main_file.write("tvm_workspace_t app_workspace;\n") main_file.write( @@ -125,57 +140,133 @@ def emit_main_prologue(main_file, workspace_bytes): void TVMLogf(const char* msg, ...) { } TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} -int main(){\n - - """ +int main(){\n +""" ) -def emit_main_data(main_file, input_list, output_list, mod_name): - for i in range(0, len(input_list)): - main_file.write(f'#include "{mangle_name(mod_name,"input_data")}{i}.h"\n') +def emit_main_data(main_file, input_map, output_list, mod_name): + for key in input_map: + main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n') for i in range(0, len(output_list)): main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n') main_file.write(f'#include "{mangle_name(mod_name,"output_data")}{i}.h"\n') -def emit_main_run(main_file, input_list, output_list, mod_name): +def emit_main_data_structs(main_file, input_map, output_list, mod_name): + main_file.write( + f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{" + ) + for key in input_map: + main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n") + main_file.write("};\n") + + main_file.write( + f"struct {mangle_name(mod_name, 'outputs')} {mangle_name(mod_name, 'outputs')} = {{" + ) + num_outputs = len(output_list) + if num_outputs == 1: + main_file.write(f"\t.output = {mangle_name(mod_name, 'output_data')}0,\n") + else: + for i in range(0, num_outputs): + main_file.write(f"\t.output{i} = {mangle_name(mod_name, 'output_data')}{i},\n") + main_file.write("};\n") + + +def emit_main_data_setup(main_file, input_map, output_list, mod_name): num_outputs = len(output_list) - num_inputs = len(input_list) + num_inputs = len(input_map) main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ') - - for i in range(0, len(input_list)): - main_file.write(f'{mangle_name(mod_name,"input_data")}{i}, ') + for key in input_map: + main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ') main_file.write("};\n") main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ') - for i in range(0, len(output_list)): + for i in range(0, num_outputs): main_file.write(f'{mangle_name(mod_name,"output_data")}{i}, ') main_file.write("};\n") + + +def emit_main_c_interface_call(main_file, mod_name): main_file.write( - f'tvm_runtime_run(&{mangle_name(mod_name,"network")}, {mangle_name(mod_name,"inputs")}, {mangle_name(mod_name,"outputs")});' + f'{mangle_name(mod_name,"run")}(&{mangle_name(mod_name,"inputs")}, &{mangle_name(mod_name,"outputs")});\n' ) +def emit_main_fake_packed_values(main_file): + main_file.write( + """ + static DLDevice fake_device = {kDLCPU, 0}; + static int64_t fake_dims = 0; + static int64_t fake_shape = {0}; + """ + ) + + +def emit_main_packed_call(main_file, input_map, output_list, mod_name): + tensors_name = mangle_name(mod_name, "tensors") + values_name = mangle_name(mod_name, "values") + typeids_name = mangle_name(mod_name, "typeids") + + def fake_tensor(source, source_index, packed_index): + main_file.write( + f""" + {tensors_name}[{packed_index}].device = fake_device; + {tensors_name}[{packed_index}].data = {source}[{source_index}]; + {tensors_name}[{packed_index}].shape = &fake_shape; + {tensors_name}[{packed_index}].ndim = fake_dims; + {tensors_name}[{packed_index}].byte_offset = 0; + {tensors_name}[{packed_index}].strides = NULL; + {values_name}[{packed_index}].v_handle = &{tensors_name}[{packed_index}]; + """ + ) + + num_outputs = len(output_list) + num_inputs = len(input_map) + num_tensors = num_inputs + num_outputs + main_file.write( + f""" + DLTensor {tensors_name}[{num_tensors}]; + TVMValue {values_name}[{num_tensors}]; + int32_t {typeids_name}[{num_tensors}]; + """ + ) + + for i in range(0, num_inputs): + fake_tensor(mangle_name(mod_name, "inputs"), i, i) + for i in range(0, num_outputs): + fake_tensor(mangle_name(mod_name, "outputs"), i, i + num_inputs) + + main_file.write( + f'{mangle_name(mod_name, "run")}({values_name}, {typeids_name}, 0, NULL, 0, NULL);\n' + ) + main_file.write("\n") + + def emit_main_compare(main_file, output_list, mod_name): - for i in range(0, len(output_list)): + num_outputs = len(output_list) + actual_data_name = mangle_name(mod_name, "output_data") + expected_data_name = mangle_name(mod_name, "expected_output_data") + + for i in range(0, num_outputs): is_float_dtype = output_list[i].dtype == "float32" - main_file.write(f'for (int i = 0; i<{mangle_name(mod_name,"output_data")}{i}_len; i++){{\n') + main_file.write(f"for (int i = 0; i<{actual_data_name}{i}_len; i++){{\n") if is_float_dtype: main_file.write( - f'if (fabs({mangle_name(mod_name,"output_data")}{i}[i]-{mangle_name(mod_name,"expected_output_data")}{i}[i]) > 0.001f){{printf("ko\\n");return -1;}}\n' + f'if (fabs({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > 0.001f){{\n\tprintf("ko\\n");\n\treturn -1;}}\n' ) else: main_file.write( - f'if ({mangle_name(mod_name,"output_data")}{i}[i]!={mangle_name(mod_name, "expected_output_data")}{i}[i]){{printf("ko\\n");return -1;}}\n' + f'if ({actual_data_name}{i}[i]!={expected_data_name}{i}[i]){{\n\tprintf("ko\\n");\n\treturn -1;}}\n' ) main_file.write("}\n") def emit_main_init_memory_manager(main_file): main_file.write("StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);") + main_file.write("\n") def emit_main_epilogue(main_file): @@ -187,33 +278,48 @@ def emit_main_epilogue(main_file): def emit_main_common_includes(main_file): main_file.write("#include \n") main_file.write("#include \n") - main_file.write('#include "tvm/runtime/crt/internal/aot_executor/aot_executor.h"\n') + main_file.write('#include "tvm/runtime/c_runtime_api.h"\n') main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') -def create_main(test_name, input_list_map, output_list_map, output_path, workspace_bytes): +def emit_main_micro_include(main_file, mod_name): + main_file.write(f"#include <{mangle_module_name(mod_name)}.h>\n") + + +def create_main(test_name, input_map, output_list_map, output_path, interface_api, workspace_bytes): file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() # create header file raw_path = file_path.with_suffix(".c").resolve() with open(raw_path, "w") as main_file: emit_main_common_includes(main_file) - for k in input_list_map: - emit_main_network_definition(main_file, k) + if interface_api == "c": + for mod_name in input_map: + emit_main_micro_include(main_file, mod_name) emit_main_prologue(main_file, workspace_bytes) - - for k in input_list_map: - emit_main_data(main_file, input_list_map[k], output_list_map[k], k) - + for mod_name in input_map: + emit_main_data(main_file, input_map[mod_name], output_list_map[mod_name], mod_name) emit_main_init_memory_manager(main_file) - for k in input_list_map: - emit_main_run(main_file, input_list_map[k], output_list_map[k], k) - - for k in input_list_map: - emit_main_compare(main_file, output_list_map[k], k) - + if interface_api == "c": + for mod_name in input_map: + emit_main_data_structs( + main_file, input_map[mod_name], output_list_map[mod_name], mod_name + ) + emit_main_c_interface_call(main_file, mod_name) + else: + emit_main_fake_packed_values(main_file) + for mod_name in input_map: + emit_main_data_setup( + main_file, input_map[mod_name], output_list_map[mod_name], mod_name + ) + emit_main_packed_call( + main_file, input_map[mod_name], output_list_map[mod_name], mod_name + ) + + for mod_name in input_map: + emit_main_compare(main_file, output_list_map[mod_name], mod_name) emit_main_epilogue(main_file) @@ -254,19 +360,22 @@ def extract_main_workspace_sizebytes(extract_dir): def compile_and_run( mod, - input_list, + inputs, output_list, - target_options, + interface_api, + use_unpacked_api, use_calculated_workspaces, params=None, workspace_byte_alignment=8, - mod_name=None, + mod_name="default", enable_op_fusion=True, ): """ This method verifies the generated source """ - target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment} {target_options}" + base_target = "c -runtime=c --link-params --executor=aot" + extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" + target = f"{base_target} {extra_target}" cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " # The calculated workspaces will not account for stack allocator tags used for debugging @@ -296,8 +405,8 @@ def compile_and_run( else: workspace_bytes = 16384 * 1024 - for i in range(len(input_list)): - create_header_file((f'{mangle_name(mod_name, "input_data")}{i}'), input_list[i], build_path) + for key in inputs: + create_header_file(f'{mangle_name(mod_name, "input_data")}_{key}', inputs[key], build_path) for i in range(len(output_list)): create_header_file( @@ -310,16 +419,23 @@ def compile_and_run( ) create_main( - "test.c", {mod_name: input_list}, {mod_name: output_list}, build_path, workspace_bytes + "test.c", + {mod_name: inputs}, + {mod_name: output_list}, + build_path, + interface_api, + workspace_bytes, ) # Verify that compiles fine file_dir = os.path.dirname(os.path.abspath(__file__)) + codegen_path = os.path.join(base_path, "codegen") makefile = os.path.join(file_dir, "aot_test.mk") make_cmd = ( f"make CFLAGS='{cflags}' -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + + f" CODEGEN_ROOT={codegen_path}" ) compile_log_path = os.path.join(build_path, "test_compile.log") @@ -333,12 +449,21 @@ def compile_and_run( def compile_and_run_multiple_models( - mod_map, input_list_map, output_list_map, target_options, param_map + mod_map, + input_list_map, + output_list_map, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + param_map, + workspace_byte_alignment=8, ): """ This method verifies the generated source """ - target = f"c -runtime=c --link-params --executor=aot {target_options}" + base_target = "c -runtime=c --link-params --executor=aot" + extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" + target = f"{base_target} {extra_target}" tmp_path = utils.tempdir() tmp_dir = tmp_path.temp_dir @@ -360,9 +485,9 @@ def compile_and_run_multiple_models( input_list = input_list_map[mod_name] output_list = output_list_map[mod_name] - for i in range(len(input_list_map[mod_name])): + for key in input_list: create_header_file( - (f'{mangle_name(mod_name,"input_data")}{i}'), input_list[i], build_path + (f'{mangle_name(mod_name,"input_data")}_{key}'), input_list[key], build_path ) for i in range(len(output_list_map[mod_name])): @@ -375,12 +500,25 @@ def compile_and_run_multiple_models( (f'{mangle_name(mod_name,"expected_output_data")}{i}'), output_list[i], build_path ) - create_main("test.c", input_list_map, output_list_map, build_path, workspace_bytes=16384 * 1024) + create_main( + "test.c", + input_list_map, + output_list_map, + build_path, + interface_api, + workspace_bytes=16384 * 1024, + ) # Verify that compiles fine file_dir = os.path.dirname(os.path.abspath(__file__)) + codegen_path = os.path.join(base_path, "codegen") makefile = os.path.join(file_dir, "aot_test.mk") - make_cmd = f"make -f {makefile} build_dir=" + build_path + f" TVM_ROOT={file_dir}/../../../.." + make_cmd = ( + f"make -f {makefile} build_dir=" + + build_path + + f" TVM_ROOT={file_dir}/../../../.." + + f" CODEGEN_ROOT={codegen_path}" + ) compile_log_path = os.path.join(build_path, "test_compile.log") ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 13cbfa71b6ae..26eca2688436 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -15,37 +15,48 @@ # specific language governing permissions and limitations # under the License. -import os -import io -import struct +from collections import OrderedDict + import numpy as np -import pathlib -import shutil -import subprocess -import tempfile -import tarfile import pytest import tvm from tvm import relay -from tvm.relay import transform -from tvm.relay.op.contrib import get_pattern_table -from tvm.contrib import utils -from tvm.relay.backend import compile_engine -from tvm.contrib import utils -from tvm.contrib import graph_executor -from tvm.micro import export_model_library_format -from tvm.relay import testing +from tvm.relay import testing, transform from tvm.relay.op.annotation import compiler_begin, compiler_end -from tvm.contrib import utils from tvm.relay.expr_functor import ExprMutator +from aot_test_utils import ( + generate_ref_data, + convert_to_relay, + compile_and_run, + compile_and_run_multiple_models, + parametrize_aot_options, +) -from aot_test_utils import * +def test_error_c_interface_with_packed_api(): + interface_api = "c" + use_unpacked_api = False + use_calculated_workspaces = True -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_conv_with_params(use_calculated_workspaces, target_options): + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + output_list = generate_ref_data(func, {}) + input_list = [] + + with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"): + compile_and_run( + func, + input_list, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) + + +@parametrize_aot_options +def test_conv_with_params(interface_api, use_unpacked_api, use_calculated_workspaces): RELAY_MODEL = """ #[version = "0.0.5"] def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), int8]) { @@ -73,13 +84,19 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), inputs = {"data": input_data} output_list = generate_ref_data(mod, inputs, params) - input_list = [input_data] - compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces, params) + compile_and_run( + mod, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + params, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_with_params(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_add_with_params(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) @@ -92,15 +109,19 @@ def test_add_with_params(use_calculated_workspaces, target_options): inputs = {"y": y_in} output_list = generate_ref_data(func, inputs, params) - input_list = [y_in] compile_and_run( - func, input_list, output_list, target_options, use_calculated_workspaces, params + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + params, ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_conv2d(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_conv2d(use_calculated_workspaces, interface_api, use_unpacked_api): """Test a subgraph with a single conv2d operator.""" def conv2d_direct(): @@ -119,7 +140,8 @@ def conv2d_direct(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - return mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14) + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) + return mod, inputs, (1, 32, 14, 14) def group_conv2d(): dtype = "float32" @@ -137,17 +159,23 @@ def group_conv2d(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w2shape).astype(dtype) - return mod, {"data": i_data, "weight": w_data}, (1, 32, 14, 14) + inputs = OrderedDict([("data", i_data), ("weight", w_data)]) + return mod, inputs, (1, 32, 14, 14) for mod, inputs, out_shape in [conv2d_direct(), group_conv2d()]: output_list = generate_ref_data(mod, inputs) - input_list = [inputs["data"], inputs["weight"]] - compile_and_run(mod, input_list, output_list, target_options, use_calculated_workspaces) - - -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_concatenate(use_calculated_workspaces, target_options): + compile_and_run( + mod, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) + + +@parametrize_aot_options +def test_concatenate(interface_api, use_unpacked_api, use_calculated_workspaces): dtype = "float32" x = relay.var("x", shape=(10, 5), dtype=dtype) y = relay.var("y", shape=(10, 5), dtype=dtype) @@ -159,16 +187,21 @@ def test_concatenate(use_calculated_workspaces, target_options): x_data = np.random.rand(10, 5).astype(dtype) y_data = np.random.rand(10, 5).astype(dtype) t_data = np.random.uniform(size=()).astype(dtype) - inputs = {"x": x_data, "y": y_data, "z": t_data} + inputs = OrderedDict([("x", x_data), ("y", y_data), ("z", t_data)]) output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"], inputs["z"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_nested_tuples(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_nested_tuples(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", shape=(10,)) x1 = x + relay.const(1.0) x2 = x1 + relay.const(1.0) @@ -180,71 +213,109 @@ def test_nested_tuples(use_calculated_workspaces, target_options): x_data = np.random.uniform(size=(10,)).astype(np.float32) inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) - input_list = [x_data] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_tuple_getitem(use_calculated_workspaces, target_options): + +@parametrize_aot_options +def test_tuple_getitem(interface_api, use_unpacked_api, use_calculated_workspaces): func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) output_list = generate_ref_data(func, {}) - input_list = [] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + inputs = {} + + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_id(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_id(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", "float32") ident = relay.Function([x], x) one = np.array(1.0, "float32") inputs = {"x": one} output_list = generate_ref_data(ident, inputs) - input_list = [one] - compile_and_run(ident, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + ident, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_const(use_calculated_workspaces, target_options): + +@parametrize_aot_options +def test_add_const(interface_api, use_unpacked_api, use_calculated_workspaces): two = relay.add(relay.const(1), relay.const(1)) func = relay.Function([], two) output_list = generate_ref_data(func, {}) - input_list = [] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + inputs = {} + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_mul_param(use_calculated_workspaces, target_options): + +@parametrize_aot_options +def test_mul_param(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) func = relay.Function([x, y], relay.multiply(x, y)) x_data = np.random.rand(10, 10).astype("float32") y_data = np.random.rand(1, 10).astype("float32") - inputs = {"x": x_data, "y": y_data} + + inputs = OrderedDict([("x", x_data), ("y", y_data)]) output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_subtract(use_calculated_workspaces, target_options): + +@parametrize_aot_options +def test_subtract(interface_api, use_unpacked_api, use_calculated_workspaces): i = relay.var("i", shape=[], dtype="int32") sub = relay.subtract(i, relay.const(1, dtype="int32")) func = relay.Function([i], sub, ret_type=relay.TensorType([], "int32")) i_data = np.array(1, dtype="int32") inputs = {"i": i_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["i"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) -@pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_tuple_output(use_calculated_workspaces, target_options): +@parametrize_aot_options +def test_tuple_output(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", shape=(6, 9)) y = relay.split(x, 3).astuple() a = relay.TupleGetItem(y, 0) @@ -255,29 +326,34 @@ def test_tuple_output(use_calculated_workspaces, target_options): x_data = np.random.rand(6, 9).astype("float32") inputs = {"x": x_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"]] - compile_and_run(func, input_list, output_list, target_options, use_calculated_workspaces) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + ) @pytest.mark.parametrize( - "use_calculated_workspaces_and_alignment", [(True, 1), (True, 16), (False, 1)] + ["use_calculated_workspaces", "workspace_byte_alignment"], [(True, 1), (True, 16), (False, 1)] ) -@pytest.mark.parametrize("target_options", ["--unpacked-api"]) -def test_mobilenet(use_calculated_workspaces_and_alignment, target_options): - use_calculated_workspaces = use_calculated_workspaces_and_alignment[0] - workspace_byte_alignment = use_calculated_workspaces_and_alignment[1] +def test_mobilenet(use_calculated_workspaces, workspace_byte_alignment): + use_unpacked_api = True + interface_api = "c" mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} output_list = generate_ref_data(mod, inputs, params) - input_list = [inputs["data"]] compile_and_run( mod, - input_list, + inputs, output_list, - target_options, + interface_api, + use_unpacked_api, use_calculated_workspaces, params, workspace_byte_alignment, @@ -339,9 +415,11 @@ def visit_call(self, call): @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) -@pytest.mark.parametrize("target_options", [""]) -def test_byoc_microtvm(use_calculated_workspaces, target_options): +def test_byoc_microtvm(use_calculated_workspaces): """This is a simple test case to check BYOC capabilities of AOT""" + use_unpacked_api = False + interface_api = "packed" + x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -379,18 +457,23 @@ def test_byoc_microtvm(use_calculated_workspaces, target_options): for _ in range(8): w_data.append(np.random.rand(10, 10).astype("float32")) - map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} - map_inputs["x"] = x_data + map_inputs = OrderedDict([("x", x_data)] + [("w{}".format(i), w_data[i]) for i in range(8)]) output_list = generate_ref_data(mod, map_inputs) input_list = [map_inputs["x"]] input_list.extend([map_inputs["w{}".format(i)] for i in range(8)]) compile_and_run( - mod, input_list, output_list, target_options, use_calculated_workspaces, mod_name="my_mod" + mod, + map_inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + mod_name="my_mod", ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_add_name_mangling_with_params(target_options): +@parametrize_aot_options +def test_add_name_mangling_with_params(interface_api, use_unpacked_api, use_calculated_workspaces): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) z = relay.add(x, y) @@ -403,27 +486,26 @@ def test_add_name_mangling_with_params(target_options): inputs = {"y": y_in} output_list = generate_ref_data(func, inputs, params) - input_list = [y_in] compile_and_run( func, - input_list, + inputs, output_list, - target_options, - use_calculated_workspaces=False, + interface_api, + use_unpacked_api, + use_calculated_workspaces, params=params, mod_name="my_mod", ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_multiple_models(target_options): +@parametrize_aot_options +def test_multiple_models(interface_api, use_unpacked_api, use_calculated_workspaces): # Identity model without params x = relay.var("x", "float32") mod1 = relay.Function([x], x) one = np.array(1.0, "float32") inputs1 = {"x": one} output_list1 = generate_ref_data(mod1, inputs1) - input_list1 = [one] params1 = None # Convolution model @@ -453,15 +535,20 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5), params2 = {"weight": weight_data} inputs2 = {"data": input_data} output_list2 = generate_ref_data(mod2, inputs2, params2) - input_list2 = [input_data] - input_list_map = {"mod1": input_list1, "mod2": input_list2} + input_list_map = {"mod1": inputs1, "mod2": inputs2} output_list_map = {"mod1": output_list1, "mod2": output_list2} mod_map = {"mod1": mod1, "mod2": mod2} param_map = {"mod1": params1, "mod2": params2} compile_and_run_multiple_models( - mod_map, input_list_map, output_list_map, target_options, param_map + mod_map, + input_list_map, + output_list_map, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + param_map, ) @@ -473,6 +560,10 @@ def test_quant_mobilenet_tfl(): import tvm.relay.testing.tf as tf_testing + interface_api = "packed" + use_unpacked_api = False + use_calculated_workspaces = True + tflite_model_file = tf_testing.get_workload_official( "https://storage.googleapis.com/download.tensorflow.org/" "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", @@ -486,12 +577,19 @@ def test_quant_mobilenet_tfl(): mod, params = convert_to_relay(tflite_model_buf, data, "input") inputs = {"input": data} output_list = generate_ref_data(mod, inputs, params) - input_list = [inputs["input"]] - compile_and_run(mod, input_list, output_list, "--unpacked-api=0", True, params) + compile_and_run( + mod, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + params=params, + ) -@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"]) -def test_transpose(target_options): +@parametrize_aot_options +def test_transpose(interface_api, use_unpacked_api, use_calculated_workspaces): """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place.""" dtype = "float32" @@ -506,11 +604,18 @@ def test_transpose(target_options): x_data = np.random.rand(10, 5).astype(dtype) y_data = np.random.rand(10, 5).astype(dtype) t_data = np.random.uniform(size=()).astype(dtype) - inputs = {"x": x_data, "y": y_data, "z": t_data} + inputs = {"x": x_data, "y": y_data, "z": t_data} output_list = generate_ref_data(func, inputs) - input_list = [inputs["x"], inputs["y"], inputs["z"]] - compile_and_run(func, input_list, output_list, target_options, True, enable_op_fusion=False) + compile_and_run( + func, + inputs, + output_list, + interface_api, + use_unpacked_api, + use_calculated_workspaces, + enable_op_fusion=False, + ) if __name__ == "__main__": diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index dca5dd6d4384..a6ea609be1e2 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -40,12 +40,13 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa (n, h, w, c) = dshape x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") - if method == "nearest_neighbor": - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) - else: - ref_res = tvm.topi.testing.bilinear_resize_python( - x_data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref_res = tvm.topi.testing.resize2d_python( + x_data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) @@ -87,7 +88,7 @@ def test_dyn_upsampling_infer_type_const(): @tvm.testing.uses_gpu def test_dyn_upsampling3d_run(): def verify_upsampling3d( - dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel" + dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="asymmetric" ): if layout == "NCDHW": @@ -98,16 +99,14 @@ def verify_upsampling3d( (n, d, h, w, c) = dshape x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32") - if method == "nearest_neighbor": - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d, scale_h, scale_w), layout - ) - else: - ref_res = tvm.topi.testing.trilinear_resize3d_python( - x_data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coord_trans, + ) + x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_d_var = relay.var("scale_d", relay.TensorType((), "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 78e2c232c08e..d3459afaab06 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -27,39 +27,40 @@ import tvm.testing -def test_resize_infer_type(): +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) size = relay.var("size", relay.TensorType((2,), "int8")) - z = relay.image.resize(x, size) + z = relay.image.resize2d(x, size) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") @tvm.testing.uses_gpu -def test_resize(): - def verify_resize(dshape, scale, method, layout): +def test_resize2d(): + def verify_resize2d(dshape, scale, method, layout): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) else: size = (dshape[2] * scale, dshape[3] * scale) size = np.array(size).astype("int64") x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) size_var = relay.var("size", relay.TensorType((2,), "int64")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) zz = run_infer_type(z) func = relay.Function([x, size_var], z) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) + for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) @@ -67,10 +68,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate()(x_data, size) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: - verify_resize((1, 4, 4, 4), 2, method, layout) - verify_resize((2, 8, 17, 20), 7, method, layout) + verify_resize2d((1, 4, 4, 4), 2, method, layout) + verify_resize2d((2, 8, 17, 20), 7, method, layout) if __name__ == "__main__": diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 13f5525bfee8..3f53c11fa36a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -496,13 +496,24 @@ def verify_any_conv2d( dilation, static_data_shape, ref_out_shape, + data_layout="NCHW", + kernel_layout="OIHW", use_cudnn=False, ): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype) - y = relay.nn.conv2d(data, kernel, strides, padding, dilation, kernel_size=kernel_shape[2:4]) + y = relay.nn.conv2d( + data, + kernel, + strides, + padding, + dilation, + kernel_size=kernel_shape[2:4] if kernel_layout == "OIHW" else kernel_shape[0:2], + data_layout=data_layout, + kernel_layout=kernel_layout, + ) mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) @@ -545,6 +556,28 @@ def test_any_conv2d(): (1, 64, 224, 224), use_cudnn=True, ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (1, 1), + (1, 224, 224, 64), + (1, 224, 224, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) + verify_any_conv2d( + (relay.Any(), 224, 224, 64), + (3, 3, 64, 64), + (1, 1), + (1, 1), + (2, 2), + (2, 224, 224, 64), + (2, 222, 222, 64), + data_layout="NHWC", + kernel_layout="HWIO", + ) def verify_any_conv2d_NCHWc( @@ -610,6 +643,63 @@ def test_any_conv2d_NCHWc(): ) +def verify_any_conv1d_transpose_ncw( + data_shape, + kernel_shape, + strides, + padding, + dilation, + groups, + static_data_shape, + ref_out_shape, + output_padding, +): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype) + y = relay.nn.conv1d_transpose( + data, + kernel, + strides, + padding, + dilation, + groups, + kernel_size=kernel_shape[2:], + output_padding=output_padding, + ) + mod["main"] = relay.Function([data, kernel], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_conv1d_transpose_ncw(): + verify_any_conv1d_transpose_ncw( + (relay.Any(), 64, 224), + (64, 192, 3), + (1,), + (1,), + (1,), + 1, + (2, 64, 224), + (2, 192, 224), + (0, 0), + ) + verify_any_conv1d_transpose_ncw( + (relay.Any(), 32, 224), + (32, 64, 3), + (2,), + (1,), + (1,), + 1, + (1, 32, 224), + (1, 64, 448), + (1, 1), + ) + + def verify_any_conv2d_transpose_nchw( data_shape, kernel_shape, @@ -830,6 +920,98 @@ def test_any_dense_dynamic_batch(): verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True) +def verify_any_batch_matmul( + x_shape, + y_shape, + out_shape, + x_var_shape, + y_var_shape, + dtype="float32", + trans_x=False, + trans_y=True, +): + x = relay.var("x", relay.TensorType(x_var_shape, dtype)) + y = relay.var("y", relay.TensorType(y_var_shape, dtype)) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + z = intrp.evaluate()(x_np, y_np) + tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) + + +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu +def test_any_batch_matmul(): + verify_any_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) * 3 + ) + + verify_any_batch_matmul( + (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + + verify_any_batch_matmul( + (1, 32, 16), (1, 16, 32), (1, 16, 16), (1, 32, 16), (relay.Any(),) * 3, trans_x=True + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 32, 16), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3, trans_y=False + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (5, 32, 16), + (relay.Any(),) * 3, + trans_x=True, + trans_y=False, + ) + verify_any_batch_matmul( + (1, 32, 16), + (1, 16, 32), + (1, 16, 16), + (relay.Any(), 32, 16), + (relay.Any(), 16, 32), + trans_x=True, + ) + verify_any_batch_matmul( + (5, 16, 32), + (5, 32, 16), + (5, 16, 16), + (relay.Any(), 16, 32), + (relay.Any(), 32, 16), + trans_y=False, + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (relay.Any(), 32, 16), + (relay.Any(), 32, 20), + trans_x=True, + trans_y=False, + ) + + @tvm.testing.uses_gpu def verify_any_pad(data_shape, pad_width, static_data_shape): mod = tvm.IRModule() @@ -899,6 +1081,72 @@ def test_any_softmax(): verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) +def verify_any_relu(data_shape, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + y = relay.nn.relu(data) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_relu(): + verify_any_relu(any_dims(3), (1, 2, 3), (1, 2, 3)) + verify_any_relu(any_dims(4), (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_prelu(data_shape, alpha, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + alpha = relay.const(np.array([alpha]), dtype=dtype) + y = relay.nn.prelu(data, alpha) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_prelu(): + verify_any_prelu(any_dims(3), 1, (1, 2, 3), (1, 2, 3)) + verify_any_prelu(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_leaky_relu(data_shape, alpha, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + y = relay.nn.leaky_relu(data, alpha) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_leaky_relu(): + verify_any_leaky_relu(any_dims(3), 0.1, (1, 2, 3), (1, 2, 3)) + verify_any_leaky_relu(any_dims(4), 0.2, (13, 11, 3, 1), (13, 11, 3, 1)) + + +def verify_any_bias_add(data_shape, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + bias = relay.const(np.random.randn(1), dtype=dtype) + y = relay.nn.bias_add(data, bias) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_bias_add(): + verify_any_bias_add(any_dims(3), (1, 2, 3), (1, 2, 3)) + verify_any_bias_add(any_dims(4), (13, 11, 3, 1), (13, 11, 3, 1)) + + def verify_any_topk(data_shape, kval, np_dshape, dtype, ret_type="indices", const_k=False): mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype=dtype) @@ -1275,7 +1523,7 @@ def test_any_ndarray_size(): verify_any_ndarray_size((1, 2, 3, 4)) -def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape): +def verify_any_resize2d(data_shape, scale, layout, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) @@ -1283,7 +1531,7 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap size = (data_shape[1] * scale, data_shape[2] * scale) else: size = (data_shape[2] * scale, data_shape[3] * scale) - y = relay.image.resize(data, size, layout) + y = relay.image.resize2d(data, size, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) check_result([data_np], mod, ref_out_shape, assert_shape=True) @@ -1291,14 +1539,14 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap @tvm.testing.uses_gpu def test_any_resize(): - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 4, 4, 4), scale=2, layout="NHWC", static_data_shape=(1, 4, 4, 4), ref_out_shape=(1, 8, 8, 4), ) - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 8, 17, 20), scale=3, layout="NCHW", diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py index 6a9d9a5cf0ad..5b0125b452c5 100644 --- a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py +++ b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py @@ -188,7 +188,7 @@ def test_conv2d(): def test_conv2d_winograd(): - mod, data, weight = get_relay_conv2d(kh=3, kw=3) + mod, data, weight = get_relay_conv2d(outc=128, kh=3, kw=3) tune_and_check(mod, data, weight) diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index cfbca40cf379..39596186d211 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -96,51 +96,61 @@ def get_network(name, batch_size=1, layout="NHWC"): @tvm.testing.requires_cuda -def test_task_extraction_cuda(): +@pytest.mark.parametrize( + "params", + [ + ("mlp", "NHWC", 1, 2), + ("resnet-18", "NHWC", 24, 25), + ("resnet-18", "NCHW", 24, 25), + ("mobilenet", "NHWC", 22, 30), + ("mobilenet", "NCHW", 22, 30), + ("resnet3d-18", "NCDHW", 23, 24), + ("resnet3d-18", "NDHWC", 23, 24), + ], +) +def test_task_extraction_cuda(params): target = tvm.target.Target("cuda") + network, layout, expected_task, expected_weights = params - mod, params = get_network("mlp") + mod, params = get_network(network, layout=layout) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - assert len(tasks) == 1 - assert sum(task_weights) == 2 - - for layout in ["NHWC", "NCHW"]: - mod, params = get_network("resnet-18", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 24 - assert sum(task_weights) == 25 - - mod, params = get_network("mobilenet", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 22 - assert sum(task_weights) == 30 - - for layout in ["NCDHW", "NDHWC"]: - mod, params = get_network("resnet3d-18", layout=layout) - tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) - - assert len(tasks) == 23 - assert sum(task_weights) == 24, sum(task_weights) - - -def test_task_extraction(): + for task, weight in zip(tasks, task_weights): + print(task.desc, task.workload_key, weight) + + assert len(tasks) == expected_task + assert sum(task_weights) == expected_weights + + +@pytest.mark.parametrize( + "params", + [ + # Relay FuseOps puts two conv2ds to separate functions and results in two tasks. + ("basic_func", 2, False), + # Relay FuseOps will not break the primitive function and result in one task. + ("fused_func", 1, False), + # The Relay function without complex ops will not form a task by default. + ("simple_func", 0, False), + # Every Relay function becomes a task regardless what ops in its body. + ("simple_func", 1, True), + # The Relay function without any reduce op is considered as a simple task. + ("shape_of_func", 0, False), + ("shape_of_func", 1, True), + # The Relay function with dynamic shape inputs/outputs will not be extracted. + ("dyn_shape_func", 0, False), + # The Conv2D in the Relay function with control flow could still be a task. + # Also, two identical Conv2D should only be one task with weight=2. + ("control_flow_func", 1, False), + # The first function with unsupported op (NMS) will not be extracted. + ("func_w_unsupported_op", 1, True), + ], +) +def test_task_extraction_cpu(params): ishape = (1, 3, 224, 224) w1shape = (32, 3, 3, 3) w2shape = (32, 32, 3, 3) dtype = "float32" target = tvm.target.Target("llvm") - def verify_task_extraction(func, expected_task, include_simple_tasks=False): - mod = tvm.IRModule.from_expr(func) - tasks, task_weights = auto_scheduler.extract_tasks( - mod["main"], None, target, include_simple_tasks=include_simple_tasks - ) - - assert len(tasks) == expected_task - assert len(task_weights) == expected_task - def get_func(): data = relay.var("data", shape=(ishape), dtype=dtype) weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype) @@ -182,13 +192,16 @@ def get_func_with_dynamic_shape(): def get_func_with_control_flow(): data = relay.var("data", shape=(1, 3, 224, 224)) - weight = relay.var("weight", shape=(32, 3, 3, 3)) + weight = relay.var("weight", shape=(3, 3, 3, 3)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") eq = relay.equal(eq1, eq2) - true_branch = relay.zeros(shape=(1, 32, 222, 222), dtype="float32") - false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=32) + true_branch = relay.zeros(shape=(1, 3, 224, 224), dtype="float32") + false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=3, padding=(1, 1)) + false_branch = relay.nn.conv2d( + false_branch, weight, kernel_size=(3, 3), channels=3, padding=(1, 1) + ) ife = relay.If(eq, true_branch, false_branch) out = relay.erf(ife) return relay.Function([data, weight, eq1, eq2], out) @@ -212,32 +225,28 @@ def get_postproc_func(): out = relay.Call(get_postproc_func(), [nms]) return relay.Function([cls_prob, loc_pred, anchors], out) - # Relay FuseOps puts two conv2ds to separate functions and results in two tasks. - verify_task_extraction(get_func(), 2) - - # By setting the function to primitive, Relay FuseOps will not break it and result in one task. - verify_task_extraction(get_fused_func(), 1) - - # The Relay function without complex ops will not form a task by default. - verify_task_extraction(get_simple_func(), 0) - - # Every Relay function becomes a task regardless what ops in its body. - verify_task_extraction(get_simple_func(), 1, True) - - # The Relay function without any reduce op is considered as a simple task. - verify_task_extraction(get_shape_of_func(), 0) - verify_task_extraction(get_shape_of_func(), 1, True) - - # The Relay function with dynamic shape inputs/outputs will not be extracted. - verify_task_extraction(get_func_with_dynamic_shape(), 0) + func_map = { + "basic_func": get_func, + "fused_func": get_fused_func, + "simple_func": get_simple_func, + "shape_of_func": get_shape_of_func, + "dyn_shape_func": get_func_with_dynamic_shape, + "control_flow_func": get_func_with_control_flow, + "func_w_unsupported_op": get_func_with_unsupported_op, + } + + def verify_task_extraction(func_name, expected_task, include_simple_tasks=False): + func = func_map[func_name]() + mod = tvm.IRModule.from_expr(func) + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], None, target, include_simple_tasks=include_simple_tasks + ) - # The Conv2D in the Relay function with control flow could still be a task. - verify_task_extraction(get_func_with_control_flow(), 1) + assert len(tasks) == expected_task + assert len(task_weights) == expected_task - # Func1 (with NMS) -> Func2 (injective). - verify_task_extraction(get_func_with_unsupported_op(), 1, True) + verify_task_extraction(*params) if __name__ == "__main__": - test_task_extraction_cuda() - test_task_extraction() + pytest.main([__file__]) diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 4ec1c21467fc..234095f67864 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import json @@ -50,17 +51,40 @@ def check_rts(expr, args, expected_result, mod=None): def test_add_op_scalar(): """ - Program: + test_add_op_scalar: fn (x, y) { return x + y; } """ - x = relay.var("x", shape=()) - y = relay.var("y", shape=()) + x = relay.var("x", shape=()) # Default to float32 + y = relay.var("y", shape=()) # Default to float32 func = relay.Function([x, y], add(x, y)) - x_data = np.array(10.0, dtype="float32") - y_data = np.array(1.0, dtype="float32") - check_rts(func, [x_data, y_data], x_data + y_data) + x_y_data = [ + (np.array(10.0, dtype="float32"), np.array(1.0, dtype="float32")), + (np.float32(10.0), np.float32(1.0)), + (10.0, 1.0), + ] + for (x_data, y_data) in x_y_data: + check_rts(func, [x_data, y_data], x_data + y_data) + + +def test_add_op_scalar_int(): + """ + test_add_op_scalar_int: + fn (x, y) { + return x + y; + } + """ + x = relay.var("x", shape=(), dtype="int32") + y = relay.var("y", shape=(), dtype="int32") + func = relay.Function([x, y], add(x, y)) + x_y_data = [ + (np.array(10.0, dtype="int32"), np.array(1.0, dtype="int32")), + (np.int32(10), np.int32(1)), + (10, 1), + ] + for (x_data, y_data) in x_y_data: + check_rts(func, [x_data, y_data], x_data + y_data) def test_add_op_tensor(): @@ -130,22 +154,22 @@ def test_plan_memory(): mod = relay.transform.FuseOps(0)(mod) func = mod["main"] mod = relay.transform.InferType()(mod) - smap = relay.backend._backend.GraphPlanMemory(func) + memory_plan = relay.backend._backend.GraphPlanMemory(func) storage_ids = set() device_types = set() storage_sizes = {} - for k, v in smap.items(): - assert len(v) == 3 - for x in v[0]: - storage_ids.add(x.value) - storage_sizes[x.value] = v[2] - for x in v[1]: - device_types.add(x.value) + + for k, v in memory_plan.expr_to_storage_info.items(): + for x in v.storage_ids: + storage_ids.add(x) + storage_sizes[x] = v.storage_sizes + for x in v.device_types: + device_types.add(x) # Current rule requires vars have unique storage id # because we don't do inplace, we will need another # two alternating temporary space. - assert len(storage_ids) == 4 + assert len(storage_ids) == 4, f"found storage_ids: {storage_ids}" assert len(device_types) == 1 assert len(storage_sizes) == 4 @@ -288,11 +312,4 @@ def test_graph_executor_nested_tuples(): if __name__ == "__main__": - test_reshape_nop() - test_plan_memory() - test_with_params() - test_add_op_scalar() - test_add_op_tensor() - test_add_op_broadcast() - test_gru_like() - test_compile_nested_tuples() + pytest.main([__file__]) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 84e2fa305bfe..645f86fadac0 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -15,21 +15,19 @@ # specific language governing permissions and limitations # under the License. """Unit tests for graph partitioning.""" + import os import sys +from collections import OrderedDict import numpy as np +import pytest import tvm -from tvm import te -import tvm.relay.testing -import tvm.relay.transform - -from tvm import relay -from tvm import runtime -from tvm.relay import transform +from tvm import relay, runtime from tvm.contrib import utils from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.annotation import compiler_begin, compiler_end +from aot.aot_test_utils import compile_and_run def update_lib(lib): @@ -48,37 +46,39 @@ def update_lib(lib): return lib -def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()): - if sys.platform == "win32": - print("Skip test on Windows for now") - return - - def check_vm_result(): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - exe = relay.vm.compile(mod, target=target) - code, lib = exe.save() - lib = update_lib(lib) - exe = runtime.vm.Executable.load_exec(code, lib) - vm = runtime.vm.VirtualMachine(exe, device) - out = vm.run(**map_inputs) - tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) +def check_vm_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + exe = relay.vm.compile(mod, target=target) + code, lib = exe.save() + lib = update_lib(lib) + exe = runtime.vm.Executable.load_exec(code, lib) + vm = runtime.vm.VirtualMachine(exe, device) + out = vm.run(**map_inputs) + tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) + + +def check_graph_executor_result( + mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() +): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, _ = relay.build(mod, target=target) + lib = update_lib(lib) + rt_mod = tvm.contrib.graph_executor.create(json, lib, device) - def check_graph_executor_result(): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - json, lib, _ = relay.build(mod, target=target) - lib = update_lib(lib) - rt_mod = tvm.contrib.graph_executor.create(json, lib, device) + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + rt_mod.run() + out = tvm.nd.empty(out_shape, device=device) + out = rt_mod.get_output(0, out) - for name, data in map_inputs.items(): - rt_mod.set_input(name, data) - rt_mod.run() - out = tvm.nd.empty(out_shape, device=device) - out = rt_mod.get_output(0, out) + tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) - tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) - check_vm_result() - check_graph_executor_result() +def check_aot_executor_result( + mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() +): + use_calculated_workspaces = True + compile_and_run(mod, map_inputs, [result], "packed", 0, use_calculated_workspaces) def set_external_func_attr(func, compiler, ext_symbol): @@ -88,7 +88,11 @@ def set_external_func_attr(func, compiler, ext_symbol): return func -def test_multi_node_subgraph(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_multi_node_subgraph(check_result): x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) w1 = relay.var("w1", shape=(10, 10)) @@ -138,8 +142,7 @@ def test_multi_node_subgraph(): for _ in range(8): w_data.append(np.random.rand(10, 10).astype("float32")) - map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} - map_inputs["x"] = x_data + map_inputs = OrderedDict([("x", x_data)] + [("w{}".format(i), w_data[i]) for i in range(8)]) check_result( mod, map_inputs, @@ -155,7 +158,11 @@ def test_multi_node_subgraph(): ) -def test_extern_gcc_single_op(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc_single_op(check_result): x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) @@ -172,7 +179,11 @@ def test_extern_gcc_single_op(): check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) -def test_extern_gcc_single_op_int(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc_single_op_int(check_result): x = relay.var("x", shape=(8, 8), dtype="int32") y = relay.var("y", shape=(8, 8), dtype="int32") @@ -189,7 +200,11 @@ def test_extern_gcc_single_op_int(): check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) -def test_extern_gcc(): +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.parametrize( + "check_result", [check_vm_result, check_graph_executor_result, check_aot_executor_result] +) +def test_extern_gcc(check_result): x = relay.var("x", shape=(2, 2)) y = relay.var("y", shape=(2, 2)) @@ -221,9 +236,17 @@ def test_extern_gcc(): x_data = np.random.rand(2, 2).astype("float32") y_data = np.random.rand(2, 2).astype("float32") - check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + inputs = OrderedDict( + [ + ("y", y_data), + ("x", x_data), + ] + ) + + check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data)) +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") def test_extern_gcc_consts(): @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") def constant_updater(expr, symbol): @@ -257,11 +280,13 @@ def constant_updater(expr, symbol): tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") -def test_extern_dnnl(): - if not tvm.get_global_func("relay.ext.dnnl", True): - print("skip because DNNL codegen is not available") - return - +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), + reason="skip because DNNL codegen is not available", +) +@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result]) +def test_extern_dnnl(check_result): dtype = "float32" ishape = (1, 32, 14, 14) w1shape = (32, 1, 3, 3) @@ -297,11 +322,13 @@ def test_extern_dnnl(): ) -def test_extern_dnnl_const(): - if not tvm.get_global_func("relay.ext.dnnl", True): - print("skip because DNNL codegen is not available") - return - +@pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") +@pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), + reason="skip because DNNL codegen is not available", +) +@pytest.mark.parametrize("check_result", [check_vm_result, check_graph_executor_result]) +def test_extern_dnnl_const(check_result): dtype = "float32" ishape = (1, 32, 14, 14) w1shape = (32, 1, 3, 3) @@ -349,7 +376,7 @@ def test_load_params_with_constants_in_ext_codegen(): zce = compiler_end(z, "ccompiler") mod["main"] = relay.Function([x, y], zce) mod["main"] = bind_params_by_name(mod["main"], params) - mod = transform.PartitionGraph()(mod) + mod = relay.transform.PartitionGraph()(mod) graph_module = relay.build(mod, target="llvm", params=params) # Params will be stored in metadata module. @@ -360,11 +387,4 @@ def test_load_params_with_constants_in_ext_codegen(): if __name__ == "__main__": - test_multi_node_subgraph() - test_extern_gcc_single_op() - test_extern_gcc_single_op_int() - test_extern_gcc() - test_extern_gcc_consts() - test_extern_dnnl() - test_extern_dnnl_const() - test_load_params_with_constants_in_ext_codegen() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index e2145f77b366..8d961eb60b18 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -62,10 +62,24 @@ def test_checkpoint(): check_grad(relay.Function(inputs, out_single)) +def verify_batch_matmul_grad(a_shape, b_shape, transpose_a, transpose_b): + tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32")) + tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32")) + check_grad( + relay.Function( + [tensor_a, tensor_b], + relay.op.nn.batch_matmul( + tensor_a, tensor_b, transpose_a=transpose_a, transpose_b=transpose_b + ), + ) + ) + + def test_batch_matmul_grad(): - x = relay.var("x", shape=(2, 3, 5), dtype="float64") - y = relay.var("y", shape=(2, 4, 5), dtype="float64") - check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) + verify_batch_matmul_grad((2, 3, 5), (2, 5, 4), False, False) + verify_batch_matmul_grad((2, 3, 5), (2, 4, 5), False, True) + verify_batch_matmul_grad((2, 5, 3), (2, 5, 4), True, False) + verify_batch_matmul_grad((2, 5, 3), (2, 4, 5), True, True) def test_reverse_reshape_grad(): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 24f0ed6642b5..eda7eac1b025 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -17,6 +17,7 @@ """ Support level10 operator test cases. """ import numpy as np +import pytest import tvm import tvm.testing import tvm.topi.testing @@ -325,17 +326,17 @@ def verify_reverse_reshape(shape, newshape, oshape): verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) -def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): +def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=False, trans_y=True): x = relay.var("x", relay.TensorType(x_shape, dtype)) y = relay.var("y", relay.TensorType(y_shape, dtype)) - z = relay.nn.batch_matmul(x, y) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) func = relay.Function([x, y], z) x_np = np.random.uniform(size=x_shape).astype(dtype) y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -353,37 +354,13 @@ def test_batch_matmul(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") - verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) - verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) - - -def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): - x = relay.var("x", relay.TensorType(x_shape, dtype)) - y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype)) - z = relay.nn.batch_matmul(x, y) - - func = relay.Function([x, y], z) - x_np = np.random.uniform(size=x_shape).astype(dtype) - y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) - - for target, dev in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - z = intrp.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) - - -# TODO(mbrookhart): enable once VM supports heterogenous execution -# @tvm.testing.uses_gpu -def test_dynamic_batch_matmul(): - verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) - verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) - verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) - verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) + verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((1, 32, 16), (1, 16, 32), (1, 16, 16), trans_x=True, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_x=False, trans_y=False) + verify_batch_matmul((5, 32, 16), (5, 32, 20), (5, 16, 20), trans_x=True, trans_y=False) @tvm.testing.uses_gpu @@ -616,15 +593,4 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 if __name__ == "__main__": - test_adaptive_pool() - test_collapse_sum_like() - test_broadcast_to() - test_broadcast_to_like() - test_slice_like() - test_reverse_reshape() - test_batch_matmul() - test_shape_of() - test_sequence_mask() - test_one_hot() - test_ndarray_size() - test_matrix_set_diag() + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 50fc0622ee6e..f05c5054415d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1448,13 +1448,15 @@ def get_shape(): align_corners=align_corners, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - ref = tvm.topi.testing.upsampling_python(data, (scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.bilinear_resize_python( - data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref = tvm.topi.testing.resize2d_python( + data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) @@ -1518,15 +1520,15 @@ def get_shape(): coordinate_transformation_mode=coordinate_transformation_mode, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - ref = tvm.topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.trilinear_resize3d_python( - data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref = tvm.topi.testing.resize3d_python( + data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coordinate_transformation_mode, + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) @@ -1535,9 +1537,9 @@ def get_shape(): @tvm.testing.uses_gpu def test_upsampling3d(): - _test_upsampling3d("NCDHW", "nearest_neighbor") + _test_upsampling3d("NCDHW", "nearest_neighbor", "asymmetric") _test_upsampling3d("NCDHW", "trilinear", "align_corners") - _test_upsampling3d("NDHWC", "nearest_neighbor") + _test_upsampling3d("NDHWC", "nearest_neighbor", "asymmetric") _test_upsampling3d("NDHWC", "trilinear", "align_corners") diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fc67f0b90295..95b0dfe96304 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1884,7 +1884,8 @@ def verify_scatter_nd_with_stack( ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ - relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype)) + for i, v in enumerate(indices_np) ] updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) @@ -1926,7 +1927,7 @@ def verify_scatter_nd_with_stack( out[1, :] += updates[0, :] out[0, :] += updates[1, :] out[0, :] += updates[2, :] - verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd(data, indices, updates, out, mode="add") verify_scatter_nd_with_stack(data, indices, updates, out) for mode in ["add", "update"]: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c4d26a1811b1..b59325aea2f9 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -189,7 +189,7 @@ def verify(x_np, y_np, cond_np): x_np = np.array(1.0, dtype) y_np = np.array(-1.0, dtype) - cond_np = np.array([1, 0, 1], dtype=np.bool) + cond_np = np.array([1, 0, 1], dtype=bool) verify(x_np, y_np, cond_np) @@ -201,7 +201,7 @@ def verify(x_np, y_np, cond_np): x_np = np.array([[1, 2], [3, 4]], dtype) y_np = np.array([[5, 6], [7, 8]], dtype) - cond_np = np.array([[1], [0]], dtype=np.bool) + cond_np = np.array([[1], [0]], dtype=bool) verify(x_np, y_np, cond_np) verify(x_np, y_np, cond_np.T) @@ -213,7 +213,7 @@ def verify(x_np, y_np, cond_np): verify(x_np, y_np, cond_np) x_np, y_np = np.ogrid[:3, :4] - cond_np = np.where(x_np < y_np, x_np, 10 + y_np).astype(np.bool) + cond_np = np.where(x_np < y_np, x_np, 10 + y_np).astype(bool) verify(x_np.astype(dtype), y_np.astype(dtype), cond_np) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e27520339f36..d93de5419f56 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -26,23 +26,72 @@ from tvm.relay.testing import run_infer_type -def test_resize_infer_type(): +def test_resize1d_infer_type(): + n, c, w = te.size_var("n"), te.size_var("c"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + tw = te.var("tw") + z = relay.image.resize1d(x, (tw,)) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, tw), "int8") + + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + z = relay.image.resize1d(x, (200,), "NCW", "linear", "align_corners") + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, 200), "int8") + + +@tvm.testing.uses_gpu +def test_resize1d(): + def verify_resize(dshape, scale, method, layout, coord_trans): + if layout == "NWC": + size = (dshape[1] * scale,) + else: + size = (dshape[2] * scale,) + + x_data = np.random.uniform(size=dshape).astype("float32") + + ref_res = tvm.topi.testing.resize1d_python(x_data, (scale,), layout, method, coord_trans) + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.image.resize1d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) + + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NWC", "NCW"]: + verify_resize((1, 4, 4), 2, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((3, 4, 5), 5, method, layout, coord_trans) + + +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = te.var("th"), te.var("tw") - z = relay.image.resize(x, (th, tw)) + z = relay.image.resize2d(x, (th, tw)) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z = relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners") + z = relay.image.resize2d(x, (100, 200), "NCHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): def verify_resize(dshape, scale, method, layout, coord_trans): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) @@ -51,25 +100,25 @@ def verify_resize(dshape, scale, method, layout, coord_trans): x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout, coord_trans) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, coordinate_transformation_mode=coord_trans) + z = relay.image.resize2d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) - for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) @@ -92,7 +141,7 @@ def test_resize3d_infer_type(): assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) - z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners") + z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8") @@ -107,10 +156,9 @@ def verify_resize(dshape, scale, method, layout): size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale) x_data = np.random.uniform(size=dshape).astype("float32") - if method == "trilinear": - ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout) + ref_res = tvm.topi.testing.resize3d_python( + x_data, (scale, scale, scale), layout, method, "align_corners" + ) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize3d(x, size, layout, method, "align_corners") assert "size=" in z.astext() @@ -123,9 +171,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["trilinear", "nearest_neighbor"]: - for layout in ["NDHWC", "NCDHW"]: - verify_resize((1, 4, 4, 4, 4), 2, method, layout) + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NDHWC", "NCDHW"]: + verify_resize((1, 4, 4, 4, 4), 2, method, layout) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_op_qnn_batch_matmul.py b/tests/python/relay/test_op_qnn_batch_matmul.py new file mode 100644 index 000000000000..91648aca3dbc --- /dev/null +++ b/tests/python/relay/test_op_qnn_batch_matmul.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_executor +from tvm.relay.testing.temp_op_attr import TempOpAttr + +# We use llvm target for testing functionality. `llvm` points to an older Intel +# generation machine, that legalizes to a simple lowering. Therefore, the +# legalization is overwritten such that it can be skipped and we use the +# QNNCanonicalizeOps lowering for the testing. +def legalize_qnn_batch_matmul(attrs, inputs, types): + return None + + +def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + "input_scale": input_scale, + "output_scale": output_scale, + "output_zero_point": output_zero_point, + "out_dtype": out_dtype, + } + return config + + +def make_configuration( + quantized_x, + quantized_y, + dtype, + x_shape, + y_shape, + x_zero_point, + y_zero_point, + x_scale, + y_scale, + output, + out_dtype="int32", + requantize=None, +): + config = { + "quantized_x": quantized_x, + "quantized_y": quantized_y, + "dtype": dtype, + "x_shape": x_shape, + "y_shape": y_shape, + "x_zero_point": x_zero_point, + "y_zero_point": y_zero_point, + "x_scale": x_scale, + "y_scale": y_scale, + "output": output, + "out_dtype": out_dtype, + "requantize": requantize, + } + return config + + +def make_int_configuration( + xzero_point_zero=True, yzero_point_zero=True, requantize_output=False, per_channel=False +): + x_shape, y_shape, output_shape = (1, 4, 5), (1, 3, 5), (1, 4, 3) + if xzero_point_zero == True: + x_zero_point = 0 + else: + x_zero_point = -123 + + if yzero_point_zero == True: + y_zero_point = 0 + else: + y_zero_point = -123 + + in_dtype = "int8" + out_dtype = "int32" if not requantize_output else "int8" + quantized_x_np = ( + np.array( + [ + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + 15, + -19, + -21, # sum = -1 + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + -17, + 17, + -21, + ] + ) # sum = 3 + .astype(in_dtype) + .reshape(x_shape) + ) + quantized_y_np = ( + np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9]) + .astype(in_dtype) + .reshape(y_shape) + ) + x_scale = 0.5 + y_scale = 0.5 + output_scale = 2.0 + + if requantize_output: + assert xzero_point_zero is True + assert yzero_point_zero is True + output = np.array([20, 51, 20, -26, -27, -26, 20, 51, 20, -14, -10, -14]) + elif xzero_point_zero is False and yzero_point_zero is False: + output = np.array( + [81960, 88360, 81960, 78400, 84540, 78400, 81960, 88360, 81960, 78984, 85164, 78984] + ) + elif xzero_point_zero is True and yzero_point_zero is False: + output = np.array([3240, 3490, 3240, -320, -330, -320, 3240, 3490, 3240, 264, 294, 264]) + elif xzero_point_zero is False and yzero_point_zero is True: + output = np.array([3240, 9640, 3240, 2878, 9018, 2878, 3240, 9640, 3240, 2970, 9150, 2970]) + else: + output = np.array([165, 415, 165, -197, -207, -197, 165, 415, 165, -105, -75, -105]) + + requant_params = ( + make_requantize_params(x_scale * y_scale, output_scale, -1, "int8") + if requantize_output + else None + ) + + output = output.astype(out_dtype).reshape(output_shape) + return make_configuration( + quantized_x=quantized_x_np, + quantized_y=quantized_y_np, + dtype=in_dtype, + x_shape=x_shape, + y_shape=y_shape, + x_zero_point=x_zero_point, + y_zero_point=y_zero_point, + x_scale=x_scale, + y_scale=y_scale, + output=output, + requantize=requant_params, + ) + + +def qnn_batch_matmul_driver(test_configuration): + in_dtype = test_configuration["dtype"] + out_dtype = test_configuration["out_dtype"] + quantized_x_name = "quantized_x" + quantized_y_name = "quantized_y" + expected_out_dtype = test_configuration["out_dtype"] + quantized_x = relay.var(quantized_x_name, shape=test_configuration["x_shape"], dtype=in_dtype) + quantized_y = relay.var(quantized_y_name, shape=test_configuration["y_shape"], dtype=in_dtype) + mod = relay.qnn.op.batch_matmul( + quantized_x, + quantized_y, + relay.const(test_configuration["x_zero_point"], "int32"), + relay.const(test_configuration["y_zero_point"], "int32"), + relay.const(test_configuration["x_scale"], "float32"), + relay.const(test_configuration["y_scale"], "float32"), + ) + if test_configuration["requantize"] is not None: + requantize_config = test_configuration["requantize"] + mod = relay.qnn.op.requantize( + mod, + input_scale=relay.const(requantize_config["input_scale"], "float32"), + input_zero_point=relay.const(0, "int32"), + output_scale=relay.const(requantize_config["output_scale"], "float32"), + output_zero_point=relay.const(requantize_config["output_zero_point"], "int32"), + out_dtype=requantize_config["out_dtype"], + ) + expected_out_dtype = requantize_config["out_dtype"] + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = tvm.IRModule.from_expr(mod) + mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with tvm.transform.PassContext(opt_level=2): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_executor.create(graph, lib, device=tvm.cpu(0)) + mod.set_input(quantized_x_name, test_configuration[quantized_x_name]) + mod.set_input(quantized_y_name, test_configuration[quantized_y_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).numpy() + np.testing.assert_equal(res, test_configuration["output"]) + assert res.dtype == expected_out_dtype + + +def test_qnn_batch_matmul_xzp0_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_xzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_with_requantized_output(): + with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int8_requantized_output_params = make_int_configuration(requantize_output=True) + qnn_batch_matmul_driver(int8_requantized_output_params) + + +if __name__ == "__main__": + test_qnn_batch_matmul_xzp0_yzp0() + test_qnn_batch_matmul_xzp0() + test_qnn_batch_matmul_yzp0() + test_qnn_batch_matmul() + test_qnn_batch_matmul_with_requantized_output() diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 345e8b815da1..322382ca002c 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -88,6 +88,20 @@ def test_float32_to_int8(): ) +def test_scalar_float32_to_int8(): + data = np.array(-63.5).astype("float32") + output = np.array(-128).astype("int8") + quant_args = {"out_zero_point": np.int32(-1), "out_scale": np.float32(0.5)} + quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="int8", + in_data=data, + verify_output_data=output, + ) + + def test_channelwise_axis_0(): data = ( np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) @@ -163,6 +177,7 @@ def test_dynamic_quantize(): if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() + test_scalar_float32_to_int8() test_channelwise_axis_0() test_channelwise_axis_1() test_dynamic_quantize() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index ad9805e74929..0f512df25cdf 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -92,6 +92,24 @@ def test_same_scale(): verify(mod, (golden_data, golden_output)) +def test_scalar_same_scale(): + # Have same scales, everything within range + golden_data = np.array(-10).astype("int32") + golden_output = golden_data + + for rounding in roundings: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) + + def test_downscale(): for rounding in roundings: mod = get_mod( @@ -437,6 +455,7 @@ def test_per_channel_different_scale(): if __name__ == "__main__": test_same_scale() + test_scalar_same_scale() test_downscale() test_upscale() test_non_power_of_two() diff --git a/tests/python/relay/test_op_qnn_simulated_dequantize.py b/tests/python/relay/test_op_qnn_simulated_dequantize.py index 64c70be9d7a7..e15fdf770c64 100644 --- a/tests/python/relay/test_op_qnn_simulated_dequantize.py +++ b/tests/python/relay/test_op_qnn_simulated_dequantize.py @@ -77,8 +77,8 @@ def verify_simulated_dequantize_simple(dtype): ) input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[]) - zp = relay.var("zp", shape=[]) - dtype = relay.var("dtype", shape=[]) + zp = relay.var("zp", shape=[], dtype="int32") + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) @@ -109,7 +109,7 @@ def test_dynamic_channels(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype, axis=0) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) @@ -150,7 +150,7 @@ def test_dynamic_dtype(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_dequantize(input_data, scale, zp, dtype) sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) np.testing.assert_allclose(sim_dq_out.numpy(), dq_out, rtol=1e-5) diff --git a/tests/python/relay/test_op_qnn_simulated_quantize.py b/tests/python/relay/test_op_qnn_simulated_quantize.py index 14014f2e4605..69ce261f6b09 100644 --- a/tests/python/relay/test_op_qnn_simulated_quantize.py +++ b/tests/python/relay/test_op_qnn_simulated_quantize.py @@ -85,8 +85,8 @@ def verify_simulated_quantize_simple(dtype): ) input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[]) - zp = relay.var("zp", shape=[]) - dtype = relay.var("dtype", shape=[]) + zp = relay.var("zp", shape=[], dtype="int32") + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) @@ -117,7 +117,7 @@ def test_dynamic_channels(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype, axis=0) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) @@ -159,7 +159,7 @@ def test_dynamic_dtype(): input_data = relay.var("input_data", shape=data.shape, dtype="float32") scale = relay.var("scale", shape=[relay.Any()], dtype="float32") zp = relay.var("zp", shape=[relay.Any()], dtype="int32") - dtype = relay.var("dtype", shape=[]) + dtype = relay.var("dtype", shape=[], dtype="int32") vm = build_simulated_quantize(input_data, scale, zp, dtype) sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) allclose_with_rounding(sim_q_out.numpy(), q_out) diff --git a/tests/python/relay/test_pass_annotate_spans_defuse.py b/tests/python/relay/test_pass_annotate_spans_defuse.py new file mode 100644 index 000000000000..def4a1da1b55 --- /dev/null +++ b/tests/python/relay/test_pass_annotate_spans_defuse.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for annotating spans.""" + + +import tvm +import tvm.relay as relay +from tvm.relay import testing +import tvm.testing + + +def test_annotate_spans_compatibility(): + data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) + weight = relay.var("weight") + + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + + simple_net = relay.nn.conv2d( + data=data, weight=weight, kernel_size=(3, 3), channels=3, padding=(1, 1) + ) + simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0] + simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) + + module, params = testing.create_workload(simple_net) + + # Apply some simple passes to legalize the IR. + with tvm.transform.PassContext(opt_level=0): + module, params = relay.optimize(module, tvm.testing.enabled_targets()[0][0], params) + + seq = tvm.transform.Sequential([relay.transform.AnnotateSpans(), relay.transform.DefuseOps()]) + with tvm.transform.PassContext(opt_level=3): + module = seq(module) + + +if __name__ == "__main__": + test_annotate_spans_compatibility() diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index f0949ab19f9c..c33bd5792242 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -42,14 +42,14 @@ def check_graph_executor( target, ref_res, device, func, params, config, opt_level, expected_index=None ): with tvm.transform.PassContext(opt_level=opt_level, config=config): - graph, lib, new_params = relay.build(func, target, params=params) + graph_executor_factory = relay.build(func, target, params=params) + contexts = [tvm.cpu(0), tvm.device(device)] - graph_json = json.loads(graph) + graph_json = json.loads(graph_executor_factory.graph_json) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] assert device_index == expected_index - mod = graph_executor.create(graph, lib, contexts) - mod.set_input(**new_params) + mod = graph_executor.GraphModule(graph_executor_factory["default"](*contexts)) mod.run() res = mod.get_output(0).numpy() tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) @@ -272,12 +272,14 @@ def check_storage_and_device_types(): smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] - for _, storage_dev_type in smap.items(): - assert len(storage_dev_type) == 3 - for sid in storage_dev_type[0]: + for _, storage_info in smap.expr_to_storage_info.items(): + + for sid in storage_info.storage_ids: storage_ids.append(sid.value) - for did in storage_dev_type[1]: + + for did in storage_info.device_types: device_types.append(did.value) + assert len(storage_ids) == 10 assert len(set(storage_ids)) == 8 assert len(set(device_types)) == 2 @@ -350,16 +352,16 @@ def expected(): assert tvm.ir.structural_equal(annotated_expr, expected_expr) smap = relay.backend._backend.GraphPlanMemory(annotated_expr) - for expr, storage_dev_type in smap.items(): + for expr, storage_info in smap.expr_to_storage_info.items(): # x is dev1 as output is dev1 if isinstance(expr, tvm.relay.expr.Var): - assert storage_dev_type[1][0] == dev1.device_type + assert storage_info.device_types[0] == dev1.device_type else: # device_copy op should be its dst_dev_type if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs): - assert storage_dev_type[1][0] == expr.attrs.dst_dev_type + assert storage_info.device_types[0] == expr.attrs.dst_dev_type else: - assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type + assert storage_info.device_types[0] == expected_dev_type[expr.op.name].device_type def run_fusible_network(dev, tgt): diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 88590c946e88..fafab3ee3584 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,24 +1797,24 @@ def expected(): _test_conv_reduce_convert_layout2() -def test_image_resize_convert_layout(): +def test_image_resize2d_convert_layout(): def _test_image_resize_convert_layout_nchw_to_nhwc(): def before(): x = relay.var("x", shape=(1, 2, 4, 4)) - y = relay.image.resize(x, (8, 8)) + y = relay.image.resize2d(x, (8, 8)) y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 2, 4, 4)) x = relay.layout_transform(x, "NCHW", "NHWC") - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.layout_transform(y, "NHWC", "NCHW") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NHWC"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1822,20 +1822,20 @@ def expected(): def _test_image_resize_convert_layout_nhwc_to_nchw(): def before(): x = relay.var("x", shape=(1, 4, 4, 2)) - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 4, 4, 2)) x = relay.layout_transform(x, "NHWC", "NCHW") - y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.image.resize2d(x, (8, 8), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NCHW"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1844,7 +1844,7 @@ def expected(): _test_image_resize_convert_layout_nhwc_to_nchw() -def test_conv_image_resize_convert_layout(): +def test_conv_image_resize2d_convert_layout(): """Check that layout transforms are propagated through image resize.""" def before(): @@ -1859,7 +1859,7 @@ def before(): data_layout="NHWC", kernel_layout="HWIO", ) - y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.image.resize2d(y, (112, 112), layout="NHWC") y = relay.Function(analysis.free_vars(y), y) return y @@ -1869,7 +1869,7 @@ def expected(): x = relay.layout_transform(x, "NHWC", "NCHW") w = relay.layout_transform(w, "HWIO", "OIHW") y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) - y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.image.resize2d(y, (112, 112), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(analysis.free_vars(y), y) return y diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 9f7f3deebeb8..962b7bebb12b 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -248,7 +248,7 @@ def verify_ones_zeros(shape, dtype): @tvm.testing.uses_gpu -def test_dynamic_to_static_resize(): +def test_dynamic_to_static_resize2d(): def verify_resize(shape, scale, method, layout): if layout == "NHWC": size = (shape[1] * scale, shape[2] * scale) @@ -258,7 +258,7 @@ def verify_resize(shape, scale, method, layout): x = relay.var("x", relay.TensorType(shape, "float32")) size_var = relay.const(np.array(size).astype("float32")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) @@ -267,17 +267,14 @@ def verify_resize(shape, scale, method, layout): zz = func2.body assert isinstance(zz, relay.Call) - assert zz.op == relay.op.get("image.resize") + assert zz.op == relay.op.get("image.resize2d") x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) - verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6) - - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: verify_resize((1, 4, 4, 4), 2, method, layout) @@ -347,7 +344,9 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): assert zz.op == relay.op.get("nn.upsampling") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale_h_val, scale_w_val), "NCHW", "nearest_neighbor", "asymmetric" + ) verify_func(func2, [x_data], ref_res) verify_upsampling((1, 16, 32, 32), 2, 2, "int8") @@ -371,8 +370,12 @@ def verify_upsampling3d(data_shape, scale_d_val, scale_h_val, scale_w_val, dtype assert zz.op == relay.op.get("nn.upsampling3d") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW" + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d_val, scale_h_val, scale_w_val), + "NCDHW", + "nearest_neighbor", + "asymmetric", ) verify_func(func2, [x_data], ref_res) diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3271379cf3ef..1e7d749ff418 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -22,6 +22,25 @@ from tvm import relay +def compare_fq_to_int(expr, args, allow_rounding_error=False): + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(*args).numpy() + + ex = relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + result_int = ex.evaluate()(*args).numpy() + + if allow_rounding_error: + assert np.all(np.abs(result - result_int) <= 1) + else: + assert np.array_equal(result, result_int) + + def test_fake_quantize_conv(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") @@ -35,23 +54,29 @@ def test_fake_quantize_conv(): ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() +def test_fake_quantize_dense(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - assert np.array_equal(result, result2) + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv(): @@ -65,23 +90,10 @@ def test_fake_transpose_quantize_conv(): op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv_bias_add(): @@ -97,24 +109,32 @@ def test_fake_transpose_quantize_conv_bias_add(): op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np, bias_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() +def test_fake_transpose_quantize_conv_bias_add_mismatch(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + two = relay.const(2.0) + zero = relay.const(0) - assert np.array_equal(result, result2) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np]) def test_fake_quantize_maxpool(): @@ -125,101 +145,121 @@ def test_fake_quantize_maxpool(): op = relay.op.nn.max_pool2d(x, [3, 3]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np], True) -def test_fake_quantize_avgpool(): +def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.op.reshape(x, [1, 3, -1]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_expand_dims(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.expand_dims(x, axis=1) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_squeeze(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.squeeze(x, axis=[0]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.all(np.abs(result - result2) <= 1) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_reshape(): +def test_fake_quantize_strided_slice(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.reshape(x, [1, 3, -1]) + op = relay.op.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 1, 112, 112]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_split(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.split(x, axis=3, indices_or_sections=2) + op = relay.qnn.op.quantize(op[0], relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() + op = relay.op.split(x, axis=3, indices_or_sections=[56, 112, 168]) + op = relay.qnn.op.quantize(op[1], relay.const(2.0), zero) - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_transpose_reshape(): +def test_fake_quantize_batch_flatten(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.transpose(x, [1, 0, 2, 3]) - op = relay.op.reshape(op, [3, -1]) + op = relay.op.nn.batch_flatten(x) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) def test_fake_quantize_concat(): @@ -234,24 +274,11 @@ def test_fake_quantize_concat(): concat = relay.op.concatenate(inputs, axis=1) out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) - mod = tvm.IRModule.from_expr(out) - mod = tvm.relay.transform.InferType()(mod) - inputs_np = [] for i in range(4): inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(*inputs_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(*inputs_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(out, inputs_np) def test_fake_quantize_clip(): @@ -261,19 +288,67 @@ def test_fake_quantize_clip(): op = relay.op.clip(x, 0, 6) op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +@pytest.mark.parametrize( + "operator", + [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], +) +def test_fake_quantize_binary(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(0)) + + y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8") + y = relay.qnn.op.dequantize(y, relay.const(0.2), relay.const(0)) + + op = operator(x, y) + if operator == relay.op.multiply: + out_scale = relay.const(20.0) + else: + out_scale = relay.const(0.1) + + op = relay.qnn.op.quantize(op, out_scale, relay.const(0), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np, y_np]) + + +@pytest.mark.parametrize( + "operator", + [ + relay.op.add, + relay.op.multiply, + relay.op.subtract, + relay.op.subtract, + relay.op.minimum, + relay.op.maximum, + ], +) +def test_fake_quantize_binary_const(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(10)) + + y = relay.const(1.0) + + op = operator(x, y) + op = relay.qnn.op.quantize(op, relay.const(0.1), relay.const(10), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_pad(): + x = relay.var("x", shape=[1, 383, 128], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(1.0), relay.const(10)) + op = relay.op.nn.pad(x, [[0, 0], [0, 1], [0, 0]], 0.0) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(10), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 383, 128], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 610d4e4e491b..321c74f4bbd8 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -251,9 +251,6 @@ def __init__(self, id): def exit_pass_ctx(self): events.append(self.id + " exit ctx") - def exit_pass_ctx(self): - events.append(self.id + " exit ctx") - @pass_instrument class PIBroken(PI): def __init__(self, id): diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index 1312b396fe4c..bcd69f7253ef 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -206,7 +206,7 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_dense(): - def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" M, K = data_shape N, _ = kernel_shape @@ -214,8 +214,8 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.dense(x, weight) y = relay.Function([x, weight], y) return y @@ -227,12 +227,12 @@ def legalize_dense(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk))) else: @@ -255,18 +255,28 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) # dense - _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False) - _test_legalize_dense((7, 16), (32, 16), (1, 0, 0)) - _test_legalize_dense((8, 15), (32, 15), (0, 1, 0)) - _test_legalize_dense((8, 16), (31, 16), (0, 0, 1)) - _test_legalize_dense((7, 15), (31, 15), (1, 1, 1)) - _test_legalize_dense((3, 16), (32, 16), (5, 0, 0)) - _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), dtype, False) + _test_legalize_dense((7, 16), (32, 16), (1, 0, 0), dtype) + _test_legalize_dense((8, 15), (32, 15), (0, 1, 0), dtype) + _test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype) + _test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype) + _test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype) + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False) + + _test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False) + _test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4") + _test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4") + _test_legalize_dense((8, 32), (31, 32), (0, 0, 1), "int4") + _test_legalize_dense((7, 31), (31, 31), (1, 1, 1), "int4") + _test_legalize_dense((3, 32), (32, 32), (5, 0, 0), "int4") + _test_legalize_dense((8, 16), (32, 16), (0, 16, 0), "int4") + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), "int4", False) @tvm.testing.uses_gpu def test_legalize_batch_matmul(): - def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True): + def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" B, M, _ = data_shape _, N, _ = kernel_shape @@ -274,8 +284,8 @@ def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True dm, dk, dn = pad_shape def before(): - x = relay.var("x", shape=data_shape, dtype="float16") - weight = relay.var("weight", shape=kernel_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) + weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.batch_matmul(x, weight) y = relay.Function([x, weight], y) return y @@ -287,12 +297,12 @@ def legalize_batch_matmul(attrs, inputs, types): def expected(): if not do_pad: return before() - x = relay.var("x", shape=data_shape, dtype="float16") + x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x - weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) else: @@ -314,13 +324,23 @@ def expected(): b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) - _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False) - _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0)) - _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0)) - _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1)) - _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1)) - _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0)) - _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False) + for dtype in ["float16", "int8"]: + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), dtype, False) + _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0), dtype) + _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0), dtype) + _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1), dtype) + _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1), dtype) + _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0), dtype) + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), dtype, False) + + _test_legalize_batch_matmul((16, 8, 32), (16, 32, 32), (0, 0, 0), "int4", False) + _test_legalize_batch_matmul((16, 7, 32), (16, 32, 32), (1, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 31), (16, 32, 31), (0, 1, 0), "int4") + _test_legalize_batch_matmul((16, 8, 32), (16, 31, 32), (0, 0, 1), "int4") + _test_legalize_batch_matmul((16, 7, 31), (16, 31, 31), (1, 1, 1), "int4") + _test_legalize_batch_matmul((16, 3, 32), (16, 32, 32), (5, 0, 0), "int4") + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 16, 0), "int4") + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), "int4", False) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 98d7161ae36c..29d420def184 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -339,8 +339,8 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) # Function that uses default compiler. Ops are fused in this function. @@ -367,6 +367,86 @@ def expected(): mod["main"] = f mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) + fused_mod = transform.FuseOps(2)(mod) + expected_mod = expected() + assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + + x_data = np.random.rand(8, 8).astype("float32") + y_data = np.random.rand(8, 8).astype("float32") + np_add = x_data + y_data + res = np.concatenate([np.log(np_add), np.exp(np_add)]) + check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res) + + +def test_extern_ccompiler_multiple_functions(): + def expected(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([x0, y0], add) + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [x, y]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(8, 8)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + main = relay.Function([x, y], fused_call) + mod["main"] = main + # define the second one + a = relay.var("a", shape=(16, 16)) + b = relay.var("b", shape=(16, 16)) + a0 = relay.var("a0", shape=(16, 16)) + b0 = relay.var("b0", shape=(16, 16)) + add = a0 + b0 + # Function that uses C compiler + func = relay.Function([a0, b0], add) + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_subfunction_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_subfunction_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [a, b]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(16, 16)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + sunfunction = relay.Function([a, b], fused_call) + mod["subfunction"] = sunfunction + mod = transform.InferType()(mod) + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f = relay.Function([x, y], concat) + mod = tvm.IRModule() + mod["main"] = f + # define second function + a = relay.var("a", shape=(16, 16)) + b = relay.var("b", shape=(16, 16)) + add = a + b + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f2 = relay.Function([a, b], concat) + mod["subfunction"] = f2 + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() @@ -416,8 +496,8 @@ def expected(): out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) func = relay.Function([data0, input0], out) - func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") - glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_main_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_main_0") mod = tvm.IRModule() mod[glb_var] = func mod = transform.InferType()(mod) @@ -532,8 +612,8 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_2") - gv0 = relay.GlobalVar("tvmgen_default_test_compiler_2") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_main_2") + gv0 = relay.GlobalVar("tvmgen_default_test_compiler_main_2") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -544,8 +624,8 @@ def expected(): data=data1, weight=weight1, kernel_size=(3, 3), channels=16, padding=(1, 1) ) func1 = relay.Function([data1, weight1], conv) - func1 = set_func_attr(func1, "test_compiler", "tvmgen_default_test_compiler_0") - gv1 = relay.GlobalVar("tvmgen_default_test_compiler_0") + func1 = set_func_attr(func1, "test_compiler", "tvmgen_default_test_compiler_main_0") + gv1 = relay.GlobalVar("tvmgen_default_test_compiler_main_0") mod[gv1] = func1 mod = transform.InferType()(mod) @@ -613,7 +693,7 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "tvmgen_default_test_compiler_main_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -643,8 +723,8 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([y0], add) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = relay.transform.InferType()(mod) add_call = relay.Call(glb_0, [y]) @@ -733,8 +813,8 @@ def expected(): tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_0") - gv0 = relay.GlobalVar("tvmgen_default_test_target_0") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_main_0") + gv0 = relay.GlobalVar("tvmgen_default_test_target_main_0") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -796,8 +876,8 @@ def expected(): f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - func1 = set_func_attr(func1, "test_target", "tvmgen_default_test_target_0") - gv1 = relay.GlobalVar("tvmgen_default_test_target_0") + func1 = set_func_attr(func1, "test_target", "tvmgen_default_test_target_main_0") + gv1 = relay.GlobalVar("tvmgen_default_test_target_main_0") mod[gv1] = func1 mod = relay.transform.InferType()(mod) @@ -806,8 +886,8 @@ def expected(): f2_cb4 = relay.var("test_target_1_i1", shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_1") - gv0 = relay.GlobalVar("tvmgen_default_test_target_1") + func0 = set_func_attr(func0, "test_target", "tvmgen_default_test_target_main_1") + gv0 = relay.GlobalVar("tvmgen_default_test_target_main_1") mod[gv0] = func0 mod = relay.transform.InferType()(mod) @@ -955,8 +1035,8 @@ def expected_same_output_region(): mul = log * sub # The partitioned graph contains log, subtract, and multiply func = relay.Function([x0, y0], mul) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -977,8 +1057,8 @@ def expected_different_output_region(): i0 = relay.var("i0", shape=(8, 8)) log = relay.log(i0) func = relay.Function([i0], log) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_0") - glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_0") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_0") + glb_0 = relay.GlobalVar("tvmgen_default_ccompiler_main_0") mod[glb_0] = func mod = transform.InferType()(mod) @@ -987,8 +1067,8 @@ def expected_different_output_region(): y0 = relay.var("y0", shape=(8, 8)) sub = x0 - y0 func = relay.Function([x0, y0], sub) - func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_1") - glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_1") + func = set_func_attr(func, "ccompiler", "tvmgen_default_ccompiler_main_1") + glb_1 = relay.GlobalVar("tvmgen_default_ccompiler_main_1") mod[glb_1] = func mod = transform.InferType()(mod) @@ -1063,8 +1143,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1140,8 +1220,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1216,7 +1296,7 @@ def create_graph(): partitioned = seq(create_graph()) - concat = partitioned["tvmgen_default_const_tuples_0"].body + concat = partitioned["tvmgen_default_const_tuples_main_0"].body assert type(concat.args[1]) == relay.Tuple assert type(concat.args[2]) == relay.Tuple assert type(concat.args[3]) == relay.Constant @@ -1266,8 +1346,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_0") - gv0 = relay.GlobalVar("tvmgen_default_" + target + "_0") + func0 = func0.with_attr("global_symbol", "tvmgen_default_" + target + "_main_0") + gv0 = relay.GlobalVar("tvmgen_default_" + target + "_main_0") mod[gv0] = func0 mod = transform.InferType()(mod) @@ -1349,9 +1429,9 @@ def Optimize(mod): mod = transform.PartitionGraph()(mod) try: - t0 = mod["tvmgen_default_test_target_0"] + t0 = mod["tvmgen_default_test_target_main_0"] except: - raise KeyError("test_target_0 not found") + raise KeyError("test_target_main_0 not found") assert isinstance(t0.body, relay.Constant) expected = np.empty([2, 2]) @@ -1359,10 +1439,39 @@ def Optimize(mod): tvm.testing.assert_allclose(t0.body.data.numpy(), expected, rtol=1e-5, atol=1e-5) +def test_preserve_type_import(): + """Test to make sure type definition and imports are preserved during the BYOC pipeline.""" + from tvm.relay.prelude import Prelude, StaticTensorArrayOps + + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + write = p.get_global_var_static("tensor_array_write", dtype, shape) + gather = p.get_global_var_static("tensor_array_gather", dtype, shape) + v = relay.var("v") + indice = relay.var("indice") + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + out = gather(tensor_array3, indice) + mod["main"] = relay.Function([v, indice], out) + mod = transform.RemoveUnusedFunctions()(mod) + mod = transform.PartitionGraph()(mod) + + run("float32", [2, 3]) + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() test_extern_ccompiler_default_ops() + test_extern_ccompiler_multiple_functions() test_extern_ccompiler() test_extern_dnnl() test_extern_dnnl_mobilenet() @@ -1379,3 +1488,4 @@ def Optimize(mod): test_flatten_tuple_output() test_tuple_output_exec() test_extern_opt() + test_static_tensor_array_gather_partition() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 151e5ecc160b..6c229064b094 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -647,13 +647,39 @@ def test_add_op_scalar(): } """ mod = tvm.IRModule() - x = relay.var("x", shape=()) - y = relay.var("y", shape=()) + x = relay.var("x", shape=()) # Default to float32 + y = relay.var("y", shape=()) # Default to float32 func = relay.Function([x, y], relay.op.add(x, y)) - x_data = np.array(10.0, dtype="float32") - y_data = np.array(1.0, dtype="float32") - mod["main"] = func - check_result([x_data, y_data], x_data + y_data, mod=mod) + x_y_data = [ + (np.array(10.0, dtype="float32"), np.array(1.0, dtype="float32")), + (np.float32(10.0), np.float32(1.0)), + (10.0, 1.0), + ] + for (x_data, y_data) in x_y_data: + mod["main"] = func + check_result([x_data, y_data], x_data + y_data, mod=mod) + + +@tvm.testing.uses_gpu +def test_add_op_scalar_int(): + """ + test_add_op_scalar_int: + fn (x, y) { + return x + y; + } + """ + mod = tvm.IRModule() + x = relay.var("x", shape=(), dtype="int32") + y = relay.var("y", shape=(), dtype="int32") + func = relay.Function([x, y], relay.op.add(x, y)) + x_y_data = [ + (np.array(10.0, dtype="int32"), np.array(1.0, dtype="int32")), + (np.int32(10), np.int32(1)), + (10, 1), + ] + for (x_data, y_data) in x_y_data: + mod["main"] = func + check_result([x_data, y_data], x_data + y_data, mod=mod) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py index 31a7e85113ab..eb657a329889 100644 --- a/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py +++ b/tests/python/topi/python/test_topi_batch_matmul_tensorcore.py @@ -30,33 +30,71 @@ } -def verify_batch_matmul(x_batch, y_batch, M, N, K): - x = te.placeholder((x_batch, M, K), name="x") - y = te.placeholder((y_batch, N, K), name="y") - dtype = x.dtype +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + B, K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(B, K, L // 8), dtype=np.int32) + for b in range(B): + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[b, k, l] = a_int4[b, k, l] | ( + (a_int32[b, k, l * 8 + m] & 0xF) << ((7 - m) * 4) + ) + return a_int4 + + +def verify_batch_matmul(x_batch, y_batch, M, N, K, dtype): + x = te.placeholder((x_batch, M, K), name="x", dtype=dtype) + y = te.placeholder((y_batch, N, K), name="y", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) - b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) - c_np = tvm.topi.testing.batch_matmul(a_np, b_np) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(x_batch, M, K)) + b_np = np.random.randint(low=-8, high=7, size=(y_batch, N, K)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype) + else: + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) + c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype) return (a_np, b_np, c_np) # get the test data a_np, b_np, c_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) with tvm.target.Target(device): fcompute, fschedule = tvm.topi.testing.dispatch(device, _batch_matmul_implement) - out = fcompute(x, y) + out = fcompute(x, y, None, out_dtype) s = fschedule([out]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), dev) - f = tvm.build(s, [x, y, out], device, name="dense") + c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev) + f = tvm.build(s, [x, y, out], device, name="batch_matmul") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) @@ -65,10 +103,11 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_batch_matmul(): - verify_batch_matmul(1, 1, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 16, 32) - verify_batch_matmul(5, 5, 16, 32, 32) - verify_batch_matmul(30, 30, 16, 32, 32) + for dtype in ["float16", "int8", "int4"]: + verify_batch_matmul(1, 1, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 16, 32, dtype) + verify_batch_matmul(5, 5, 16, 32, 32, dtype) + verify_batch_matmul(30, 30, 16, 32, 32, dtype) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index 8dbe94b45a2f..2a4865c6dd8d 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -16,13 +16,15 @@ # under the License. """Example code to do convolution.""" +import sys + +import pytest import numpy as np + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +from tvm import autotvm, te, topi import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import cudnn from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.topi.nn.conv2d import _get_workload @@ -30,238 +32,272 @@ import tvm.testing +dtype = tvm.testing.parameter("float32") -def verify_conv2d_nchw( - batch, - in_channel, - in_size, - num_filter, - kernel, - stride, - padding, - dilation=1, - add_bias=False, - add_relu=False, - use_cudnn=False, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - print( - "Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation) - ) +@tvm.testing.fixture +def input_shape(batch, in_channel, in_size): + return (batch, in_channel, in_size, in_size) - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name="A") - W = te.placeholder((num_filter, in_channel, kernel, kernel), name="W") - bias = te.placeholder((num_filter, 1, 1), name="bias") - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - bias_shape = get_const_tuple(bias.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = np.random.uniform(size=bias_shape).astype(dtype) - dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - if add_bias: - c_np += b_np - if add_relu: - c_np = np.maximum(c_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(c_np.shape) - wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) - - # check if tile_ow candidates are the factors of the right output weight. - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - tvm.testing.assert_allclose(ow_tile, out_width) +@tvm.testing.fixture +def weight_shape(num_filter, in_channel, kernel): + return (num_filter, in_channel, kernel, kernel) - def check_target(target): - dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - print("Running on target: %s" % target) - if "cudnn" in target: - fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn - else: - fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) +@tvm.testing.fixture +def bias_shape(num_filter): + return (num_filter, 1, 1) - with tvm.target.Target(target): - if "cudnn" in target: - C = fcompute( - A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype - ) + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + input_shape, + weight_shape, + bias_shape, + dtype, + stride, + padding, + dilation, + add_bias, + apply_relu, +): + a_np = np.random.uniform(size=input_shape).astype(dtype) + w_np = np.random.uniform(size=weight_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + + if add_bias: + c_np = c_np + b_np + if apply_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + +class BaseConv2DTests: + add_bias = tvm.testing.parameter(False) + apply_relu = tvm.testing.parameter(False) + dilation = tvm.testing.parameter(1) + batch = tvm.testing.parameter(1) + + def test_conv2d_nchw( + self, + target, + dev, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dtype, + ref_data, + dilation, + add_bias, + apply_relu, + ): + target = tvm.target.Target(target) + is_cudnn_target = target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []) + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + + a_np, w_np, b_np, c_np = ref_data + + A = te.placeholder(a_np.shape, name="A", dtype=dtype) + W = te.placeholder(w_np.shape, name="W", dtype=dtype) + bias = te.placeholder(b_np.shape, name="bias") + + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + if is_cudnn_target: + fcompute, fschedule = topi.cuda.conv2d_cudnn, topi.cuda.schedule_conv2d_cudnn else: - C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C = topi.add(C, bias) - if add_relu: - C = topi.nn.relu(C) - s = fschedule([C]) - - if "llvm" in target: - verify_workload_padding() - - a = tvm.nd.array(a_np, dev) - w = tvm.nd.array(w_np, dev) - b = tvm.nd.array(b_np, dev) - - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) - if add_bias: + fcompute, fschedule = tvm.topi.testing.get_conv2d_nchw_implement(target) + + with target: + if is_cudnn_target: + C = fcompute( + A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype + ) + else: + C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if apply_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) func = tvm.build( s, [A, W, bias, C], target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" + name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, b, c) - else: - func = tvm.build( - s, - [A, W, C], - target, - name="relu_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), - ) - func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + target, + input_shape, + weight_shape, + stride, + padding, + dilation, + dtype, + ref_data, + ): + a_np, w_np, b_np, c_np = ref_data + _, _, out_height, out_width = c_np.shape + + A = te.placeholder(input_shape, name="A", dtype=dtype) + W = te.placeholder(weight_shape, name="W", dtype=dtype) - for target, dev in tvm.testing.enabled_targets(): - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - check_target(target) - - if use_cudnn: - check_target("cuda -model=unknown -libs=cudnn") - if ("opencl", tvm.device("opencl")) in tvm.testing.enabled_targets(): - check_target("opencl -device=intel_graphics") - - -@tvm.testing.uses_gpu -def test_conv2d_nchw(): - # ResNet18 workloads - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - - # bias, relu - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) - - # dilation = 2 - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2) - - # batch size - verify_conv2d_nchw(4, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(9, 64, 56, 64, 3, 1, 1) - - # weird workloads - verify_conv2d_nchw(2, 2, 2, 2, 2, 2, 2) - verify_conv2d_nchw(3, 3, 3, 3, 3, 3, 3) - verify_conv2d_nchw(4, 4, 4, 4, 4, 4, 4) - verify_conv2d_nchw(5, 5, 5, 5, 5, 5, 5) - verify_conv2d_nchw(6, 6, 6, 6, 6, 6, 6) - - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1) - # verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2) - # verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1) - - # inception v3 workloads - verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0) - verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0) - verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0) - verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0) - verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1) - verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0) - verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0) - verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0) - verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0) - verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) - verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0) - # disable these tests due to some bugs of llvm with nvptx - # verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0) - verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) - verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0) - verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3) - verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0) - verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0) - verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1) - verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0) - verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0) - verify_conv2d_nchw(1, 1024, 19, 84, 3, 1, 1) - verify_conv2d_nchw(1, 2048, 10, 126, 3, 1, 1) - verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1) - verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1) - - # Asymmetric padding - verify_conv2d_nchw(1, 3, 35, 64, 7, 2, (0, 0, 1, 1)) - verify_conv2d_nchw(1, 64, 8, 128, 3, 1, (3, 3, 2, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, (1, 2, 2, 1)) - verify_conv2d_nchw(1, 64, 17, 192, 1, 1, (1, 2)) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (3, 1)) - verify_conv2d_nchw(1, 128, 8, 384, 3, 1, (0, 2)) - verify_conv2d_nchw(1, 64, 35, 64, 3, 1, (1, 2), use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 1, 1, "VALID") - verify_conv2d_nchw(1, 388, 8, 64, 3, 1, "VALID") - verify_conv2d_nchw(1, 64, 10, 48, 3, 1, "VALID", use_cudnn=True) - verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME") - verify_conv2d_nchw(1, 64, 5, 32, 2, 1, "SAME") - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True) - verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "VALID", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 64, 8, 64, 24, 1, "SAME", add_bias=True, add_relu=True) - verify_conv2d_nchw(1, 32, 35, 64, 7, 2, (0, 0, 2, 2)) + with tvm.target.Target(target): + wkl = _get_workload(A, W, (stride, stride), padding, dilation, dtype) + + # check if tile_ow candidates are the factors of the right output weight. + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestResNet18Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 224, 64, 7, 2, 3), + (64, 56, 64, 3, 1, 1), + (64, 56, 64, 1, 1, 0), + (64, 56, 128, 3, 2, 1), + (64, 56, 128, 1, 2, 0), + (128, 28, 128, 3, 1, 1), + (128, 28, 256, 3, 2, 1), + (128, 28, 256, 1, 2, 0), + (256, 14, 256, 3, 1, 1), + (256, 14, 512, 3, 2, 1), + (256, 14, 512, 1, 2, 0), + (512, 7, 512, 3, 1, 1), + ) + + +class TestInceptionV3Workloads(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 299, 32, 3, 2, 0), + (32, 149, 32, 3, 1, 0), + (32, 147, 64, 3, 1, 1), + (64, 73, 80, 1, 1, 0), + (80, 73, 192, 3, 1, 0), + (192, 35, 64, 1, 1, 0), + (192, 35, 48, 1, 1, 0), + (48, 35, 64, 5, 1, 2), + (64, 35, 96, 3, 1, 1), + (96, 35, 96, 3, 1, 1), + (192, 35, 32, 1, 1, 0), + (256, 35, 64, 1, 1, 0), + (256, 35, 48, 1, 1, 0), + (288, 35, 64, 1, 1, 0), + (288, 35, 48, 1, 1, 0), + (288, 35, 384, 3, 2, 0), + (96, 35, 96, 3, 2, 0), + (768, 17, 192, 1, 1, 0), + (768, 17, 128, 1, 1, 0), + (128, 17, 128, 1, 1, 0), + (128, 17, 192, 7, 1, 3), + (128, 17, 128, 7, 1, 3), + (128, 17, 192, 1, 1, 0), + (768, 17, 160, 1, 1, 0), + # disable these tests due to some bugs of llvm with nvptx + # (160, 17, 160, 1, 1, 0), + (160, 17, 192, 7, 1, 3), + (160, 17, 160, 7, 1, 3), + (160, 17, 192, 1, 1, 0), + (192, 17, 192, 1, 1, 0), + (192, 17, 192, 7, 1, 3), + (192, 17, 320, 3, 2, 0), + (192, 17, 192, 3, 2, 0), + (1280, 8, 320, 1, 1, 0), + (1280, 8, 384, 1, 1, 0), + (384, 8, 384, 1, 1, 0), + (384, 8, 384, 3, 1, 1), + (1280, 8, 448, 1, 1, 0), + (448, 8, 384, 3, 1, 1), + (1280, 8, 192, 1, 1, 0), + (2048, 8, 320, 1, 1, 0), + (2048, 8, 384, 1, 1, 0), + (2048, 8, 448, 1, 1, 0), + (2048, 8, 192, 1, 1, 0), + (1024, 19, 84, 3, 1, 1), + (2048, 10, 126, 3, 1, 1), + (512, 5, 126, 3, 1, 1), + (256, 3, 126, 3, 1, 1), + ) + + +class TestWeirdWorkloads(BaseConv2DTests): + batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (2, 2, 2, 2, 2, 2, 2), + (3, 3, 3, 3, 3, 3, 3), + (4, 4, 4, 4, 4, 4, 4), + (5, 5, 5, 5, 5, 5, 5), + (6, 6, 6, 6, 6, 6, 6), + # disable these tests due to some bugs of llvm with nvptx + # (1, 1, 1, 1, 1, 1, 1), + # (2, 13, 71, 59, 3, 1, 1), + ) + + +class TestAsymmetricPadding(BaseConv2DTests): + dilation = tvm.testing.parameter(1, 2) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (3, 35, 64, 7, 2, (0, 0, 1, 1)), + (64, 8, 128, 3, 1, (3, 3, 2, 2)), + (64, 8, 64, 1, 1, (1, 2, 2, 1)), + (64, 17, 192, 1, 1, (1, 2)), + (64, 8, 64, 3, 1, (3, 1)), + (128, 8, 384, 3, 1, (0, 2)), + (64, 35, 64, 3, 1, (1, 2)), + (64, 8, 64, 1, 1, "VALID"), + (388, 8, 64, 3, 1, "VALID"), + (64, 10, 48, 3, 1, "VALID"), + (512, 19, 64, 1, 1, "SAME"), + (64, 5, 32, 2, 1, "SAME"), + (64, 8, 64, 3, 1, "SAME"), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + (32, 35, 64, 7, 2, (0, 0, 2, 2)), + ) + + +class TestBatchSize(BaseConv2DTests): + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + ) + batch = tvm.testing.parameter(1, 4, 9) + + +class TestBiasRelu(BaseConv2DTests): + add_relu = tvm.testing.parameter(True, False) + add_bias = tvm.testing.parameter(True, False) + in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( + (64, 56, 64, 3, 1, 1), + (64, 8, 64, 3, 1, (1, 2, 2, 1)), + (64, 8, 64, 5, 2, (1, 3)), + (64, 8, 64, 3, 1, "VALID"), + (64, 8, 64, 24, 1, "SAME"), + ) if __name__ == "__main__": - test_conv2d_nchw() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index cdb7c0e8d4aa..eb4c5a343b58 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -34,6 +34,14 @@ topi.arm_cpu.conv2d_nhwc_spatial_pack, topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, ), + "mali": ( + topi.mali.conv2d_nhwc_spatial_pack, + topi.mali.schedule_conv2d_nhwc_spatial_pack, + ), + "bifrost": ( + topi.mali.conv2d_nhwc_spatial_pack, + topi.mali.schedule_conv2d_nhwc_spatial_pack, + ), "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), } @@ -58,25 +66,21 @@ def get_ref_data(): a_np, w_np, b_np = get_ref_data() - def check_device(device): - if not tvm.testing.device_enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.Target(device): - fcompute, fschedule = tvm.topi.testing.dispatch(device, _conv2d_nhwc_implement) + def check_device(target, dev): + print("Running on target: %s" % target) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv2d_nhwc_implement) B = fcompute(A, W, stride, padding, dilation, dtype) s = fschedule([B]) - dev = tvm.device(device, 0) a = tvm.nd.array(a_np, dev) w = tvm.nd.array(w_np, dev) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) - func = tvm.build(s, [A, W, B], device) + func = tvm.build(s, [A, W, B], target) func(a, w, b) tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) - for device in ["llvm", "cuda"]: - check_device(device) + for target, dev in tvm.testing.enabled_targets(): + check_device(target, dev) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 235a09400387..964c1621fa47 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -135,6 +135,7 @@ def test_dense( @pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")]) +@tvm.testing.requires_gpu def test_dense_cuda_int8( target, dev, diff --git a/tests/python/topi/python/test_topi_dense_tensorcore.py b/tests/python/topi/python/test_topi_dense_tensorcore.py index a3657af2c1ca..7e7d3f2209d3 100644 --- a/tests/python/topi/python/test_topi_dense_tensorcore.py +++ b/tests/python/topi/python/test_topi_dense_tensorcore.py @@ -29,40 +29,94 @@ _dense_implement = {"gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)]} -def verify_dense(batch, in_dim, out_dim, use_bias=True): +def convert_int32_into_int4(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + K, L = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(K, L // 8), dtype=np.int32) + for k in range(K): + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[k, l] = a_int4[k, l] | ((a_int32[k, l * 8 + m] & 0xF) << ((7 - m) * 4)) + return a_int4 + + +def convert_int32_into_int4_bias(a_int32): + """convert int32 values into int4 + Parameters + ---------- + a_int32 : int + + Return + ------ + a_int4 : int + """ + (L,) = a_int32.shape + assert L % 8 == 0 + a_int4 = np.zeros(shape=(L // 8), dtype=np.int32) + for l in range(L // 8): + for m in range(min(8, L - l * 8)): + a_int4[l] = a_int4[l] | ((a_int32[l * 8 + m] & 0xF) << ((7 - m) * 4)) + return a_int4 + + +def verify_dense(batch, in_dim, out_dim, dtype, use_bias=True): """Dense tensorcore verify function""" - A = te.placeholder((batch, in_dim), name="A") - B = te.placeholder((out_dim, in_dim), name="B") - C = te.placeholder((out_dim,), name="C") - dtype = A.dtype + A = te.placeholder((batch, in_dim), name="A", dtype=dtype) + B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype) + C = te.placeholder((out_dim,), name="C", dtype=dtype) + + assert dtype in ["int4", "int8", "float16"] + + out_dtype = "float32" + if dtype in ["int8", "int4"]: + out_dtype = "int32" # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_dense_tensorcore") def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) + if dtype == "int4": + a_np = np.random.randint(low=-8, high=7, size=(batch, in_dim)) + b_np = np.random.randint(low=-8, high=7, size=(out_dim, in_dim)) + c_np = np.random.randint(low=-8, high=7, size=(out_dim,)) + elif dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype) + b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(dtype) else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) + a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(dtype) + d_np = tvm.topi.testing.dense(a_np, b_np, c_np, use_bias, True, out_dtype) return (a_np, b_np, c_np, d_np) # get the test data a_np, b_np, c_np, d_np = get_ref_data() + if dtype == "int4": + a_np = convert_int32_into_int4(a_np) + b_np = convert_int32_into_int4(b_np) + c_np = convert_int32_into_int4_bias(c_np) def check_device(device): dev = tvm.device(device, 0) print("Running on target: %s" % device) for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): with tvm.target.Target(device): - D = fcompute(A, B, C if use_bias else None) + D = fcompute(A, B, C if use_bias else None, out_dtype) D = topi.nn.relu(D) s = fschedule([D]) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(c_np, dev) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) f = tvm.build(s, [A, B, C, D], device, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-3) @@ -73,11 +127,17 @@ def check_device(device): @tvm.testing.requires_tensorcore def test_dense_tensorcore(): """Test cases""" - verify_dense(8, 16, 32, use_bias=True) - verify_dense(16, 32, 16, use_bias=True) - verify_dense(256, 1024, 1024, use_bias=True) - verify_dense(1000, 1024, 1024, use_bias=False) - verify_dense(256, 2048, 1000, use_bias=False) + for dtype in ["float16", "int8"]: + verify_dense(8, 16, 32, "float16", use_bias=True) + verify_dense(16, 32, 16, dtype, use_bias=True) + verify_dense(256, 1024, 1024, dtype, use_bias=True) + verify_dense(1000, 1024, 1024, dtype, use_bias=False) + verify_dense(256, 2048, 1000, dtype, use_bias=False) + # TODO: need fix int4 use_bias=True, wyc-ruiker + verify_dense(16, 32, 16, "int4", use_bias=False) + verify_dense(256, 1024, 1024, "int4", use_bias=False) + verify_dense(1000, 1024, 1024, "int4", use_bias=False) + verify_dense(256, 2048, 1000, "int4", use_bias=False) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 76093c51b4c8..092ac9df5f9a 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -14,561 +14,375 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import numpy as np +import pytest + import tvm -from tvm import te -from tvm import autotvm -from tvm import topi +import tvm.testing import tvm.topi.testing -import numpy as np + +from tvm import autotvm, te, topi from tvm.topi.utils import get_const_tuple from tvm.topi.nn.utils import get_pad_tuple from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.depthwise_conv2d import _get_workload from tvm.topi.x86.depthwise_conv2d import _fallback_schedule -import tvm.testing -_depthwise_conv2d_nchw_implement = { - "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], - "arm_cpu": [ - (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), - ( - topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, - topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, - ), - ], - "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], - "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], - "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], - "intel_graphics": [ - ( - topi.intel_graphics.depthwise_conv2d_nchw, - topi.intel_graphics.schedule_depthwise_conv2d_nchw, - ) - ], +_depthwise_conv2d_implement = { + "NCHW": { + "generic": [(topi.nn.depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nchw)], + "arm_cpu": [ + (topi.arm_cpu.depthwise_conv2d_nchw, topi.arm_cpu.schedule_depthwise_conv2d_nchw), + ( + topi.arm_cpu.depthwise_conv2d_nchw_spatial_pack, + topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack, + ), + ], + "gpu": [(topi.cuda.depthwise_conv2d_nchw, topi.cuda.schedule_depthwise_conv2d_nchw)], + "mali": [(topi.mali.depthwise_conv2d_nchw, topi.mali.schedule_depthwise_conv2d_nchw)], + "bifrost": [(topi.nn.depthwise_conv2d_nchw, topi.bifrost.schedule_depthwise_conv2d_nchw)], + "intel_graphics": [ + ( + topi.intel_graphics.depthwise_conv2d_nchw, + topi.intel_graphics.schedule_depthwise_conv2d_nchw, + ) + ], + }, + "NHWC": { + "generic": [(topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc)], + "arm_cpu": [ + ( + topi.arm_cpu.compute_depthwise_conv2d_nhwc, + topi.arm_cpu.schedule_depthwise_conv2d_nhwc, + ) + ], + "gpu": [(topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc)], + }, + "NCHWc": { + "generic": [(topi.x86.depthwise_conv2d_NCHWc, topi.x86.schedule_depthwise_conv2d_NCHWc)], + }, } -_depthwise_conv2d_nhwc_implement = { - "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), - "arm_cpu": ( - topi.arm_cpu.compute_depthwise_conv2d_nhwc, - topi.arm_cpu.schedule_depthwise_conv2d_nhwc, - ), - "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), -} +in_dtype, out_dtype = tvm.testing.parameters(("float32", "float32")) -def compile_depthwise_NHWC_int8_arm( - batch, - in_channel, - in_size, - kernel, - depth_multiplier, - stride, - padding, - add_bias=False, - dilation=1, -): - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) - padding_sum = pad_top + pad_left + pad_bottom + pad_right - - in_height = in_width = in_size - A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int16") - W = te.placeholder((kernel, kernel, in_channel, depth_multiplier), name="W", dtype="int16") - bias = te.placeholder((in_channel * depth_multiplier,), name="bias", dtype="int32") - dtype = "int32" - - target = "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu" - compute = topi.arm_cpu.compute_depthwise_conv2d_nhwc - schedule = topi.arm_cpu.schedule_depthwise_conv2d_nhwc - - if not tvm.testing.device_enabled(target): - print("Skip because %s is not enabled" % target) - return - - print("Compiling on arm AArch64 target: %s" % target) - with tvm.target.Target(target): - assert topi.arm_cpu.arm_utils.is_aarch64_arm(), "AArch64 target not recognized" - - C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype) - if add_bias: - C += bias - ins_outs = [A, W, bias, C] - else: - ins_outs = [A, W, C] - - s = schedule([C]) - - func = tvm.build( - s, - ins_outs, - target, - name="depthwise_conv2d", - ) - - -def depthwise_conv2d_with_workload_nchw( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - padt, padl, padb, padr = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (padt, padl, padb, padr) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_channel, in_height, in_width), name="Input") - Filter = te.placeholder( - (filter_channel, channel_multiplier, filter_height, filter_width), name="Filter" - ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - dtype = "float32" - - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nchw_implement)[:] - if target == "llvm" and channel_multiplier == 1 and dilation == 1: - impl_list.append( - (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) - ) +@tvm.testing.fixture +def input_shape(layout, batch, in_channel, in_size, filter_shape): + if layout == "NCHW": + return (batch, in_channel, in_size, in_size) + elif layout == "NHWC": + return (batch, in_size, in_size, in_channel) + elif layout == "NCHWc": + oc_block = filter_shape[-1] + ic_block = next(bn for bn in range(oc_block, 0, -1) if in_channel % bn == 0) + return (batch, in_channel // ic_block, in_size, in_size, ic_block) - for fcompute, fschedule in impl_list: - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nchw") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (1, 1, dilation, dilation) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dilated_filter_np, stride, padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, c, :, :] = ( - depthwise_conv2d_scipy[:, c, :, :] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - def verify_workload_padding(): - _, _, out_height, out_width = get_const_tuple(depthwise_conv2d_scipy.shape) - wkl = _get_workload( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - - # check if tile_ow candidates are the factors of the right output weight. - with tvm.target.Target(target): - cfg = autotvm.get_config() - _fallback_schedule(cfg, wkl) - ow_tile = np.prod(cfg["tile_ow"].size) - - tvm.testing.assert_allclose(ow_tile, out_width) - - if "llvm" in target: - verify_workload_padding() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), - dev, - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev - ) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - tvm.testing.assert_allclose( - depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5 - ) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def depthwise_conv2d_with_workload_nhwc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride_h, +@tvm.testing.fixture +def filter_shape(layout, in_channel, channel_multiplier, kernel): + filter_channel = in_channel + if layout == "NCHW": + return (filter_channel, channel_multiplier, kernel, kernel) + elif layout == "NHWC": + return (kernel, kernel, filter_channel, channel_multiplier) + elif layout == "NCHWc": + out_channel = in_channel * channel_multiplier + # For testing the functionality, we choose an arbitrary block + # size that can divide out_channel, regardless of the + # performance. + oc_block = next(bn for bn in range(16, 0, -1) if out_channel % bn == 0) + return (out_channel // oc_block, 1, kernel, kernel, 1, oc_block) + + +@tvm.testing.fixture +def scale_shape(layout, in_channel, channel_multiplier, filter_shape): + out_channel = in_channel * channel_multiplier + + if layout in ("NCHW", "NHWC"): + return (out_channel,) + + if layout == "NCHWc": + oc_block = filter_shape[-1] + return (out_channel // oc_block, oc_block) + + raise ValueError("Unknown layout {}".format(layout)) + + +@tvm.testing.fixture +def shift_shape(scale_shape): + return scale_shape + + +@tvm.testing.fixture(cache_return_value=True) +def ref_data( + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + dilation, + stride, padding, - dilation=1, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, ): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_w = stride_h - - if dilation == 1: - # here we transform the padding argument from 'str' to 'tuple' , - # because we need this to match the "workload" tuple to the records in TopHub - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - else: - padding_args = padding - - # placeholder - Input = te.placeholder((batch, in_height, in_width, in_channel), name="Input") - Filter = te.placeholder( - (filter_height, filter_width, filter_channel, channel_multiplier), name="Filter" + input_np = np.random.uniform(size=input_shape).astype(in_dtype) + filter_np = np.random.uniform(size=filter_shape).astype(in_dtype) + scale_np = np.random.uniform(size=scale_shape).astype(out_dtype) + shift_np = np.random.uniform(size=shift_shape).astype(out_dtype) + if layout == "NCHW": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchw + dilation = (1, 1, dilation, dilation) + reshape = (1, -1, 1, 1) + elif layout == "NHWC": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nhwc + dilation = (dilation, dilation, 1, 1) + reshape = (1, 1, 1, -1) + elif layout == "NCHWc": + np_depthwise_conv2d = tvm.topi.testing.depthwise_conv2d_python_nchwc + dilation = (1, 1, dilation, dilation, 1, 1) + reshape = (1, scale_shape[0], 1, 1, scale_shape[1]) + + dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation) + output_np = np_depthwise_conv2d(input_np, dilated_filter_np, stride, padding) + + if use_scale_shift: + output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape) + if apply_relu: + output_np = np.maximum(output_np, 0) + + return ( + input_np, + filter_np, + scale_np, + shift_np, + output_np, ) - Scale = te.placeholder((in_channel * channel_multiplier,), name="Scale") - Shift = te.placeholder((in_channel * channel_multiplier,), name="Shift") - dtype = "float32" - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - fcompute, fschedule = tvm.topi.testing.dispatch(target, _depthwise_conv2d_nhwc_implement) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = fcompute( - Input, Filter, (stride_h, stride_w), padding_args, dilation, dtype - ) - ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = fschedule(DepthwiseConv2d) - s2 = fschedule(ScaleShift) - s3 = fschedule(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], target) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], target) - - # Prepare pod type for test data closure - input_shape = get_const_tuple(Input.shape) - filter_shape = get_const_tuple(Filter.shape) - scale_shape = get_const_tuple(Scale.shape) - shift_shape = get_const_tuple(Shift.shape) - scale_shift_shape = get_const_tuple(ScaleShift.shape) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.nhwc.v2") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - dilated_filter_np = tvm.topi.testing.dilate_python( - filter_np, (dilation, dilation, 1, 1) - ) - scale_np = np.random.uniform(size=scale_shape).astype(dtype) - shift_np = np.random.uniform(size=shift_shape).astype(dtype) - # correctness with scipy - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nhwc( - input_np, dilated_filter_np, stride=[stride_h, stride_w], padding=padding - ) - scale_shift_scipy = np.zeros(shape=scale_shift_shape) - for c in range(in_channel * channel_multiplier): - scale_shift_scipy[:, :, :, c] = ( - depthwise_conv2d_scipy[:, :, :, c] * scale_np[c] + shift_np[c] - ) - relu_scipy = np.maximum(scale_shift_scipy, 0) - return ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) +class BaseDepthwiseConv2D: + """Provides the test_conv2d test function, to be used by other test classes. - # Get the test data - ( - input_np, - filter_np, - scale_np, - shift_np, - depthwise_conv2d_scipy, - scale_shift_scipy, - relu_scipy, - ) = get_ref_data() - - # prepare data - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - scale_tvm = tvm.nd.array(scale_np, dev) - shift_tvm = tvm.nd.array(shift_np, dev) - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - scale_shift_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, dev, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, dev, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, dev, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - relu_scipy = np.maximum(scale_shift_scipy, 0) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(scale_shift_tvm.numpy(), scale_shift_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -def _transform_data(data, bn): - # NCHW -> NCHW[x]c - batch_size, channel, height, width = data.shape - data = np.reshape(data, (batch_size, channel // bn, bn, height, width)) - data = np.transpose(data, (0, 1, 3, 4, 2)) - return data - - -def _transform_kernel(kernel, bn): - # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block - channel, channel_multiplier, kh, kw = kernel.shape - out_channel = channel * channel_multiplier - kernel = np.reshape(kernel, (out_channel // bn, bn, kh, kw)) - kernel = np.transpose(kernel, (0, 2, 3, 1)) - out_channel_chunk, kh, kw, out_channel_block = kernel.shape - return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block) - - -def depthwise_conv2d_with_workload_NCHWc( - target, - dev, - batch, - in_channel, - in_height, - channel_multiplier, - filter_height, - stride, - padding, - dilation=1, -): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_h = stride_w = stride - - assert ( - channel_multiplier == 1 - ), "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1." - pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width)) - padding_args = (pad_h, pad_w) - - out_channel = filter_channel * channel_multiplier - # for testing functionality, - # we choose arbitrary block size that can divide the channel, - # regardless of the performance. - oc_block = 1 - for bn in range(16, 0, -1): - if out_channel % bn == 0: - oc_block = bn - break - - ic_block = 1 - for bn in range(oc_block, 0, -1): - if in_channel % bn == 0: - ic_block = bn - break - - # placeholder - Input = te.placeholder( - (batch, in_channel // ic_block, in_height, in_width, ic_block), name="Input" - ) - Filter = te.placeholder( - (out_channel // oc_block, 1, filter_height, filter_width, 1, oc_block), name="Filter" - ) - in_layout = "NCHW%dc" % ic_block - out_layout = "NCHW%dc" % oc_block - dtype = "float32" + Test parameter sets are split out into different classes for + readability (e.g. used for mobilenet), and for restrictions + (e.g. implemented only for llvm). + """ - with autotvm.tophub.context(target): # load tophub pre-tuned parameters - dev = tvm.device(target, 0) - with tvm.target.Target(target): - # declare - DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc( + layout = tvm.testing.parameter("NCHW", "NHWC") + + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + (4, 256, 64, 2, 5, 2), + ) + padding = tvm.testing.parameter("SAME", "VALID") + dilation = tvm.testing.parameter(1, 2) + + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) + apply_relu = tvm.testing.parameter(True, False, ids=["with_relu", "no_relu"]) + + run_after_compile = True + + def test_conv2d( + self, + request, + target, + dev, + in_dtype, + out_dtype, + layout, + input_shape, + filter_shape, + scale_shape, + shift_shape, + use_scale_shift, + apply_relu, + batch, + in_channel, + channel_multiplier, + kernel, + stride, + padding, + dilation, + ): + # Transform the padding argument from 'str' to 'tuple' to + # match the "workload" tuple in TopHub. Which padding_args to + # use for each layout chosen to reproduce previous behavior. + if dilation == 1: + padding_args = get_pad_tuple(padding, (kernel, kernel)) + padding_args_i = [0, 1, 2, 3] if layout == "NCHW" else [0, 1] + padding_args = [padding_args[i] for i in padding_args_i] + else: + padding_args = padding + + # placeholder + Input = te.placeholder(input_shape, name="Input", dtype=in_dtype) + Filter = te.placeholder(filter_shape, name="Filter", dtype=in_dtype) + Scale = te.placeholder(scale_shape, name="Scale", dtype=out_dtype) + Shift = te.placeholder(shift_shape, name="Shift", dtype=out_dtype) + + if layout == "NCHW": + topi_scale_shift = topi.nn.scale_shift_nchw + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NHWC": + topi_scale_shift = topi.nn.scale_shift_nhwc + fcompute_args = (Input, Filter, stride, padding_args, dilation, out_dtype) + + elif layout == "NCHWc": + topi_scale_shift = topi.nn.scale_shift_nchwc + in_layout = "NCHW{}c".format(input_shape[-1]) + out_layout = "NCHW{}c".format(filter_shape[-1]) + fcompute_args = ( Input, Filter, - (stride_h, stride_w), + stride, padding, - (dilation, dilation), + dilation, in_layout, out_layout, - dtype, - ) - # TODO: add scale_shift implement for NCHWc and add test here - Relu = topi.nn.relu(DepthwiseConv2d) - # schedule - s1 = topi.x86.schedule_depthwise_conv2d_NCHWc(DepthwiseConv2d) - s2 = topi.x86.schedule_depthwise_conv2d_NCHWc(Relu) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], target) - f2 = tvm.build(s2, [Input, Filter, Relu], target) - - # Prepare pod type for test data closure - input_shape = (batch, in_channel, in_height, in_width) - filter_shape = (filter_channel, channel_multiplier, filter_height, filter_width) - - # Use memoize, pickle the test data for next time use. - @memoize("topi.tests.test_topi_depthwise_conv2d.NCHWc") - def get_ref_data(): - input_np = np.random.uniform(size=input_shape).astype(dtype) - filter_np = np.random.uniform(size=filter_shape).astype(dtype) - # correctness with scipy - dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype( - dtype - ) - depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw( - input_np, dw_np, stride, padding - ) - relu_scipy = np.maximum(depthwise_conv2d_scipy, 0) - return ( - _transform_data(input_np, ic_block), - _transform_kernel(filter_np, oc_block), - _transform_data(depthwise_conv2d_scipy, oc_block), - _transform_data(relu_scipy, oc_block), + out_dtype, ) - # Get the test data - (input_np, filter_np, depthwise_conv2d_scipy, relu_scipy) = get_ref_data() - - input_tvm = tvm.nd.array(input_np, dev) - filter_tvm = tvm.nd.array(filter_np, dev) - - depthwise_conv2d_tvm = tvm.nd.array( - np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), dev - ) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), dev) - # launch kernel 1 (depthwise_conv2d) - f1(input_tvm, filter_tvm, depthwise_conv2d_tvm) - # launch kernel 2 (depthwise_conv2d + relu) - f2(input_tvm, filter_tvm, relu_tvm) - tvm.testing.assert_allclose(depthwise_conv2d_tvm.numpy(), depthwise_conv2d_scipy, rtol=1e-5) - tvm.testing.assert_allclose(relu_tvm.numpy(), relu_scipy, rtol=1e-5) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nchw(target, dev): - # mobilenet workloads - depthwise_conv2d_with_workload_nchw(target, dev, 1, 32, 112, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 64, 112, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 128, 56, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 256, 28, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 512, 14, 1, 3, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 1024, 7, 1, 3, 1, "SAME") - - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nchw(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - # dilation = 2 - depthwise_conv2d_with_workload_nchw(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -@tvm.testing.parametrize_targets -def test_depthwise_conv2d_nhwc(target, dev): - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload_nhwc(target, dev, 4, 256, 64, 2, 5, 2, "VALID") - - # dilation = 2 - # disabled because it uses too large shared memory on cuda - # depthwise_conv2d_with_workload_nhwc(target, dev, 1, 728, 64, 1, 3, 1, "SAME", dilation=2) - - -# test llvm only for now since depthwise_conv2d_NCHWc implement is missing in other backend. + with autotvm.tophub.context(target): # load tophub pre-tuned parameters + impl_list = tvm.topi.testing.dispatch(target, _depthwise_conv2d_implement[layout])[:] + if target == "llvm" and layout == "NCHW" and channel_multiplier == 1 and dilation == 1: + impl_list.append( + (topi.x86.depthwise_conv2d_nchw, topi.x86.schedule_depthwise_conv2d_nchw) + ) + + for fcompute, fschedule in impl_list: + with tvm.target.Target(target): + # Declare, build schedule + C = fcompute(*fcompute_args) + if use_scale_shift: + C = topi_scale_shift(C, Scale, Shift) + if apply_relu: + C = topi.nn.relu(C) + + s = fschedule(C) + + # Build and run + f = tvm.build(s, [Input, Filter, Scale, Shift, C], target) + + if self.run_after_compile: + input_np, filter_np, scale_np, shift_np, output_np = request.getfixturevalue( + "ref_data" + ) + input_tvm = tvm.nd.array(input_np, dev) + filter_tvm = tvm.nd.array(filter_np, dev) + scale_tvm = tvm.nd.array(scale_np, dev) + shift_tvm = tvm.nd.array(shift_np, dev) + output_tvm = tvm.nd.array( + np.zeros(shape=get_const_tuple(C.shape), dtype=C.dtype), + dev, + ) + + f(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm) + tvm.testing.assert_allclose(output_np, output_tvm.numpy(), rtol=1e-5) + + +class TestDepthwiseConv2D(BaseDepthwiseConv2D): + """Test variety of parameters, defined in BaseDepthwiseConv2D. Also + has llvm-specific tests for workload padding.""" + + @tvm.testing.parametrize_targets("llvm") + def test_workload_padding( + self, + out_dtype, + layout, + input_shape, + filter_shape, + target, + ref_data, + stride, + padding, + dilation, + ): + input_np, filter_np, scale_np, shift_np, output_np = ref_data + if layout == "NCHW": + _, _, out_height, out_width = output_np.shape + elif layout == "NHWC": + _, out_height, out_width, _ = output_np.shape + elif layout == "NCHWc": + _, _, out_height, out_width, _ = output_np.shape + + Input = te.placeholder(input_shape, name="Input") + Filter = te.placeholder(filter_shape, name="Filter") + wkl = _get_workload(Input, Filter, (stride, stride), padding, dilation, out_dtype, layout) + + # check if tile_ow candidates are the factors of the right output weight. + with tvm.target.Target(target): + cfg = autotvm.get_config() + _fallback_schedule(cfg, wkl) + ow_tile = np.prod(cfg["tile_ow"].size) + + tvm.testing.assert_allclose(ow_tile, out_width) + + +class TestDepthwiseConv2D_MobilenetWorkloads(BaseDepthwiseConv2D): + """Extra tests to verify functionality for workloads used by mobilenet.""" + + layout = tvm.testing.parameter("NCHW") + + batch = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1) + kernel = tvm.testing.parameter(3) + padding = tvm.testing.parameter("SAME") + dilation = tvm.testing.parameter(1) + + in_channel, in_size, stride = tvm.testing.parameters( + (32, 112, 1), + (64, 112, 2), + (128, 56, 1), + (128, 56, 2), + (256, 28, 1), + (256, 28, 2), + (512, 14, 1), + (512, 14, 2), + (1024, 7, 1), + ) + + @tvm.testing.parametrize_targets("llvm") -def test_depthwise_conv2d_nchwc(target, dev): - # NCHW[x]c - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME", dilation=2) - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload_NCHWc(target, dev, 1, 728, 32, 1, 3, 1, "VALID") +class TestDepthwiseConv2D_NCHWc(BaseDepthwiseConv2D): + """Tests specific to NCHWc layouts. + + Once the implementation supports channel_multiplier>1 and GPU + devices, this class can be merged into TestDepthwiseConv2D. + """ + + # depthwise_conv2d_NCHWc currently does not support channel multiplier > 1 + layout = tvm.testing.parameter("NCHWc") + (batch, in_channel, in_size, channel_multiplier, kernel, stride) = tvm.testing.parameters( + (1, 728, 32, 1, 3, 1), + ) + + +@tvm.testing.parametrize_targets("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu") +class TestDepthwiseConv2DArmCompile(BaseDepthwiseConv2D): + """Compile-only tests for cross-compiling to ARM.""" + layout = tvm.testing.parameter("NHWC", "NCHW") + batch = tvm.testing.parameter(1) + dilation = tvm.testing.parameter(1) + in_dtype, out_dtype = tvm.testing.parameters(("int16", "int32")) + in_channel = tvm.testing.parameter(728) + in_size = tvm.testing.parameter(32) + kernel = tvm.testing.parameter(1) + channel_multiplier = tvm.testing.parameter(1, 3) + stride = tvm.testing.parameter(1) + padding = tvm.testing.parameter("SAME") + use_scale_shift = tvm.testing.parameter(True, False, ids=["with_scale_shift", "no_scale_shift"]) -def test_depthwise_conv2d_arm(): - # Test compilation on arm targets - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 3, 1, "SAME") - compile_depthwise_NHWC_int8_arm(1, 728, 32, 1, 1, 1, "SAME", True) + run_after_compile = False if __name__ == "__main__": - test_depthwise_conv2d() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py b/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py index 8e30ed6840e3..0bbb0e6c0cca 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d_back_weight.py @@ -36,8 +36,8 @@ def verify_depthwise_conv2d_back_weight( stride_w = stride_h padding_w = padding_h - out_h = np.int((in_h + 2 * padding_h - filter_h) / stride_h + 1) - out_w = np.int((in_w + 2 * padding_w - filter_w) / stride_w + 1) + out_h = int((in_h + 2 * padding_h - filter_h) / stride_h + 1) + out_w = int((in_w + 2 * padding_w - filter_w) / stride_w + 1) out_channel = in_channel * channel_multiplier oshape = [batch, out_h, out_w, out_channel] diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 2730783907fd..fe7fba52f1ee 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -24,7 +24,7 @@ from tvm.contrib.pickle_memoize import memoize -def verify_resize( +def verify_resize2d( batch, in_channel, in_height, @@ -33,7 +33,7 @@ def verify_resize( out_width, layout="NCHW", coord_trans="align_corners", - method="bilinear", + method="linear", ): if layout == "NCHW": A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="float32") @@ -47,24 +47,16 @@ def verify_resize( a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError("Layout not supported {} ".format(layout)) - B = topi.image.resize( + B = topi.image.resize2d( A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method, ) - if method == "bilinear": - b_np = tvm.topi.testing.bilinear_resize_python( - a_np, (out_height, out_width), layout, coord_trans - ) - else: - # TODO: Nearest neighbor case doesn't do anything with coordinate transform mode, and also - # nearest_neighbors and align_corners combination in topi doesn't match the output of this - # function. - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize2d_python(a_np, (scale_h, scale_w), layout, method, coord_trans) def check_target(target, dev): print("Running on target: %s" % target) @@ -82,19 +74,21 @@ def check_target(target, dev): @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): # Scale NCHW - verify_resize(4, 16, 32, 32, 50, 50, "NCHW") + verify_resize2d(4, 16, 32, 32, 50, 50, "NCHW") # Scale NCHW + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NCHW") + verify_resize2d(6, 32, 64, 64, 20, 20, "NCHW") # Scale NHWC - verify_resize(4, 16, 32, 32, 50, 50, "NHWC") + verify_resize2d(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NHWC") - for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric - for layout in ["NCHW", "NHWC"]: - verify_resize(4, 16, 32, 32, 50, 50, layout, coord_trans, method=method) + verify_resize2d(6, 32, 64, 64, 20, 20, "NHWC") + for layout in ["NCHW", "NHWC"]: + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "align_corners", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="linear") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") def verify_resize3d( @@ -107,8 +101,8 @@ def verify_resize3d( out_height, out_width, layout="NCDHW", - coordinate_transformation_mode="half_pixel", - method="trilinear", + coordinate_transformation_mode="asymmetric", + method="linear", ): if layout == "NCDHW": A = te.placeholder( @@ -139,18 +133,14 @@ def verify_resize3d( method=method, ) - if method == "trilinear": - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, (out_depth, out_height, out_width), layout, coordinate_transformation_mode - ) - else: - scale_d = out_depth / in_depth - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + scale_d = out_depth / in_depth + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize3d_python( + a_np, (scale_d, scale_h, scale_w), layout, method, coordinate_transformation_mode + ) def check_target(target, dev): - print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(B) a = tvm.nd.array(a_np, dev) @@ -167,16 +157,10 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_resize3d(): # Trilinear - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW") - verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "asymmetric") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "asymmetric") - - # Nearest neighbor - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW", method="nearest_neighbor") - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NDHWC", method="nearest_neighbor") + for method in ["nearest_neighbor", "linear"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NCDHW", "NDHWC"]: + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, layout, coord_trans, method) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_pooling.py b/tests/python/topi/python/test_topi_pooling.py index 57877e3d202c..b7f11de1391f 100644 --- a/tests/python/topi/python/test_topi_pooling.py +++ b/tests/python/topi/python/test_topi_pooling.py @@ -315,6 +315,7 @@ def verify_poolnd( pool_type, count_include_pad, ceil_mode, + layout=layout, ) np.testing.assert_equal(tuple(output_shape), tuple(ref_np.shape)) @@ -355,7 +356,7 @@ def verify_pool3d( padding, pool_type, ceil_mode, - layout="NCDHW", + layout=layout, count_include_pad=count_include_pad, ) @@ -363,18 +364,106 @@ def verify_pool3d( @tvm.testing.uses_gpu def test_pool3d(): """test cases of pool3d""" + verify_pool3d( + [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "avg", False, True + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 1, 2, 2, 2, 1], "avg", False, True + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [3, 3, 3, 3, 3, 3], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", True + ) + + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 1, 0, 5, 4, 3], "avg", False, True + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 0, 5, 4, 3, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [1, 1, 1], [3, 2, 1, 0, 5, 4], "max", True + ) + + # Test non-1 dilation + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [3, 3, 3], [2, 1, 0, 5, 4, 3], "avg", False, True + ) verify_pool3d( [1, 16, 32, 32, 32], [2, 2, 2], [2, 2, 2], + [2, 2, 2], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 1, 3], [1, 0, 5, 4, 3, 2], "max", False + ) + verify_pool3d( + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 2, 3], [3, 2, 1, 0, 5, 4], "max", True + ) + # Test channel last layouts + verify_pool3d( + [1, 32, 32, 32, 16], + [2, 2, 2], + [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], @@ -382,9 +471,10 @@ def test_pool3d(): "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], @@ -392,9 +482,10 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [4, 4, 4], [4, 4, 4], [1, 1, 1], @@ -402,9 +493,10 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [4, 4, 4], [4, 4, 4], [1, 1, 1], @@ -412,37 +504,41 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], [0, 0, 0, 0, 0, 0], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [2, 2, 1, 1, 1, 2], "max", True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], @@ -450,9 +546,10 @@ def test_pool3d(): "avg", False, True, + layout="NDHWC", ) verify_pool3d( - [1, 16, 32, 32, 32], + [1, 32, 32, 32, 16], [2, 2, 2], [2, 2, 2], [1, 1, 1], @@ -460,36 +557,32 @@ def test_pool3d(): "avg", False, False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [1, 0, 5, 4, 3, 2], "max", False, + layout="NDHWC", ) verify_pool3d( - [1, 16, 31, 31, 31], + [1, 31, 31, 31, 16], [3, 3, 3], [3, 3, 3], [1, 1, 1], [3, 2, 1, 0, 5, 4], "max", True, + layout="NDHWC", ) # Test non-1 dilation verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [3, 3, 3], - [2, 1, 0, 5, 4, 3], - "avg", - False, - True, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [3, 3, 3], [2, 1, 0, 5, 4, 3], "avg", False, True ) verify_pool3d( [1, 16, 32, 32, 32], @@ -502,27 +595,23 @@ def test_pool3d(): False, ) verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [2, 1, 3], - [1, 0, 5, 4, 3, 2], - "max", - False, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 1, 3], [1, 0, 5, 4, 3, 2], "max", False ) verify_pool3d( - [1, 16, 31, 31, 31], - [3, 3, 3], - [3, 3, 3], - [2, 2, 3], - [3, 2, 1, 0, 5, 4], - "max", - True, + [1, 16, 31, 31, 31], [3, 3, 3], [3, 3, 3], [2, 2, 3], [3, 2, 1, 0, 5, 4], "max", True ) def verify_pool2d( - input_shape, kernel, stride, dilation, padding, pool_type, ceil_mode, count_include_pad=True + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad=True, + layout="NCHW", ): verify_poolnd( 2, @@ -533,7 +622,7 @@ def verify_pool2d( padding, pool_type, ceil_mode, - layout="NCHW", + layout=layout, count_include_pad=count_include_pad, ) @@ -541,162 +630,69 @@ def verify_pool2d( @tvm.testing.uses_gpu def test_pool2d(): """test cases of pool""" + verify_pool2d([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True) + + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True) + + # Test non-1 dilations + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True) + verify_pool2d([1, 16, 32, 32], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False) + verify_pool2d([1, 16, 31, 31], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True) + # Test channel last verify_pool2d( - [1, 16, 32, 32], - [2, 2], - [2, 2], - [1, 1], - [0, 0, 0, 0], - "avg", - False, - True, - ) - verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [1, 2, 1, 2], - "avg", - False, - True, + [1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 32, 32], - [2, 2], - [2, 2], - [1, 1], - [1, 2, 1, 2], - "avg", - False, - False, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [4, 4], - [4, 4], - [1, 1], - [3, 3, 3, 3], - "avg", - False, - False, + [1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [4, 4], - [4, 4], - [1, 1], - [0, 0, 0, 0], - "avg", - False, - False, + [1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [1, 1], - [0, 0, 0, 0], - "max", - False, + [1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 2, 1], - "max", - False, + [1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 2, 1], - "max", - True, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, layout="NHWC") verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [2, 1, 0, 3], - "avg", - False, - True, - ) - verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [1, 1], - [0, 3, 2, 1], - "avg", - False, - False, - ) - verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [1, 0, 3, 2], - "max", - False, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [1, 1], - [3, 2, 1, 0], - "max", - True, + [1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, layout="NHWC" ) - - # Test non-1 dilations verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [2, 1], - [2, 1, 0, 3], - "avg", - False, - True, + [1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, layout="NHWC") verify_pool2d( - [1, 16, 32, 32], - [2, 3], - [2, 2], - [2, 3], - [0, 3, 2, 1], - "avg", - False, - False, + [1, 31, 31, 16], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [3, 3], - [1, 0, 3, 2], - "max", - False, + [1, 32, 32, 16], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, layout="NHWC" ) verify_pool2d( - [1, 16, 31, 31], - [3, 3], - [3, 3], - [2, 2], - [3, 2, 1, 0], - "max", - True, + [1, 31, 31, 16], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, layout="NHWC" ) + verify_pool2d([1, 31, 31, 16], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, layout="NHWC") def verify_pool1d( @@ -719,7 +715,7 @@ def verify_pool1d( padding, pool_type, ceil_mode, - layout="NCW", + layout=layout, count_include_pad=count_include_pad, ) @@ -727,162 +723,43 @@ def verify_pool1d( @tvm.testing.uses_gpu def test_pool1d(): """test cases of pool1d""" - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 0], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [1, 2], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [1, 2], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [4], - [4], - [1], - [3, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [4], - [4], - [1], - [0, 0], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 0], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 1], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 1], - "max", - True, - ) - - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [2, 5], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [1], - [0, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [1, 4], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [1], - [3, 0], - "max", - True, - ) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 0], "avg", False, True) + verify_pool1d([1, 16, 31], [3], [3], [1], [1, 2], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [1], [1, 2], "avg", False, False) + verify_pool1d([1, 16, 31], [4], [4], [1], [3, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [4], [4], [1], [0, 0], "avg", False, False) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 0], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 1], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 1], "max", True) + + verify_pool1d([1, 16, 31], [3], [3], [1], [2, 5], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [1], [0, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [3], [3], [1], [1, 4], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [1], [3, 0], "max", True) # Test non-1 dilations - verify_pool1d( - [1, 16, 31], - [3], - [3], - [2], - [2, 5], - "avg", - False, - True, - ) - verify_pool1d( - [1, 16, 32], - [2], - [2], - [3], - [0, 3], - "avg", - False, - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [2], - [1, 4], - "max", - False, - ) - verify_pool1d( - [1, 16, 31], - [3], - [3], - [3], - [3, 0], - "max", - True, - ) + verify_pool1d([1, 16, 31], [3], [3], [2], [2, 5], "avg", False, True) + verify_pool1d([1, 16, 32], [2], [2], [3], [0, 3], "avg", False, False) + verify_pool1d([1, 16, 31], [3], [3], [2], [1, 4], "max", False) + verify_pool1d([1, 16, 31], [3], [3], [3], [3, 0], "max", True) + # Test Channel last + verify_pool1d([1, 32, 16], [2], [2], [1], [0, 0], "avg", False, True, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [1, 2], "avg", False, True, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [1], [1, 2], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [4], [4], [1], [3, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [4], [4], [1], [0, 0], "avg", False, False, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [1], [0, 0], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 1], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 1], "max", True, layout="NWC") + + verify_pool1d([1, 31, 16], [3], [3], [1], [2, 5], "avg", False, True, layout="NWC") + verify_pool1d([1, 31, 16], [2], [2], [1], [0, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [1, 4], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [1], [3, 0], "max", True, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [2], [2, 5], "avg", False, True, layout="NWC") + verify_pool1d([1, 32, 16], [2], [2], [3], [0, 3], "avg", False, False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [2], [1, 4], "max", False, layout="NWC") + verify_pool1d([1, 31, 16], [3], [3], [3], [3, 0], "max", True, layout="NWC") if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 003b89f7122a..11006576fea3 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -35,21 +35,20 @@ } -def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmv(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, 1), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, 1), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmv(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, 1), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -81,21 +80,20 @@ def check_device(device): check_device(device) -def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True): +def verify_dynamic_csrmm(batch, in_dim, out_dim, dtype, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") - dtype = "float32" A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name="A") - B = te.placeholder((in_dim, out_dim), name="B") - C = te.placeholder((nr,), name="C") + B = te.placeholder((in_dim, out_dim), dtype=dtype, name="B") + C = te.placeholder((nr,), dtype=dtype, name="C") D = topi.sparse.csrmm(A, B, C if use_bias else None) s = te.create_schedule(D.op) dtype = A.dtype # get the test data def get_ref_data(): - a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype) - 0.5, 0.0) - b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype) - 0.5 - c_np = np.random.uniform(size=(batch,)).astype(dtype) + a_np = np.random.uniform(size=(batch, in_dim), high=100).astype(dtype) + b_np = np.random.uniform(size=(in_dim, out_dim), high=100).astype(dtype) + c_np = np.random.uniform(size=(batch,), high=100).astype(dtype) if use_bias: d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1)) else: @@ -212,14 +210,15 @@ def check_device(device): def test_csrmv(): - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False) - verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float32", use_bias=False) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="float64", use_bias=True) + verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, dtype="int32", use_bias=True) def test_csrmm(): M, K, N = 5, 7, 2 - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False) - verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="int64", use_bias=False) + verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, dtype="float64", use_bias=True) def test_dense_si(): diff --git a/tests/python/topi/python/test_topi_upsampling.py b/tests/python/topi/python/test_topi_upsampling.py index 0ab0e64af4c7..7793417a9a2b 100644 --- a/tests/python/topi/python/test_topi_upsampling.py +++ b/tests/python/topi/python/test_topi_upsampling.py @@ -78,11 +78,13 @@ def verify_upsampling( B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False) - if method == "bilinear": - out_size = (int(round(in_height * scale_h)), int(round(in_width * scale_w))) - b_np = tvm.topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric") - else: - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize2d_python( + a_np, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target) @@ -213,20 +215,16 @@ def verify_upsampling3d( scale_w, layout=layout, method=method, - coordinate_transformation_mode="half_pixel", + coordinate_transformation_mode="asymmetric", ) - if method == "trilinear": - out_size = ( - int(round(in_depth * scale_d)), - int(round(in_height * scale_h)), - int(round(in_width * scale_w)), - ) - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, out_size, layout, coordinate_transformation_mode="half_pixel" - ) - else: - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize3d_python( + a_np, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 7bfdfc676b67..c307034c04c9 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -285,6 +285,15 @@ def test_predicate(): ) assert len(res) == 0 + # zero iter + xo = tvm.tir.Var("xo", "int32"), 1 + xi = tvm.tir.Var("xi", "int32"), 129 + y = tvm.tir.Var("y", "int32"), 128 + + res = tvm.arith.detect_iter_map( + [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + ) + def convert_division(divisions): if divisions is None or len(divisions) == 0: @@ -643,6 +652,66 @@ def test_normalize_iter_map_to_expr(): tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) +def test_inverse_affine_iter_map(): + analyzer = tvm.arith.Analyzer() + l0 = create_iter("l0", 64) + l1 = create_iter("l1", 64) + l2 = create_iter("l2", 64) + + # simple case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l0_1_l1_1_fused = ifuse([l0_1, l1_1]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 2 + l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 + l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + + # compound case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l2_1, l2_2 = isplit(l2, 4) + l2_0, l2_1 = isplit(l2_1, 4) + + l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) + + iter_map = tvm.arith.detect_iter_map( + [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) + ) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 3 + l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 + l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 + l2_inverse = ( + floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] + ) + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0 + + # diamond-shape DAG + l0_0, l0_1 = isplit(l0, 16) + l1 = ifuse([l0_1, l0_0]) + l1_0, l1_1 = isplit(l1, 8) + l2 = ifuse([l1_1, l1_0]) + + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 1 + l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8) + l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16) + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + + if __name__ == "__main__": test_split() test_trivial() @@ -652,3 +721,4 @@ def test_normalize_iter_map_to_expr(): test_normalize_iter_map_to_expr() test_subspace_division() test_complex() + test_inverse_affine_iter_map() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c3afa6c65627..231c376c50ca 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -275,6 +275,7 @@ def test_add_index_simplify(): def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") + a, b = tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y) @@ -293,6 +294,8 @@ def test_sub_index_simplify(): # mul co-efficient foldng ck.verify(x - x, 0) + ck.verify(a - a, 0) + ck.verify(a - b, a - b) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) diff --git a/tests/python/unittest/test_autotvm_measure.py b/tests/python/unittest/test_autotvm_measure.py index 9db9f18fa377..a89c69c37d64 100644 --- a/tests/python/unittest/test_autotvm_measure.py +++ b/tests/python/unittest/test_autotvm_measure.py @@ -26,6 +26,8 @@ from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task from tvm import autotvm from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult +from tvm.autotvm import measure +from inspect import Signature def test_task_tuner_without_measurement(): @@ -60,8 +62,30 @@ def test_task_tuner_without_measurement_spawn(): p.join() +def test_task_runner_with_ref_input(): + """test runner ref_input without measurement""" + refinp = [np.random.rand(128, 128) for i in range(3)] + runner = measure.LocalRunner() + runner.ref_input = refinp + + class DummyExecutor(measure.executor.Executor): + def __init__(self): + self.ran_dummy_executor = False + + def submit(self, func, *args, **kwargs): + self.ran_dummy_executor = True + sig = Signature.from_callable(func) + assert sig.bind(*args, **kwargs).arguments["ref_input"] == refinp + return measure.local_executor.LocalFutureNoFork(None) + + runner.executor = DummyExecutor() + runner.run([None], [None]) + assert runner.executor.ran_dummy_executor + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) test_task_tuner_without_measurement() test_task_tuner_without_measurement_spawn() + test_task_runner_with_ref_input() diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 2922a3adf48b..5a32385632fc 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -32,8 +32,64 @@ from tvm.contrib import utils +@tvm.testing.requires_micro +def test_export_operator_model_library_format(): + import tvm.micro as micro + + target = tvm.target.target.micro("host") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + A = tvm.te.placeholder((2,), dtype="int8") + B = tvm.te.placeholder((1,), dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") + sched = tvm.te.create_schedule(C.op) + mod = tvm.build(sched, [A, B, C], tvm.target.Target(target, target), name="add") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + micro.export_model_library_format(mod, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as json_f: + metadata = json.load(json_f) + assert metadata["version"] == 5 + assert metadata["model_name"] == "add" + export_datetime = datetime.datetime.strptime( + metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + ) + assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) + assert metadata["target"] == {"1": str(target)} + + assert metadata["memory"]["add"][0]["dtype"] == "int8" + assert metadata["memory"]["add"][0]["shape"] == [2] + assert metadata["memory"]["add"][0]["size_bytes"] == 2 + + assert metadata["memory"]["add"][1]["dtype"] == "int8" + assert metadata["memory"]["add"][1]["shape"] == [1] + assert metadata["memory"]["add"][1]["size_bytes"] == 1 + + assert metadata["memory"]["add"][2]["dtype"] == "int8" + assert metadata["memory"]["add"][2]["shape"] == [2] + assert metadata["memory"]["add"][2]["size_bytes"] == 2 + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) + + assert ( + len(mod.ir_module_by_target) == 1 + ), f"expect 1 ir_model_by_target: {ir_module_by_target!r}" + for target, ir_mod in mod.ir_module_by_target.items(): + assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 + with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: + assert tir_f.read() == str(ir_mod) + + def validate_graph_json(extract_dir, factory): - with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: + with open(os.path.join(extract_dir, "executor-config", "graph", "graph.json")) as graph_f: graph_json = graph_f.read() assert graph_json == factory.graph_json @@ -46,14 +102,20 @@ def validate_graph_json(extract_dir, factory): @tvm.testing.requires_micro @pytest.mark.parametrize( - "target", + "executor,target,should_generate_interface", [ - ("graph", tvm.target.target.micro("host")), - ("aot", tvm.target.target.micro("host", options="-executor=aot")), + ("graph", tvm.target.target.micro("host"), False), + ("aot", tvm.target.target.micro("host", options="-executor=aot"), False), + ( + "aot", + tvm.target.target.micro( + "host", options="-executor=aot --unpacked-api=1 --interface-api=c" + ), + True, + ), ], ) -def test_export_model_library_format_c(target): - executor, _target = target +def test_export_model_library_format_c(executor, target, should_generate_interface): with utils.TempDirectory.set_keep_for_debug(True): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( @@ -66,8 +128,8 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ ) factory = tvm.relay.build( relay_mod, - _target, - target_host=_target, + target, + target_host=target, mod_name="add", params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, ) @@ -85,13 +147,13 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 5 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(_target)} + assert metadata["target"] == {"1": str(target)} if executor == "graph": assert metadata["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, @@ -117,11 +179,14 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "add_lib1.c")) + assert should_generate_interface == os.path.exists( + os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_add.h") + ) if executor == "graph": validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -165,7 +230,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 5 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -198,7 +263,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -209,13 +274,9 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ @tvm.testing.requires_micro @pytest.mark.parametrize( "target", - [ - ("graph", tvm.target.target.micro("host")), - ("aot", tvm.target.target.micro("host", options="-executor=aot")), - ], + [tvm.target.target.micro("host"), tvm.target.target.micro("host", options="-executor=aot")], ) def test_export_model_library_format_workspace(target): - executor, _target = target with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( """ @@ -229,7 +290,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 } """ ) - factory = tvm.relay.build(relay_mod, _target, target_host=_target, mod_name="qnn_conv2d") + factory = tvm.relay.build(relay_mod, target, target_host=target, mod_name="qnn_conv2d") temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") @@ -244,13 +305,13 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 5 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(_target)} + assert metadata["target"] == {"1": str(target)} assert metadata["memory"]["functions"]["main"] == [ { "constants_size_bytes": 0, @@ -269,11 +330,8 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 @tvm.testing.requires_micro -def test_export_model(): +def test_export_non_dso_exportable(): module = tvm.support.FrontendTestModule() - factory = executor_factory.GraphExecutorFactoryModule( - None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {}, {} - ) temp_dir = utils.tempdir() import tvm.micro as micro diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index ee8032550b39..8306f2f67fa1 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -18,6 +18,8 @@ import pytest from io import StringIO import csv +import os +import json import tvm.testing from tvm.runtime import profiler_vm @@ -26,6 +28,24 @@ from tvm.contrib.debugger import debug_executor +def read_csv(report): + f = StringIO(report.csv()) + headers = [] + rows = [] + reader = csv.reader(f, delimiter=",") + # force parsing + in_header = True + for row in reader: + if in_header: + headers = row + in_header = False + rows = [[] for x in headers] + else: + for i in range(len(row)): + rows[i].append(row[i]) + return dict(zip(headers, rows)) + + @pytest.mark.skipif(not profiler_vm.enabled(), reason="VM Profiler not enabled") @tvm.testing.parametrize_targets def test_vm(target, dev): @@ -39,14 +59,9 @@ def test_vm(target, dev): assert "fused_nn_softmax" in str(report) assert "Total" in str(report) - f = StringIO(report.csv()) - reader = csv.reader(f, delimiter=",") - # force parsing - in_header = True - for row in reader: - if in_header: - assert "Hash" in row - in_header = False + csv = read_csv(report) + assert "Hash" in csv.keys() + assert all([float(x) > 0 for x in csv["Duration (us)"]]) @tvm.testing.parametrize_targets @@ -61,3 +76,60 @@ def test_graph_executor(target, dev): assert "fused_nn_softmax" in str(report) assert "Total" in str(report) assert "Hash" in str(report) + + +@tvm.testing.parametrize_targets("cuda", "llvm") +@pytest.mark.skipif( + tvm.get_global_func("runtime.profiling.PAPIMetricCollector", allow_missing=True) is None, + reason="PAPI profiling not enabled", +) +def test_papi(target, dev): + target = tvm.target.Target(target) + if str(target.kind) == "llvm": + metric = "PAPI_FP_OPS" + elif str(target.kind) == "cuda": + metric = "cuda:::event:shared_load:device=0" + else: + pytest.skip(f"Target {target.kind} not supported by this test") + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, dev) + + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev) + report = vm.profile( + [data], + func_name="main", + collectors=[tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})], + ) + print(report) + assert metric in str(report) + + csv = read_csv(report) + assert metric in csv.keys() + assert any([float(x) > 0 for x in csv[metric]]) + + +@tvm.testing.requires_llvm +def test_json(): + mod, params = mlp.get_workload(1) + + exe = relay.vm.compile(mod, "llvm", params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, tvm.cpu()) + + data = np.random.rand(1, 1, 28, 28).astype("float32") + report = vm.profile(data, func_name="main") + parsed = json.loads(report.json()) + assert "device_metrics" in parsed + assert "calls" in parsed + assert "Duration (us)" in parsed["calls"][0] + assert "microseconds" in parsed["calls"][0]["Duration (us)"] + assert len(parsed["calls"]) > 0 + for call in parsed["calls"]: + assert isinstance(call["Name"], str) + assert isinstance(call["Count"]["count"], int) + assert isinstance(call["Duration (us)"]["microseconds"], float) + + +if __name__ == "__main__": + test_papi("llvm", tvm.cpu()) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index e3422ae45945..75b61d281840 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -29,7 +29,9 @@ def test_basic(dev, target): return exe = relay.vm.compile(mod, target, params=params) - vm = profiler_vm.VirtualMachineProfiler(exe, dev) + code, lib = exe.save() + des_exe = tvm.runtime.vm.Executable.load_exec(code, lib) + vm = profiler_vm.VirtualMachineProfiler(des_exe, dev) data = np.random.rand(1, 1, 28, 28).astype("float32") res = vm.profile(tvm.nd.array(data), func_name="main") diff --git a/tests/python/unittest/test_target_codegen_device.py b/tests/python/unittest/test_target_codegen_device.py index 99b504219c14..b4181fb7b014 100644 --- a/tests/python/unittest/test_target_codegen_device.py +++ b/tests/python/unittest/test_target_codegen_device.py @@ -45,7 +45,7 @@ def check_target(device): assert a.numpy()[0] == value + 3 check_target("cuda") - check_target("vulkan") + check_target("vulkan -from_device=0") @tvm.testing.requires_gpu diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 0551fcd54855..85e9cb12d8d2 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -433,5 +433,99 @@ def do_compute(A, B, n): tvm.testing.assert_allclose(b.numpy(), a_np) +class TestVectorizedIndices: + load_type, store_type = tvm.testing.parameters( + # Load N values, write to N locations. + # Vectorized copy. + ("ramp", "ramp"), + # Load 1 value, write to N locations. + # Scalar load, vectorized store. + # + # Most TVM operations (e.g. schedule[tensor].vectorize(axis)) have + # the broadcast outside of the index, but it is semantically okay + # for the broadcast to be inside the index, and it shows up with + # some optimizations. + ("broadcast", "ramp"), + # Load 1 values, write to 1 location. + # Broadcasting on both sides should be equivalent to a scalar copy. + ("broadcast", "broadcast"), + # Loads N values, write to 1 location. + # Disabled as it would have unclear semantics. + # ("ramp","broadcoast"), + ) + indirect_indices = tvm.testing.parameter(True, False, ids=["reorder", "no_reorder"]) + + @tvm.testing.fixture + def ref_data(self, load_type, store_type, indirect_indices): + n = 4 + + index_map = { + "ramp": np.arange(n), + "broadcast": np.zeros(n, dtype="int32"), + } + + a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32") + b_np = np.zeros(shape=n, dtype=a_np.dtype) + reorder_np = np.arange(n, dtype="int32")[::-1] + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = reorder_np[load_index] + + b_np[store_index] = a_np[load_index] + + return a_np, reorder_np, b_np + + @tvm.testing.fixture + def mod(self, target, load_type, store_type, indirect_indices): + target = tvm.target.Target(target) + + n = 4 + dtype = "int32" + A = te.placeholder((n,), dtype=dtype, name="A") + R = te.placeholder((n,), dtype=dtype, name="R") + + def do_compute(ins, outs): + ib = tvm.tir.ir_builder.create() + A, R = map(ib.buffer_ptr, ins) + B = ib.buffer_ptr(outs[0]) + + if "gpu" in target.keys: + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + index_map = { + "ramp": tvm.tir.Ramp(0, 1, 4), + "broadcast": tvm.tir.Broadcast(0, 4), + } + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = tvm.tir.expr.Load("int32x4", R, load_index) + + transfer = tvm.tir.expr.Load("int32x4", A, load_index) + ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + + return ib.get() + + B = te.extern(A.shape, [A, R], do_compute, dtype="int32") + s = te.create_schedule(B.op) + + return tvm.lower(s, [A, R, B]) + + def test_ramp_broadcast_index(self, target, dev, mod, ref_data): + f = tvm.build(mod, target=target) + + a_np, reorder_np, b_np = ref_data + a = tvm.nd.array(a_np, dev) + r = tvm.nd.array(reorder_np, dev) + b = tvm.nd.array(np.zeros(shape=b_np.shape, dtype="int32"), dev) + f(a, r, b) + tvm.testing.assert_allclose(b.numpy(), b_np) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 30b96546f991..e9626e7f31b4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -189,9 +189,7 @@ def fanout(n, a): assert ir.min.value == 0 assert tvm.ir.structural_equal(ir.extent, n - 3) # Check loopbody - ibody = ir.body - assert isinstance(ibody, tvm.tir.AttrStmt) - abody = ibody.body + abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 60a324727f81..8b504df120e0 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -218,7 +218,7 @@ def test_rfactor(): assert set(BF.op.body[0].axis) == set([k2]) assert s[B].op.body[0].axis[0].dom.extent == n assert len(s[B].all_iter_vars) == 2 - # schedule with splot + # schedule with split s = te.create_schedule(B.op) ko, ki = s[B].split(k1, factor=4) xo, xi = s[B].split(B.op.axis[0], factor=8) diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index e2c2f7f7e0e5..ae5e7051bfba 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -379,8 +379,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body, tvm.tir.For) - assert stmt.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body, tvm.tir.For) + assert stmt.body.loop_var.name == C.op.axis[0].var.name if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index ed4a21397885..2931925965b7 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -309,7 +309,7 @@ def get_B1_realize(x): ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) - assert stmt.node == C.op and len(ret) == 1 + assert stmt.producer == C and len(ret) == 1 def test_tensor_inputs(): diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 36fd80fd07de..8c2b2710f1ba 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -70,6 +70,22 @@ def lca_is_func_root(a: ty.handle) -> None: A.data[0] = 1.0 +@tvm.script.tir +def match_buffer_func(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + AA = tir.match_buffer(A[i, j], ()) + AA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + def test_buffer_load_store(): func = buffer_load_store_func A, B = [func.buffer_map[x] for x in func.params] @@ -115,7 +131,24 @@ def test_lca_func_root(): assert lca[A] is None +def test_match_buffer(): + func = match_buffer_func + A, B = [func.buffer_map[x] for x in func.params] + lca = tir.analysis.detect_buffer_access_lca(func) + + root_block = func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + + # LCA of Buffer C is the inner block + assert lca[A] == block_inner + + # LCA of Buffer C is the main block + assert lca[B] == block + + if __name__ == "__main__": test_buffer_load_store() test_opaque_access() test_lca_func_root() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 7e4d7d87c1e1..7641f0ac46cb 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -39,6 +39,48 @@ def func() -> None: tir.evaluate(D.data) +@tvm.script.tir +def match_buffer_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + with tir.block([8, 8], "block") as [vi, vj]: + tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with tir.block([16, 16], "AAA") as [i, j]: + tir.reads([]) + tir.writes(AA[i, j]) + AAA = tir.match_buffer(AA[i, j], ()) + AAA[()] = 1.0 + tir.evaluate(B0.data) + tir.evaluate(B1.data) + + +@tvm.script.tir +def opaque_block_func() -> None: + with tir.block([], "root"): + A = tir.alloc_buffer((16, 16), "float32") + B = tir.alloc_buffer((16, 16), "float32") + tir.reads([]) + tir.writes([]) + # Need add read/write region manually to avoid triggering block access region detector + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes([B[i, 0:16]]) + for j in range(0, 16): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -53,5 +95,41 @@ def test_block_access_region_detector(): ) +def test_opaque_block(): + alloc_buffers = opaque_block_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + block0 = opaque_block_func.body.block.body.body.block + ret = tir.analysis.get_block_access_region(block0, buffer_var_map) + tvm.ir.assert_structural_equal(block0.reads, ret[0]) + tvm.ir.assert_structural_equal(block0.writes, ret[1]) + + block1 = block0.body.body.block + ret = tir.analysis.get_block_access_region(block1, buffer_var_map) + tvm.ir.assert_structural_equal(block1.reads, ret[0]) + tvm.ir.assert_structural_equal(block1.writes, ret[1]) + + +def test_match_buffer(): + root_block = match_buffer_func.body.block + block = root_block.body.body.body.block + block_inner = block.body[0].body.body.block + alloc_buffers = func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + # Check inner block AAA + ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map) + tvm.ir.assert_structural_equal(block_inner.reads, ret[0]) + tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) + + # Check block + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + # B is opaque access + tvm.ir.assert_structural_equal(block.reads, ret[2]) + + if __name__ == "__main__": test_block_access_region_detector() + test_opaque_block() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 6e081a179059..66f3ef9e599f 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import tir +from tvm._ffi.base import TVMError from tvm.ir.transform import PassContext +import itertools +import pytest def build_tir_func(func): @@ -30,15 +34,69 @@ def build_tir_func(func): def test_scalar_add(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") - c = a + b - c = tir.ret(c) - c = tir.Evaluate(c) - func = tir.PrimFunc([a, b], c) + # All these types should be interchangeable with each other + # E.g. float16 + float32 upconverts the float16 --> float32 + # Meanwhile if an int or float or together the int will be + # cast to the float type. + lhs_types = ["float32", "float16", "int32", "int64"] + rhs_types = ["float32", "float16"] + for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): + # Input vars should be float32, we will cast to test for upcasting between them + lhs_input = tir.Var("lhs", "float32") + rhs_input = tir.Var("rhs", "float32") + lhs = tir.Cast(lhs_type, lhs_input) + rhs = tir.Cast(rhs_type, rhs_input) + output = lhs + rhs + output = tir.ret(output) + output = tir.Evaluate(output) + func = tir.PrimFunc([lhs_input, rhs_input], output) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 3.0 + + +def assignment_helper(store_dtype, value_dtype): + store = tir.Var("store", dtype=store_dtype) + value = tir.Var("value", dtype=value_dtype) + tir.Let(store, value, body=store) + + +def test_fail_implicit_downcasts_same_type(): + # These lists should be sorted + bits = [8, 16, 32, 64] + for type in ["float", "int", "uint"]: + for i in range(len(bits) - 1): + with pytest.raises(TVMError): + assignment_helper( + store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}" + ) + + +def test_cast_between_types(): + # We should only be able to assign values with the same types + bits = [16, 32] + types = ["float", "int", "uint"] + for store_type, store_bits, value_type, value_bits in itertools.product( + types, bits, types, bits + ): + store_dtype = f"{store_type}{store_bits}" + value_dtype = f"{value_type}{value_bits}" + if store_dtype == value_dtype: + assignment_helper(store_dtype, value_dtype) + else: + # TODO: we might want to allow casts between uint and int types + with pytest.raises(TVMError): + assignment_helper(store_dtype, value_dtype) + + +def test_ret_const(): + a = tir.const(0) + b = tir.ret(a) + b = tir.Evaluate(b) + func = tir.PrimFunc([], b) func = build_tir_func(func) - out = func(1.0, 2.0) - assert out == 3.0 + out = func() + assert out == 0 def test_control_flow_jump(): @@ -55,6 +113,13 @@ def test_control_flow_jump(): assert out == 1.0 +def test_exception(): + with pytest.raises(tvm.TVMError): + x = tir.Var(name=1, dtype="int") + + if __name__ == "__main__": test_scalar_add() + test_ret_const() test_control_flow_jump() + test_exception() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 83377e443764..42f9c34133df 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -131,6 +131,23 @@ def assert_simplified_equal(index_simplified, index_direct): ) assert_simplified_equal(index_simplified, index_direct) + # Test Case5 + B = tvm.tir.decl_buffer((1, 14, 14, 1024)) + i = te.size_var("i") + j = te.size_var("j") + k = te.size_var("k") + + index_simplified = B.vload( + ( + idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), + idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), + idxm(idxd((i * 50176 + j * 28672 + k), 1024), 14), + idxm((i * 50176 + j * 28672 + k), 1024), + ) + ) + index_direct = B.vload((0, 0, 0, (i * 50176 + j * 28672 + k))) + assert_simplified_equal(index_simplified, index_direct) + @tvm.testing.requires_llvm def test_buffer_broadcast(): diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 355d3abed559..5b123e883849 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -18,6 +18,7 @@ from tvm import te import numpy as np import tvm.testing +from tvm.topi.math import cast def test_for(): @@ -30,8 +31,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - assert isinstance(body, tvm.tir.AttrStmt) - body = body.body assert isinstance(body, tvm.tir.Allocate) body = body.body assert isinstance(body, tvm.tir.For) @@ -497,6 +496,62 @@ def check_target(target, ir): check_target("vulkan", searchsorted_ir_gpu) +@tvm.testing.requires_gpu +def test_dyn_shared(): + n = te.size_var("n") + dtype = "float32" + A = te.placeholder((n,), name="A") + + def test_device_ir(A, B): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic size + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + + temp[tx] = Aptr[tx] + depth = tvm.tir.log2(cast(n, "float32")) + + with ib.for_range(0, depth) as i: + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + d = n >> (i + 1) + with ib.if_scope(tx < d): + temp[tx] += temp[tx + d] + + Bptr[0] = temp[0] + return ib.get() + + B = te.extern( + (1,), + [A], + lambda ins, outs: test_device_ir(ins[0], outs[0]), + name="reduce", + dtype=dtype, + ) + s = te.create_schedule(B.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + freduce = tvm.build(s, [A, B], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev) + freduce(a, b) + tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": test_prefetch() test_if() @@ -507,3 +562,4 @@ def check_target(target, ir): test_while_collatz() test_while_mandel() test_while_binary_search() + test_dyn_shared() diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py new file mode 100644 index 000000000000..78a8c5117849 --- /dev/null +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -0,0 +1,455 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(tvm.TVMError): + mod = tvm.tir.transform.LowerMatchBuffer()(mod) + + +@tvm.script.tir +def buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + sub_A = tir.match_buffer( + A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2], (4, 1, 2), offset_factor=1 + ) + sub_C = tir.match_buffer( + C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2], (4, 2), offset_factor=1 + ) + for ii, kk in tir.grid(4, 2): + sub_A[ii, 0, kk] += sub_C[ii, kk] + + +@tvm.script.tir +def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i, j, k in tir.grid(4, 16, 8): + with tir.block([]): + tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + for ii, kk in tir.grid(4, 2): + A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] + + +@tvm.ir.register_op_attr("tir.intrin_test", "") +def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): + return 0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + sub_A = tir.match_buffer( + A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], + (16, 1, 16), + strides=[8192, 128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_A.data, + sub_A.elem_offset, + sub_A.strides[0], + sub_A.strides[1], + sub_A.shape[0], + sub_A.shape[1], + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + sub_B = tir.match_buffer( + B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], + (32, 8), + strides=[Bs_0, Bs_1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (32, 64, 128)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(2, 64, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + tir.evaluate( + tir.intrin_test( + A.data, + i * 131072 + j * 128 + k * 16, + 8192, + 128, + 16, + 1, + dtype="handle", + ) + ) + for i, j, k in tir.grid(64, 2, 8): + with tir.block([]): + tir.reads([]) + tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + tir.evaluate( + tir.intrin_test( + B.data, + i * 4096 + j * 2048 + k * 8, + 64, + 1, + 32, + 8, + dtype="handle", + ) + ) + + +@tvm.script.tir +def recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + As_0 = tir.var("int32") + As_1 = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + strides=[As_0, As_1], + offset_factor=1, + ) + sub_B = tir.match_buffer( + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + offset_factor=1, + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + ] + ) + Ass_0 = tir.var("int32") + Ass_1 = tir.var("int32") + sub_sub_A = tir.match_buffer( + sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + strides=[Ass_0, Ass_1], + offset_factor=1, + ) + sub_sub_B = tir.match_buffer( + sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], + (4, 4), + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_sub_A.data, + sub_sub_A.elem_offset, + sub_sub_A.strides[0], + sub_sub_A.strides[1], + sub_sub_A.shape[0], + sub_sub_A.shape[1], + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + sub_sub_B[jjj, kkk] = 1 + + +@tvm.script.tir +def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (64, 64, 64)) + B = tir.match_buffer(b, (64, 64, 64)) + for i, j, k in tir.grid(64, 4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + ] + ) + for jj, kk in tir.grid(4, 4): + with tir.block([]): + tir.reads([]) + tir.writes( + [ + A[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + B[ + i, + j * 16 + jj * 4 : j * 16 + jj * 4 + 4, + k * 16 + kk * 4 : k * 16 + kk * 4 + 4, + ], + ] + ) + tir.evaluate( + tir.intrin_test( + A.data, + i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, + 64, + 1, + 4, + 4, + dtype="handle", + ) + ) + for jjj, kkk in tir.grid(4, 4): + B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1 + + +@tvm.script.tir +def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + Bs_0 = tir.var("int32") + Bs_1 = tir.var("int32") + sub_A = tir.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) + sub_B = tir.match_buffer( + B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1 + ) + for ii, jj in tir.grid(m, m): + sub_A[ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + sub_B.strides[0], + sub_B.strides[1], + sub_B.shape[0], + sub_B.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n * m, m)) + B = tir.match_buffer(b, (n * 2, m * 4)) + for i in range(0, n): + with tir.block([]): + tir.reads([]) + tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + for ii, jj in tir.grid(m, m): + A[i * m + ii, jj] = 1 + for j in range(0, 4): + tir.evaluate( + tir.intrin_test( + B.data, + i * n * (m * 4), + m * 4, + 1, + 2, + m * 4, + dtype="handle", + ) + ) + + +@tvm.script.tir +def rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + sub_A = tir.match_buffer(A[i, j], (), offset_factor=1) + sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) + sub_A[()] = 1 + tir.evaluate( + tir.intrin_test( + sub_B.data, + sub_B.elem_offset, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + B = tir.match_buffer(b, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j], B[i, j]]) + A[i, j] = 1 + tir.evaluate( + tir.intrin_test( + B.data, + i * 8 + j, + 0, + 0, + 0, + 0, + dtype="handle", + ) + ) + + +@tvm.script.tir +def fail_match_load(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads(A[i, j]) + tir.writes([]) + sub_A = tir.match_buffer(A[i, j], ()) + tir.evaluate(tir.load("float32", sub_A.data, 0)) + + +@tvm.script.tir +def fail_match_store(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 8): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j]) + sub_A = tir.match_buffer(A[i, j], ()) + sub_A.data[0] = 1 + + +@tvm.script.tir +def fail_buffer_bind(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + stride = tir.var("int32") + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +@tvm.script.tir +def fail_match_func_param(a: ty.handle, m: ty.handle, n: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1 + ) + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_buffer_load_store(): + _check(buffer_load_store, transformed_buffer_load_store) + + +def test_opaque_access(): + _check(opaque_access, transformed_opaque_access) + + +def test_recursive_match(): + _check(recursive_match, transformed_recursive_match) + + +def test_symbolic_match(): + _check(symbolic_match, transformed_symbolic_match) + + +def test_rank0_buffer(): + _check(rank0_buffer, transformed_rank0_buffer) + + +def test_fail_load_store(): + _check_fail(fail_match_load) + _check_fail(fail_match_store) + + +def test_fail_buffer_bind(): + _check_fail(fail_buffer_bind) + + +def test_fail_match_func_param(): + _check_fail(fail_match_func_param) + + +if __name__ == "__main__": + test_buffer_load_store() + test_opaque_access() + test_recursive_match() + test_symbolic_match() + test_rank0_buffer() + test_fail_load_store() + test_fail_buffer_bind() + test_fail_match_func_param() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 89ca9ac70b92..dbae0b6fa516 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -29,7 +29,7 @@ def test_const(): def test_scalar_dtype_inference(): for data in [ True, - np.bool(1), + bool(1), np.uint8(1), np.uint16(1), np.uint32(1), @@ -48,7 +48,7 @@ def test_scalar_dtype_inference(): for data in [ True, - np.bool(1), + bool(1), np.uint8(1), np.uint16(1), np.uint32(1), @@ -398,7 +398,7 @@ def test_block_blockrealize(): ) ] writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] - match_buffer_region = tvm.tir.MatchBufferRegion( + block_match_buffer = tvm.tir.MatchBufferRegion( match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) ) @@ -410,7 +410,7 @@ def test_block_blockrealize(): body, init=init_body, alloc_buffers=[alloc_buffer], - match_buffers=[match_buffer_region], + match_buffers=[block_match_buffer], annotations={"attr_key": "attr_value"}, ) @@ -462,7 +462,7 @@ def test_block_blockrealize(): assert output.find("reads") != -1 assert output.find("writes") != -1 assert output.find("alloc_buffer") != -1 - assert output.find("match_buffer_region") != -1 + assert output.find("match_buffer") != -1 assert output.find("attr") != -1 assert output.find("with init()") != -1 @@ -471,7 +471,6 @@ def test_block_blockrealize(): test_intimm_cond() test_buffer_load_store() test_vars() - test_scoped_storage_var() test_prim_func() test_cast() test_attr() diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index f1f8cf70d0c9..78eab6bdde9f 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -146,13 +146,22 @@ def verify_callop_float_only(f): rhs = te.var("rhs", dtype=rhs_dtype) if "float" not in lhs_dtype and "float" not in rhs_dtype: check_throws(lambda: f(lhs, rhs)) - elif "float" in lhs_dtype and "float" in rhs_dtype and lhs_dtype != rhs_dtype: - check_throws(lambda: f(lhs, rhs)) elif "float" in lhs_dtype: out = f(lhs, rhs) - assert out.dtype == lhs_dtype - assert out.args[0].dtype == lhs_dtype - assert out.args[1].dtype == lhs_dtype + + # Upcasting for floating point types + dtypes = [lhs_dtype, rhs_dtype] + if "float64" in dtypes: + target_dtype = "float64" + elif "float32" in dtypes: + target_dtype = "float32" + else: + target_dtype = "int32" + assert out.dtype == target_dtype + + # Final inputs are the right type + assert out.args[0].dtype == target_dtype + assert out.args[1].dtype == target_dtype else: out = f(lhs, rhs) assert out.dtype == rhs_dtype diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index 4a914f5063f8..ced8d78ff11a 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest import tvm from tvm import tir from tvm.script import ty @@ -140,6 +143,4 @@ def test_war_dependency(): if __name__ == "__main__": - test_elementwise_dependency() - test_matmul_dependency() - test_war_dependency() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index c34ec8d610d6..d6934c6f407f 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm from tvm import tir @@ -171,7 +173,7 @@ def buffer_matched(a: ty.handle, c: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: - Bb = tir.match_buffer_region(B[vi : vi + 1, vj]) + Bb = tir.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 @@ -354,20 +356,4 @@ def test_compute_inline_multi_loads(): if __name__ == "__main__": - test_compute_inline_elementwise() - test_compute_inline_under_loop() - test_compute_inline_as_dce() - test_compute_inline_multi_consumer() - test_compute_inline_fail_multi_writer() - test_reverse_compute_inline_elementwise() - test_reverse_compute_inline_under_loop() - test_reverse_compute_inline_fail_as_dce() - test_reverse_compute_inline_fail_multi_producer() - test_reverse_compute_inline_fail_multi_reader() - test_reverse_compute_multi_reverse_loads() - test_reverse_compute_fail_multi_reverse_loads() - test_opaque_access_load() - test_opaque_access_store() - test_buffer_matched() - test_compute_inline_predicate() - test_compute_inline_multi_loads() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 1fa658feabe3..6f56eb598894 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm from tvm import tir from tvm.script import ty - # pylint: disable=no-member,invalid-name,unused-variable @@ -65,6 +66,4 @@ def test_tir_schedule_error_none(): if __name__ == "__main__": - test_tir_schedule_error_detail() - test_tir_schedule_error_fast() - test_tir_schedule_error_none() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_instruction.py b/tests/python/unittest/test_tir_schedule_instruction.py new file mode 100644 index 000000000000..9e6f447dd3e6 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_instruction.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +# mypy: ignore-errors +import sys + +import pytest +from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV + + +def test_inst_kind_get(): + kind = InstructionKind.get("EnterPostproc") + assert not kind.is_pure + assert kind.name == "EnterPostproc" + + +def test_inst_construct_1(): + block = BlockRV() + loop0 = LoopRV() + loop1 = LoopRV() + inst = Instruction( + kind=InstructionKind.get("GetLoops"), + inputs=[block], + attrs=[], + outputs=[loop0, loop1], + ) + assert str(inst) == "_, _ = sch.get_loops(block=_)" + assert len(inst.inputs) == 1 + assert len(inst.attrs) == 0 + assert len(inst.outputs) == 2 + assert inst.kind.same_as(InstructionKind.get("GetLoops")) + assert inst.inputs[0].same_as(block) + assert inst.outputs[0].same_as(loop0) + assert inst.outputs[1].same_as(loop1) + + +def test_inst_construct_2(): + block = BlockRV() + inst = Instruction( + kind=InstructionKind.get("ComputeInline"), + inputs=[block], + attrs=[], + outputs=[], + ) + assert str(inst) == "sch.compute_inline(block=_)" + assert len(inst.inputs) == 1 + assert len(inst.attrs) == 0 + assert len(inst.outputs) == 0 + assert inst.kind.same_as(InstructionKind.get("ComputeInline")) + assert inst.inputs[0].same_as(block) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py new file mode 100644 index 000000000000..b285f72ca59f --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -0,0 +1,675 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable,missing-function-docstring,missing-module-docstring + + +@tvm.script.tir +def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + tir.writes([C[vi, vj]]) + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + +@tvm.script.tir +def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + C_rf = tir.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block( + [4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf" + ) as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]: + tir.bind(vi2_inner_inner, i2_inner_inner) + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vi2_outer, i2_outer) + tir.bind(vi2_inner_outer, i2_inner_outer) + with tir.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( + A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + ) + + for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): + with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [ + vi2_inner_inner_1, + vi_1, + vj_1, + ]: + tir.bind(vi2_inner_inner_1, i2_inner_inner_1) + tir.bind(vi_1, i0_1) + tir.bind(vj_1, i1_1) + with tir.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + + +@tvm.script.tir +def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [256, 256]) + B = tir.match_buffer(b, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + C = tir.alloc_buffer([256, 256]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + with tir.block([256, 256], "D") as [vi, vj]: + D[vi, vj] = C[vi, vj] + + +@tvm.script.tir +def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +def matmul_loop_multiple_children(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + D = tir.match_buffer(d, [128, 128]) + + for k, i, j in tir.grid(128, 128, 128): + with tir.block([tir.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: + tir.bind(ck, k) + tir.bind(ci, i) + tir.bind(cj, j) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] + with tir.block([tir.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: + tir.bind(dk, k) + tir.bind(di, i) + tir.bind(dj, j) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] + + +@tvm.script.tir +def square_sum(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] + + +@tvm.script.tir +def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + C_rf = tir.alloc_buffer([16, 256]) + + for i0, i1, i2 in tir.grid(16, 256, 256): + with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + tir.bind(vi2, i2) + tir.bind(b, i0) + tir.bind(i, i1) + with tir.init(): + C_rf[b, vi2] = 0.0 + C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) + + for i0_1, i2_1 in tir.grid(16, 256): + with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: + tir.bind(vi2_1, i2_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + + +@tvm.script.tir +def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + tir.reads([C[b], A[b, i, j]]) + tir.writes([C[b]]) + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in tir.serial(0, 16): + with tir.block([16], "D") as [b_1]: + tir.bind(b_1, i0_1) + tir.reads([C[b_1]]) + tir.writes([D[b_1]]) + D[b_1] = tir.sqrt(C[b_1], dtype="float32") + + +@tvm.script.tir +def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + C_rf = tir.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [ + vi1_i2_fused_inner, + b, + i, + j, + ]: + tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + with tir.init(): + C_rf[vi1_i2_fused_inner, b] = 0.0 + C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) + + for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): + with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: + tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] + + for i0_2 in tir.serial(0, 16): + with tir.block([16], "D") as [b_2]: + tir.bind(b_2, i0_2) + D[b_2] = tir.sqrt(C[b_2], dtype="float32") + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i in tir.serial(0, 128): + for k in tir.parallel(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, k) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] + + +@tvm.script.tir +def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for io, ii_ko_fused, ki in tir.grid(32, 128, 4): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) + tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + + with tir.block([tir.reduce_axis(0, 128)], "B") as [k]: + with tir.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] + + +@tvm.script.tir +def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + B_rf = tir.alloc_buffer([128]) + + with tir.block([128], "B_rf") as [vi0]: + with tir.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] + + with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]: + with tir.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] + + +@tvm.script.tir +def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.alloc_buffer((16, 16)) + D = tir.alloc_buffer((16, 16)) + E = tir.alloc_buffer((16, 16)) + F = tir.match_buffer(f, (16, 16)) + + for i in tir.serial(0, 16): + for j1 in tir.serial(0, 16): + for k1o, k1i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(ck, k1o * 4 + k1i) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, cj, ck] + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i) + tir.bind(dj, j1) + tir.bind(dk, k2o * 4 + k2i) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i) + tir.bind(ej, j2) + tir.bind(ek, k3o * 4 + k3i) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i) + tir.bind(fj, j2) + tir.bind(fk, k4o * 4 + k4i) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] + + +@tvm.script.tir +def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 16]) + C = tir.alloc_buffer([16, 16]) + D = tir.alloc_buffer([16, 16]) + E = tir.alloc_buffer([16, 16]) + F = tir.match_buffer(f, [16, 16]) + C_rf = tir.alloc_buffer([16, 16, 4]) + + for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): + with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: + tir.bind(vk1o, k1o) + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(vk1i, k1i) + with tir.init(): + C_rf[ci, cj, vk1o] = 0.0 + C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] + for i_1 in tir.serial(0, 16): + for j1_1 in tir.serial(0, 16): + for k1o_1 in tir.serial(0, 4): + with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: + tir.bind(vk1o_1, k1o_1) + tir.bind(ci_1, i_1) + tir.bind(cj_1, j1_1) + with tir.init(): + C[ci_1, cj_1] = 0.0 + C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i_1) + tir.bind(dj, j1_1) + tir.bind(dk, (k2o * 4) + k2i) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i_1) + tir.bind(ej, j2) + tir.bind(ek, (k3o * 4) + k3i) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i_1) + tir.bind(fj, j2) + tir.bind(fk, (k4o * 4) + k4i) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reduction_rfactor_matmul(): + s = tir.Schedule(transformed_matmul, debug_mode=True) + C = s.get_block("update") + _, _, _, _, kii = s.get_loops(C) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + func(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum(): + s = tir.Schedule(square_sum, debug_mode=True) + C = s.get_block("C") + _, _, j = s.get_loops(C) + rf_block = s.rfactor(j, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + c = tvm.nd.array(np.zeros((16,), dtype="float32")) + func(a, c) + c_np = np.sum(a_np * a_np, axis=(1, 2)) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum_square_root(): + s = tir.Schedule(transformed_square_sum_square_root, debug_mode=True) + C = s.get_block("C") + _, _, fi = s.get_loops(C) + rf_block = s.rfactor(fi, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + d = tvm.nd.array(np.zeros((16,), dtype="float32")) + func(a, d) + d_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2))) + tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_loop_multiple_children(): + s = tir.Schedule(matmul_loop_multiple_children, debug_mode=True) + C = s.get_block("C") + k, _, _ = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_stage_pipeline(): + s = tir.Schedule(matmul_not_stage_pipeline, debug_mode=True) + C = s.get_block("C") + _, _, k = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block1(): + s = tir.Schedule(element_wise, debug_mode=True) + B = s.get_block("B") + i, _ = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_not_reduction_block2(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block3(): + s = tir.Schedule(rowsum_not_dominant, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_serial_loop(): + s = tir.Schedule(rowsum_not_serial, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_same_buffer_access(): + s = tir.Schedule(matmul_not_same_buffer_access, debug_mode=True) + C = s.get_block("C") + _, _, k = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_factor_axis_range(): + s = tir.Schedule(transformed_matmul, debug_mode=True) + C = s.get_block("update") + _, _, _, _, kii = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, 3) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, -4) + + rf_block = s.rfactor(kii, -3) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + func(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_wrong_reduce_pattern1(): + s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_reduce_pattern2(): + s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_loops1(): + s = tir.Schedule(rowsum, debug_mode=True) + B = s.get_block("B") + i, _ = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_wrong_loops2(): + s = tir.Schedule(rowsum_transformed, debug_mode=True) + B = s.get_block("B") + _, _, ki = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 0) + + +def test_reduction_rfactor_zero_dim(): + s = tir.Schedule(rowsum_zero_dim, debug_mode=True) + B = s.get_block("B") + (k,) = s.get_loops(B) + s.rfactor(k, 0) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128,)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(np.array(1, dtype="float32")) + func(a, b) + b_np = np.array(np.sum(a_np)) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_outermost_loop_multiple_children(): + s = tir.Schedule(multiple_reduction_blocks, debug_mode=True) + D = s.get_block("D") + E = s.get_block("E") + F = s.get_block("F") + _, _, k2o, k2i = s.get_loops(D) + _, _, k3o, k3i = s.get_loops(E) + _, _, k4o, k4i = s.get_loops(F) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4i, 0) + + C = s.get_block("C") + i, j1, k1o, k1i = s.get_loops(C) + s.rfactor(k1o, 2) + tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") + a = tvm.nd.array(a_np) + f = tvm.nd.array(np.zeros((16, 16), dtype="float32")) + func(a, f) + f_np = np.sum(a_np, axis=2) * 4369 + tvm.testing.assert_allclose(f.numpy(), f_np, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py new file mode 100644 index 000000000000..9ac15b8c1986 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -0,0 +1,455 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k in tir.grid(128, 128, n): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i_j_k_fused in tir.serial(0, (n * 16384)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, n)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128, annotations={"useless_annotation": True}): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(10, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i, j, k]]) + tir.writes([B[i, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for fused in tir.serial(0, 2097152): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) + tir.bind(vk, tir.floormod(fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, ((i1 * 64) + i3)) + tir.bind(vj, ((j1 * 32) + j2)) + tir.bind(vk, ((k1 * 8) + k2)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i1 * 64 + i3) + tir.bind(vj, j1 * 64 + j3) + tir.bind(vk, k1 * 64 + k3) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.where( + ( + ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) + and (((k0 * 43) + k1) < 128) + ) + ) + tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + tir.bind(vj, j1) + tir.bind(vk, ((k0 * 43) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in tir.serial(0, 2097152): + with tir.block([], "opaque"): + tir.reads( + [ + A[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + tir.writes( + [ + B[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in tir.grid(8, 16, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i0 * 16 + i1, j, k]]) + tir.writes([B[i0 * 16 + i1, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i0 * 16 + i1) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + B = tir.match_buffer(b, [16, 16]) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +@tvm.script.tir +def opaque_access_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + B = tir.match_buffer(b, (16, 16)) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_fuse(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + + +def test_split(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[2, 1, 64]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, 8]) + tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + + +def test_split_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 1, 64]) + sch.split(j, factors=[2, None, 64]) + sch.split(k, factors=[2, 1, None]) + tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + + +def test_split_with_predicate(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[1000, 2, 3]) + sch.split(j, factors=[None, 129]) + sch.split(k, factors=[3, None]) + tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + + +def test_fuse_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + + +def test_fuse_split_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_split_fail_not_start_with_zero(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + + +def test_fuse_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.fuse(i, j) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.fuse(i, j) + tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + + +def test_split_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.split(i, factors=[None, 16]) + tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + + +def test_split_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.split(j, factors=[None, 4]) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.split(j, factors=[None, 4]) + tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + + +def test_fuse_split_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + + +def test_split_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(k, factors=[10, None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 34041120f252..ca2ee796a2ba 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring - import gc +import sys +import pytest import tvm from tvm import tir from tvm.ir import IRModule @@ -338,16 +339,4 @@ def test_replace_ir_module(): if __name__ == "__main__": - test_replace_direct_write0() - test_replace_direct_write1() - test_replace_copy() - test_replace_partial_copy0() - test_replace_partial_copy1() - test_replace_root_write() - test_replace_root_copy0() - test_replace_root_copy1() - test_replace_root_copy2() - test_replace_root_copy3() - test_replace_block_remap() - test_replace_block_in_opaque_block() - test_replace_ir_module() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index a320812b339f..f77ec0318eea 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest import tvm from tvm import tir from tvm.script import ty @@ -651,19 +653,4 @@ def test_warp_memory_negative(): if __name__ == "__main__": - test_elementwise() - test_matmul() - test_block_in_opaque_block() - test_write_after_read() - test_loop_carried_dependency() - test_concatenate_multi_producer_covered() - test_concatenate_multi_producer_uncovered() - test_lca_at_loop() - test_multi_producer_consumer() - test_elementwise_affine_producer() - test_subblock() - test_subblock_uncovered() - test_thread_binding() - test_equal_ranked_threads() - test_warp_memory() - test_warp_memory_negative() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py new file mode 100644 index 000000000000..cafc6fe1d292 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +# mypy: ignore-errors +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def _make_get_block(name, output): + return Instruction( + kind=InstructionKind.get("GetBlock"), + inputs=[], + attrs=[name, "main"], + outputs=[output], + ) + + +def _make_get_loops(input, outputs): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("GetLoops"), + inputs=[input], + attrs=[], + outputs=outputs, + ) + + +def _make_compute_inline(input): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("ComputeInline"), + inputs=[input], + attrs=[], + outputs=[], + ) + + +def _make_enter_postproc(): + return Instruction( + kind=InstructionKind.get("EnterPostproc"), + inputs=[], + attrs=[], + outputs=[], + ) + + +def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="block", output=b0), + _make_get_loops(input=b0, outputs=[l1, l2]), + ], + decisions={}, + ) + + +def _make_trace_2(b0): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + ], + decisions={}, + ) + + +def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name + if add_postproc: + insts = [ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + _make_get_block(name="C", output=b1), + _make_enter_postproc(), + _make_compute_inline(input=b1), + ] + else: + insts = [ + _make_get_block(name="B", output=b0), + _make_compute_inline(input=b0), + _make_get_block(name="C", output=b1), + ] + return Trace(insts=insts, decisions={}) + + +def test_trace_construct_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="block", func_name="main")', + "l1, l2 = sch.get_loops(block=b0)", + ) + ) + assert len(trace.insts) == 2 + assert len(trace.decisions) == 0 + + +def test_trace_construct_get_decision_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + assert trace.get_decision(trace.insts[0]) is None + assert trace.get_decision(trace.insts[1]) is None + + +def test_trace_construct_append_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + trace.append(inst=_make_get_block("block2", BlockRV())) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="block", func_name="main")', + "l1, l2 = sch.get_loops(block=b0)", + 'b3 = sch.get_block(name="block2", func_name="main")', + ) + ) + + +def test_trace_construct_pop_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + last_inst = trace.insts[-1] + assert trace.pop().same_as(last_inst) + assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")' + + +def test_trace_construct_pop_2(): + trace = Trace([], {}) + assert str(trace) == "" + assert trace.pop() is None + assert str(trace) == "" + + +def test_trace_apply_to_schedule(): + trace = _make_trace_2(BlockRV()) + sch = tir.Schedule(elementwise, debug_mode=True) + trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + + +def test_trace_as_json_1(): + trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) + obj = trace.as_json() + assert obj == [ + [ + ["GetBlock", [], ["block", "main"], ["b0"]], + ["GetLoops", ["b0"], [], ["l1", "l2"]], + ], + [], + ] + + +def test_trace_simplified_1(): + trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + trace = trace.simplified(remove_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + ) + ) + + +def test_trace_simplified_2(): + trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + trace = trace.simplified(remove_postproc=False) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b0)", + 'b1 = sch.get_block(name="C", func_name="main")', + "sch.enter_postproc()", + "sch.compute_inline(block=b1)", + ) + ) + + +def test_apply_json_to_schedule_1(): + trace = _make_trace_2(BlockRV()) + json_obj = trace.as_json() + sch = tir.Schedule(elementwise, debug_mode=True) + Trace.apply_json_to_schedule(json_obj, sch) + tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index af89ca252738..07658978db52 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring +import sys + import pytest import tvm from tvm import tir from tvm.ir import IRModule from tvm.script import ty - # pylint: disable=no-member,invalid-name,unused-variable @@ -108,8 +109,4 @@ def test_tir_schedule_remove_rv(): if __name__ == "__main__": - test_tir_schedule_creation() - test_tir_schedule_get_block() - test_tir_schedule_get_loops() - test_tir_schedule_copy() - test_tir_schedule_remove_rv() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py new file mode 100644 index 000000000000..2e9f1110732a --- /dev/null +++ b/tests/python/unittest/test_tir_specialize.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring, missing-module-docstring + +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, n]) + B = tir.match_buffer(b, [m, n]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, 128]) + B = tir.match_buffer(b, [m, 128]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + x = tir.var("int32") + m = tir.var("int32") + A = tir.match_buffer(a, [m, x * 8]) + B = tir.match_buffer(b, [m, x * 8]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + n = tir.var("int32") + A = tir.match_buffer(a, (m, n), "float32") + C = tir.match_buffer(c, (m, n), "float32") + + B = tir.alloc_buffer((m, n), "float32") + + with tir.block([m, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([m, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_64(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 64), "float32") + C = tir.match_buffer(c, (128, 64), "float32") + B = tir.alloc_buffer((128, 64), "float32") + + with tir.block([128, 64], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, 64], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, (128, n), "float32") + C = tir.match_buffer(c, (128, n), "float32") + B = tir.alloc_buffer((128, n), "float32") + + with tir.block([128, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def mem_copy( + a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32 +) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) + B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +def test_specialize_nothing(): + func = matmul.specialize({}) + assert func.same_as(matmul) # Pointer the same + + +def test_specialize_matmul(): + a, _, _, n = matmul.params + # fully specialized + func = matmul.specialize({a: tir.decl_buffer((128, 128))}) + tvm.ir.assert_structural_equal(func, matmul_128) + # partially specialized + func = matmul.specialize({n: 128}) + tvm.ir.assert_structural_equal(func, matmul_m_128) + # symbolic specialized + func = matmul.specialize({n: tir.Var("x", "int32") * 8}) + tvm.ir.assert_structural_equal(func, matmul_m_8x) + + +def test_specialize_elemwise(): + a, c = element_wise.params + C = element_wise.buffer_map[c] + # fully specialized + func = element_wise.specialize({a: tir.decl_buffer((128, 64))}) + tvm.ir.assert_structural_equal(func, element_wise_128_64) + # partially specialized + func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))}) + tvm.ir.assert_structural_equal(func, element_wise_128_n) + + +def test_specialize_mem_copy(): + a, _, m, n, p, q = mem_copy.params + # fully specialized + func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + # partially specialized + func = mem_copy.specialize({q: n}) + tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) + + +def test_specialize_recursive_load(): + # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]] + pass + + +if __name__ == "__main__": + test_specialize_nothing() + test_specialize_matmul() + test_specialize_elemwise() + test_specialize_mem_copy() + test_specialize_recursive_load() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7c06b5ef5ca1..a469c6d0cc13 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -293,6 +293,52 @@ def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: C[i, j] = B[0, j] +@tvm.script.tir +def match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((16, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[i, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[i, j], ()) + C1[()] = B2[()] * 2.0 + + +@tvm.script.tir +def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + C0 = tir.match_buffer(C[i, 0:16], (16)) + B = tir.alloc_buffer((1, 16)) + with tir.block([]): + B0 = tir.match_buffer(B[0, 0:16], (16)) + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + B1 = tir.match_buffer(B0[j], ()) + B1[()] = A1[()] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + C1 = tir.match_buffer(C0[j], ()) + B2 = tir.match_buffer(B[0, j], ()) + C1[()] = B2[()] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -321,6 +367,10 @@ def test_complex(): _check(complex_func, compacted_complex_func) +def test_match_buffer(): + _check(match_buffer_func, compacted_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -329,3 +379,4 @@ def test_complex(): test_warp_mem() test_symbolic() test_complex() + test_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 2d45118f39f2..7dacd8e046cc 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body.body + body = stmt.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body.body) + slist = tvm.tir.stmt_list(stmt[0].body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c997748649cd..6929a329ac0f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -35,7 +35,7 @@ def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32") + B = tir.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with tir.block() as []: tir.reads(A[i, j]) @@ -111,7 +111,7 @@ def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32 with tir.block([]): tir.reads(A[i, m]) tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32") + B = tir.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): with tir.block([]) as []: tir.reads(A[i, j]) @@ -190,8 +190,8 @@ def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) - B = tir.alloc_buffer((32,)) - C = tir.alloc_buffer((32,)) + B = tir.alloc_buffer((32,), scope="global") + C = tir.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 252a187dbdc5..b111e2be75c7 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -636,7 +636,7 @@ def test_hoisting_block_scope_4(): def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") + data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index ceb32c484c6d..9b37bcaaacbc 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,8 +47,8 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 3e7a5a0cb300..673267a9b1fa 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -49,13 +49,13 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert len(stmt.body.body.extents) == 3 @@ -94,11 +94,11 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.body.body.extents) == 3 + assert stmt.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.extents) == 3 def test_vthread_if_then_else(): @@ -119,7 +119,7 @@ def test_vthread_if_then_else(): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"].body + )["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 12ad16dfe092..65e317dfbcb8 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,7 +38,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.node == cp @@ -58,7 +58,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 9e8848083908..c632f744bb81 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -156,7 +156,7 @@ def test_thread_axis(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -178,7 +178,7 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] body = stmt.body.body.body.body assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -229,7 +229,7 @@ def test_thread_axis2(): _, x = s[C].split(x, factor=m) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] for_body = stmt.body.body.body.body[0] assert "threadIdx" not in str(for_body.extent) diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 3fb8331d39fc..badf5e0e4d10 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -18,6 +18,8 @@ from tvm import tir from tvm.script import ty +# pylint: disable=no-self-argument + @tvm.script.tir class WithInit: @@ -43,11 +45,46 @@ def main(a: ty.handle, b: ty.handle) -> None: B[i] += A[i, j, k] +@tvm.script.tir +class InitWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with tir.init(): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + +@tvm.script.tir +class BranchWithMatchBuffer: + def main(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [64, 64, 64]) + B = tir.match_buffer(b, [64]) + + with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + BB = tir.match_buffer(B[i], ()) + AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + if (j == 0) and (k == 32): + BB[()] = tir.float32(0) + BB[()] += AA[j, k] + + def test_lower_reduction(): origin_mod = WithInit() mod = tvm.tir.transform.LowerInitBlock()(origin_mod) tvm.ir.assert_structural_equal(mod, WithBranch(), True) +def test_lower_match_buffer(): + origin_mod = InitWithMatchBuffer() + mod = tvm.tir.transform.LowerInitBlock()(origin_mod) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + + if __name__ == "__main__": test_lower_reduction() + test_lower_match_buffer() diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ef474c15cfbb..84bf0c4d52fd 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,8 +47,9 @@ def test_lower_warp_memory_local_scope(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - assert fdevice.body.body.value.value == "local" - assert fdevice.body.body.body.extents[0].value == 2 + allocate = fdevice.body.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" + assert fdevice.body.body.extents[0].value == 2 @tvm.testing.requires_cuda @@ -72,8 +73,8 @@ def test_lower_warp_memory_correct_indices(): bounds = tvm.te.schedule.InferBound(s) ir = tvm.te.schedule.ScheduleOps(s, bounds) - inner_func = ir.body.body.body.body - store_A_warp = inner_func.body.seq[0].body.body + inner_func = ir.body.body.body + store_A_warp = inner_func.seq[0].body.body indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py new file mode 100644 index 000000000000..9c511f1de6b9 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +import tvm.testing +from tvm.topi.math import cast + + +def run_passes(sch, args): + bounds = tvm.te.schedule.InferBound(sch) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.VectorizeLoop(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] + )(mod) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +@tvm.testing.requires_gpu +def test_matmul_dyn_shared(): + n = 1024 + block = 16 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n // block) + ib.scope_attr(by, "thread_extent", n // block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + # Create a dynamic shared memory for the accumulation. + # This is for testing merging dynamic shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_sh[ty, tx] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + + ib.emit(syncthread()) + + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="matmul", + dtype="float32", + ) + s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + expected_alloc_size = block * block * 3 * 4 + verify_single_allocation(mod["main"].body, expected_alloc_size) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_vectorized_store(): + """Test vectorized store into dynamic shared memory""" + n = te.size_var("n") + A = te.placeholder((n,), name="A", dtype="float16") + B = te.placeholder((n,), name="B", dtype="float32") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + values_per_thread = 4 + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, values_per_thread, kind="vectorize") as i: + A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] + + with ib.for_range(0, values_per_thread) as i: + Cptr[tx * values_per_thread + i] = ( + cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] + ) + + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 + ) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_reuse_and_merge(): + n = 64 + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + C = te.placeholder((te.size_var("n_dyn"),), name="C", dtype="float32") + + def test_device_ir(A, B, C, D): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + Dptr = ib.buffer_ptr(D) + + A_sh[tx] = Aptr[tx] + Dptr[tx] = A_sh[tx] + + B_sh[tx] = Bptr[tx] + Dptr[tx] += B_sh[tx] + + C_sh[tx] = Cptr[tx] # C cannot reuse other buffers since it size is dynamic + Dptr[tx] += C_sh[tx] + + return ib.get() + + D = te.extern( + (n,), + [A, B, C], + lambda ins, outs: test_device_ir(ins[0], ins[1], ins[2], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(D.op) + + mod = run_passes(s, [A, B, C, D]) + # merged allocation + # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); + verify_single_allocation(mod["main"].body) + + +if __name__ == "__main__": + test_matmul_dyn_shared() + test_dyn_shared_vectorized_store() + test_dyn_shared_reuse_and_merge() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index d42c5e1f8626..022c964df0c7 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -115,6 +115,28 @@ def transformed_func() -> None: ) +@tvm.script.tir +def match_buffer_func() -> None: + C = tir.alloc_buffer((128, 128)) + with tir.block([128]) as [vi]: + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + +@tvm.script.tir +def transformed_match_buffer_func() -> None: + for i in range(0, 128): + with tir.block([128]) as [vi]: + tir.bind(vi, i) + C = tir.alloc_buffer((128, 128)) + C0 = tir.match_buffer(C[vi, 0:128], (128)) + with tir.block([128]) as [jj]: + C1 = tir.match_buffer(C0[jj], ()) + C1[()] = 0 + + def test_elementwise(): _check(element_func, transformed_element_func) @@ -123,6 +145,11 @@ def test_locate_buffer_allocation(): _check(original_func, transformed_func) +def test_match_buffer_allocation(): + _check(match_buffer_func, transformed_match_buffer_func) + + if __name__ == "__main__": test_elementwise() test_locate_buffer_allocation() + test_match_buffer_allocation() diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 2d1fea01aa32..0e9ab862a9c8 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.body.extents[0].value == 17 * 8 + assert stmt.extents[0].value == 17 * 8 def test_flatten_double_buffer(): @@ -114,8 +114,8 @@ def test_flatten_double_buffer(): )(mod) stmt = mod["main"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index dbe7e04700d9..9e738b136b17 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -228,6 +228,47 @@ def verify(n): assert num_alloc[0] == 1 +def test_storage_combine_with_vectorization(): + n = 1024 + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute((n,), lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + AA = s.cache_read(A, "global:tag", readers=[C]) + BB = s.cache_read(B, "global:tag", readers=[C]) + CC = s.cache_write(C, "global:tag") + s[CC].vectorize(s[CC].op.axis[0]) + bounds = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + mod = tvm.tir.transform.Simplify()(mod) + stmt = mod["main"].body + num_alloc = [0] + + def verify(v): + # find add op + if ( + isinstance(v, tvm.tir.Add) + and isinstance(v.a, tvm.tir.Load) + and isinstance(v.b, tvm.tir.Load) + ): + lhs_ramp = v.a.index + rhs_ramp = v.b.index + # these two ramp load should not overlap + assert lhs_ramp.lanes == n + assert rhs_ramp.lanes == n + assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n + elif isinstance(v, tvm.tir.Allocate): + num_alloc[0] += 1 + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + def test_storage_share_gpu(): m = te.var("m") A = [te.placeholder((m), name="A")] @@ -257,9 +298,9 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.tir.AttrStmt): - if n.attr_key == "storage_scope": - alloc_stats[n.value.value] += 1 + if isinstance(n, tvm.tir.Allocate): + scope = n.buffer_var.type_annotation.storage_scope + alloc_stats[scope] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 @@ -276,7 +317,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body, tvm.tir.Allocate) @@ -293,7 +334,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body.body.body, tvm.tir.Allocate) @@ -315,7 +356,6 @@ def get_mod(kind="serial"): mod = get_mod(kind="parallel") # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -325,11 +365,9 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # j[0] = 0 # while((j[0] < 10)){ @@ -338,11 +376,10 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A mod = get_mod(kind="serial") # for (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -352,10 +389,8 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - # // attr [j] storage_scope = "global" + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # for (i, 0, n) { # j[0] = 0 @@ -365,7 +400,7 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body, tvm.tir.Allocate) # A def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): @@ -648,6 +683,7 @@ def verify(n): test_parallel_alloc() test_while_alloc() test_storage_combine() + test_storage_combine_with_vectorization() test_storage_share_gpu() test_inplace_rule2() diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 030c01713927..ffdf4b5916c4 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -19,6 +19,21 @@ import tvm.testing +def run_passes(inputs, stmt): + func = tvm.te.schedule.SchedulePostProcToPrimFunc(inputs, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + cuda_target = tvm.target.Target("cuda") + + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) + )(mod) + + mod = tvm.tir.transform.SplitHostDevice()(mod) + return tvm.tir.transform.ThreadSync("shared")(mod) + + @tvm.testing.requires_cuda def test_thread_storage_sync(): m = te.size_var("m") @@ -38,23 +53,46 @@ def test_thread_storage_sync(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = run_passes([A, A2], stmt) + f = mod["test_kernel0"] + body_list = tvm.tir.stmt_list(f.body.body.body) + assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) - cuda_target = tvm.target.Target("cuda") - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) - )(mod._move()) +@tvm.testing.requires_cuda +def test_sync_else_branch(): + def ir(A, B): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - cuda_target = tvm.target.Target("cuda") - f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", 1) + + local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") + shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + local[i] = Aptr[i] + with ib.else_scope(): + shared[i] = Aptr[i] + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + Bptr[i] = local[i] + with ib.else_scope(): + Bptr[i] = shared[i] + + return ib.get() + + A = tvm.tir.decl_buffer((8,), "float32") + B = tvm.tir.decl_buffer((8,), "float32") + stmt = ir(A, B) + mod = run_passes([A, B], stmt) + assert "@tir.tvm_storage_sync" in str(mod) if __name__ == "__main__": test_thread_storage_sync() + test_sync_else_branch() diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 76d4d9b98043..4798e9e09865 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -177,22 +177,111 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -def test_complete_opaque_block_error(): - def render(e): - pass +@tvm.script.tir +def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: + data_buf = tir.match_buffer(data, (16, 16), "float32") + index_buf = tir.match_buffer(index, (1,), "int32") + out_buf = tir.alloc_buffer((16, 16), "float32") + + with tir.block([16, 16]) as [vi, vj]: + out_buf[vi, vj] = data_buf[vi, index_buf[0]] + + +@tvm.script.tir +def expected_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: + index_buf = tir.match_buffer( + index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1 + ) + data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) + for i0, i1 in tir.grid(16, 16): + with tir.block([16, 16], "") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.reads([data_buf[vi, 0:16], index_buf[0]]) + tir.writes([out_buf[vi, vj]]) + out_buf[vi, vj] = data_buf[vi, index_buf[0]] + + +@tvm.script.tir +def func_with_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: + data_buf = tir.match_buffer(data, (16, 16), "float32") + index_buf = tir.match_buffer(index, (1,), "int32") + out_buf = tir.alloc_buffer((16, 16), "float32") - override_renderer(render) + with tir.block([16, 16]) as [vi, vj]: + out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] - try: - from_source(func_with_opaque_block) - except tvm.error.DiagnosticError: - return - assert False + +@tvm.script.tir +def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: + index_buf = tir.match_buffer( + index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1 + ) + data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) + for i0, i1 in tir.grid(16, 16): + with tir.block([16, 16], "") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.reads([data_buf[0:16, 0:16], index_buf[0]]) + tir.writes([out_buf[vi, vj]]) + out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] + + +def test_complete_buffer_indices(): + new_func = tvm.script.from_source(tvm.script.asscript(func_with_bufferslice_indices)) + tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices) + new_func = tvm.script.from_source(tvm.script.asscript(func_with_recursive_bufferslice_indices)) + tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) + + +@tvm.script.tir +def match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + for j in range(0, 16): + with tir.block([]) as []: + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +@tvm.script.tir +def expected_match_buffer_func(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + for i in range(0, 16): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, 0:16]) + A0 = tir.match_buffer(A[i, 0:16], (16)) + with tir.block([]): + tir.reads([]) + tir.writes(A0[0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads([]) + tir.writes(A0[j]) + A1 = tir.match_buffer(A0[j], ()) + A1[()] = 1.0 + + +def test_complete_match_buffer(): + tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() - test_complete_opaque_block_error() test_complete_part_region() + test_complete_buffer_indices() + test_complete_match_buffer() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 052217b32cb5..7aeceeccfa89 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -202,7 +202,7 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer_region(vi) # error + A = tir.match_buffer(vi) # error tir.evaluate(1.0) @@ -291,8 +291,36 @@ def error_index_type() -> None: A[vi, vj] = A[vi, 0.0] + 1 # error +def error_bufferslice_index_type() -> None: + A = tir.alloc_buffer((1,), "float32") + B = tir.alloc_buffer((16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16]) as [vi, vj]: + C[vi, vj] = B[vi, A[0]] # error + + def test_error_index_type(): check_error(error_index_type, 4) + check_error(error_bufferslice_index_type, 6) + + +def error_index_with_stop() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + A[vi, vj] = A[vi, 1:10] + 1 # error + + +def error_bufferslice_index_with_stop() -> None: + A = tir.alloc_buffer((1,), "int32") + B = tir.alloc_buffer((16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16]) as [vi, vj]: + C[vi, vj] = B[vi, A[0:1]] # error + + +def test_error_index_with_stop_slice(): + check_error(error_index_with_stop, 4) + check_error(error_bufferslice_index_with_stop, 6) def mismatch_args() -> None: @@ -335,6 +363,23 @@ def test_tvm_exception_catch(): check_error(intrin_except_assign, 3) +def buffer_shape_mismatch(a: ty.handle) -> None: + A = tir.match_buffer(a, (8, 8)) + for i, j in tir.grid(8, 2): + with tir.block([]): + tir.reads([]) + tir.writes([A[i, j * 4 : j * 4 + 4]]) + sub_A = tir.match_buffer( + A[i, j * 4 : j * 4 + 4], (5) + ) # error: shape mismatched between 4 and 5 + for jj in range(0, 4): + sub_A[i, j * 4 + jj] = 1 + + +def test_match_buffer_shape_mismatch(): + check_error(buffer_shape_mismatch, 7) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -383,5 +428,7 @@ def render(e): test_opaque_access_during_complete() test_convert_slice_to_bufferload() test_error_index_type() + test_error_index_with_stop_slice() test_mismatch_args() test_tvm_exception_catch() + test_match_buffer_shape_mismatch() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 164949552859..0566ff5044d9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -277,8 +277,8 @@ def mmult( } ) # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") + C_global = tir.buffer_var("float32", "global") + packedB = tir.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") @@ -2820,6 +2820,43 @@ def test_for_thread_binding(): assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" +@tvm.script.tir +def match_buffer_region(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16), "float32") + B = tir.match_buffer(b, (1), "float32") + + with tir.block([16, 4]) as [vi, vj]: + C = tir.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + with tir.block([4]) as [vii]: + D = tir.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in tir.grid(4, 4): + B[0] += D[i, 0, j] + + +def test_match_buffer_region(): + func = match_buffer_region + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body, tir.stmt.BlockRealize) + root = rt_func.body.block + + assert isinstance(root.body, tir.stmt.For) + assert isinstance(root.body.body, tir.stmt.For) + assert isinstance(root.body.body.body, tir.stmt.BlockRealize) + outer_block = root.body.body.body.block + assert len(outer_block.match_buffers) == 1 + buffer_C = outer_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + + assert isinstance(outer_block.body, tir.stmt.For) + assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) + inner_block = outer_block.body.body.block + assert len(inner_block.match_buffers) == 1 + buffer_D = inner_block.match_buffers[0].buffer + tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + + @tvm.script.tir def block_elements(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") @@ -2832,10 +2869,10 @@ def block_elements(a: ty.handle, b: ty.handle) -> None: tir.writes(B[0, 0]) tir.block_attr({"attr_key": "attr_value"}) C = tir.alloc_buffer((4, 4), dtype="float32") - D = tir.match_buffer_region(A[0:4, 0]) + D = tir.match_buffer(A[0:4, 0], (4, 1)) with tir.init(): B[0, 0] = tir.float32(0) - B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] def test_block_elements(): @@ -2946,6 +2983,34 @@ def test_minmax(): tvm.ir.assert_structural_equal(func, rt_func) +@tvm.script.tir +def abs(a: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + + with tir.block([128, 128], "A") as [vi, vj]: + A[vi, vj] = tir.abs(A[vi, vj]) + + +def test_abs(): + func = abs + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +@tvm.script.tir +def constant_folding(a: ty.handle) -> None: + A = tir.match_buffer(a, (), "float32") + A[()] = tir.min(2.2, 5.2) + A[()] = tir.max(tir.float32(2.2), tir.float32(tir.float32(5.2))) + A[()] = tir.min(2.2, 5.0) + + +def test_script_printer(): + func = constant_folding + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() @@ -2960,5 +3025,8 @@ def test_minmax(): test_element_wise() test_predicate() test_for_thread_binding() + test_match_buffer_region() test_block_elements() test_opaque_block() + test_abs() + test_script_printer() diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index e534e8bc114b..753d17d8afe5 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -36,3 +36,6 @@ python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.3.0 # Jenkinsfile. We expect config.cmake to be present from pack_lib(). # TODO(areusch): Make pack_lib() pack all the data dependencies of TVM. (cd build && cmake .. && make standalone_crt) + +# Ensure no stale pytest-results remain from a previous test run. +(cd build && rm -rf pytest-results) diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index cae28467830f..cb42b9a71d59 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -35,3 +35,4 @@ echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "/opt/acl"\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 2af91d7c6b8e..167e5becd4a7 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -46,3 +46,4 @@ echo set\(USE_ETHOSN_HW OFF\) >> config.cmake echo set\(USE_VITIS_AI ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 609325c9962b..6e20087df34a 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -45,3 +45,4 @@ echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(USE_TENSORRT_CODEGEN ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_vulkan.sh b/tests/scripts/task_config_build_gpu_vulkan.sh index 17d11397718d..a5a26a1db0fb 100755 --- a/tests/scripts/task_config_build_gpu_vulkan.sh +++ b/tests/scripts/task_config_build_gpu_vulkan.sh @@ -30,3 +30,4 @@ echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_i386.sh b/tests/scripts/task_config_build_i386.sh index 05acbb022124..ce244fa59276 100755 --- a/tests/scripts/task_config_build_i386.sh +++ b/tests/scripts/task_config_build_i386.sh @@ -34,3 +34,5 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake + diff --git a/tests/scripts/task_config_build_qemu.sh b/tests/scripts/task_config_build_qemu.sh index 086ca8034dc9..d821da0eb262 100755 --- a/tests/scripts/task_config_build_qemu.sh +++ b/tests/scripts/task_config_build_qemu.sh @@ -29,3 +29,4 @@ echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_wasm.sh b/tests/scripts/task_config_build_wasm.sh index 78dc7550028b..490e9446007e 100755 --- a/tests/scripts/task_config_build_wasm.sh +++ b/tests/scripts/task_config_build_wasm.sh @@ -34,3 +34,4 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 71c03888fddf..b05acb090c2f 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -19,3 +19,9 @@ set -o pipefail echo "Checking MyPy Type defs in the schedule package." mypy --check-untyped-defs python/tvm/tir/schedule + +echo "Checking MyPy Type defs in the analysis package." +mypy --check-untyped-defs python/tvm/tir/analysis/ + +echo "Checking MyPy Type defs in the transofrm package." +mypy --check-untyped-defs python/tvm/tir/transform/ diff --git a/tests/scripts/task_python_vta_tsim.sh b/tests/scripts/task_python_vta_tsim.sh index 3a6a35e5a06f..4c21f46c5f81 100755 --- a/tests/scripts/task_python_vta_tsim.sh +++ b/tests/scripts/task_python_vta_tsim.sh @@ -27,6 +27,9 @@ export VTA_HW_PATH=`pwd`/3rdparty/vta-hw export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 +# temporary skip tsim test, enable later +exit 0 + # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 3804b1496d05..468c4d40b942 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -261,16 +261,18 @@ def visit_constant(self, c): ] ) +############################################################################### # By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will # dump out the module IR when ``FoldConstant`` is done. Users can plug in this # pass after any pass they want to debug for viewing the optimization effect. # -# There is a more flexible debugging mechanism also exposed by the build configuration -# object. One can pass a tracing function which can be used to execute arbitrary code -# before and/or after each pass. A tracing function will receive a :py::class:`tvm.IRModule`, -# a :py:class:`tvm.transform.PassInfo` object, -# and a boolean indicating whether you are executing before, or after a pass. -# An example is below. +# There is a more flexible debugging mechanism. One can implement a ``PassInstrument`` +# class to execute arbitrary code not only before and/or after each pass but also +# at entering/exiting ``PassContext``. See :ref:`pass_instrument_cpp_backend` +# for more details. +# +# Here we use :py::func`tvm.instrument.pass_instrument` decorator to implement +# a PassInsturment class printing IR before execution of each passes: @tvm.instrument.pass_instrument diff --git a/tutorials/dev/use_pass_instrument.py b/tutorials/dev/use_pass_instrument.py new file mode 100644 index 000000000000..3369304a651d --- /dev/null +++ b/tutorials/dev/use_pass_instrument.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +.. _tutorial-use-pass-instrument: + +How to Use TVM Pass Instrument +============================== +**Author**: `Chi-Wei Wang `_ + +As more and more passes are implemented, it becomes useful to instrument +pass execution, analyze per-pass effects, and observe various events. + +We can instrument passes by providing a list of :py:class:`tvm.ir.instrument.PassInstrument` +instances to :py:class:`tvm.transform.PassContext`. We provide a pass instrument +for collecting timing information (:py:class:`tvm.ir.instrument.PassTimingInstrument`), +but an extension mechanism is available via the :py:func:`tvm.instrument.pass_instrument` decorator. + +This tutorial demostrates how developers can use ``PassContext`` to instrument +passes. Please also refer to the :ref:`pass-infra`. +""" +import tvm +import tvm.relay as relay +from tvm.relay.testing import resnet +from tvm.contrib.download import download_testdata +from tvm.relay.build_module import bind_params_by_name +from tvm.ir.instrument import ( + PassTimingInstrument, + pass_instrument, +) + + +############################################################################### +# Create An Example Relay Program +# ------------------------------- +# We use pre-defined resnet-18 network in Relay. +batch_size = 1 +num_of_image_class = 1000 +image_shape = (3, 224, 224) +output_shape = (batch_size, num_of_image_class) +relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape) +print("Printing the IR module...") +print(relay_mod.astext(show_meta_data=False)) + + +############################################################################### +# Create PassContext With Instruments +# ----------------------------------- +# To run all passes with an instrument, pass it via the ``instruments`` argument to +# the ``PassContext`` constructor. A built-in ``PassTimingInstrument`` is used to +# profile the execution time of each passes. +timing_inst = PassTimingInstrument() +with tvm.transform.PassContext(instruments=[timing_inst]): + relay_mod = relay.transform.InferType()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + # before exiting the context, get profile results. + profiles = timing_inst.render() +print("Printing results of timing profile...") +print(profiles) + + +############################################################################### +# Use Current PassContext With Instruments +# ---------------------------------------- +# One can also use the current ``PassContext`` and register +# ``PassInstrument`` instances by ``override_instruments`` method. +# Note that ``override_instruments`` executes ``exit_pass_ctx`` method +# if any instrument already exists. Then it switches to new instruments +# and calls ``enter_pass_ctx`` method of new instruments. +# Refer to following sections and :py:func:`tvm.instrument.pass_instrument` for these methods. +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments([timing_inst]) +relay_mod = relay.transform.InferType()(relay_mod) +relay_mod = relay.transform.FoldScaleAxis()(relay_mod) +profiles = timing_inst.render() +print("Printing results of timing profile...") +print(profiles) + + +############################################################################### +# Register empty list to clear existing instruments. +# +# Note that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called. +# Profiles are cleared so nothing is printed. +cur_pass_ctx.override_instruments([]) +# Uncomment the call to .render() to see a warning like: +# Warning: no passes have been profiled, did you enable pass profiling? +# profiles = timing_inst.render() + + +############################################################################### +# Create Customized Instrument Class +# ---------------------------------- +# A customized instrument class can be created using the +# :py:func:`tvm.instrument.pass_instrument` decorator. +# +# Let's create an instrument class which calculates the change in number of +# occurrences of each operator caused by each pass. We can look at ``op.name`` to +# find the name of each operator. And we do this before and after passes to calculate the difference. + + +@pass_instrument +class RelayCallNodeDiffer: + def __init__(self): + self._op_diff = [] + # Passes can be nested. + # Use stack to make sure we get correct before/after pairs. + self._op_cnt_before_stack = [] + + def enter_pass_ctx(self): + self._op_diff = [] + self._op_cnt_before_stack = [] + + def exit_pass_ctx(self): + assert len(self._op_cnt_before_stack) == 0, "The stack is not empty. Something wrong." + + def run_before_pass(self, mod, info): + self._op_cnt_before_stack.append((info.name, self._count_nodes(mod))) + + def run_after_pass(self, mod, info): + # Pop out the latest recorded pass. + name_before, op_to_cnt_before = self._op_cnt_before_stack.pop() + assert name_before == info.name, "name_before: {}, info.name: {} doesn't match".format( + name_before, info.name + ) + cur_depth = len(self._op_cnt_before_stack) + op_to_cnt_after = self._count_nodes(mod) + op_diff = self._diff(op_to_cnt_after, op_to_cnt_before) + # only record passes causing differences. + if op_diff: + self._op_diff.append((cur_depth, info.name, op_diff)) + + def get_pass_to_op_diff(self): + """ + return [ + (depth, pass_name, {op_name: diff_num, ...}), ... + ] + """ + return self._op_diff + + @staticmethod + def _count_nodes(mod): + """Count the number of occurrences of each operator in the module""" + ret = {} + + def visit(node): + if isinstance(node, relay.expr.Call): + if hasattr(node.op, "name"): + op_name = node.op.name + else: + # Some CallNode may not have 'name' such as relay.Function + return + ret[op_name] = ret.get(op_name, 0) + 1 + + relay.analysis.post_order_visit(mod["main"], visit) + return ret + + @staticmethod + def _diff(d_after, d_before): + """Calculate the difference of two dictionary along their keys. + The result is values in d_after minus values in d_before. + """ + ret = {} + key_after, key_before = set(d_after), set(d_before) + for k in key_before & key_after: + tmp = d_after[k] - d_before[k] + if tmp: + ret[k] = d_after[k] - d_before[k] + for k in key_after - key_before: + ret[k] = d_after[k] + for k in key_before - key_after: + ret[k] = -d_before[k] + return ret + + +############################################################################### +# Apply Passes and Multiple Instrument Classes +# -------------------------------------------- +# We can use multiple instrument classes in a ``PassContext``. +# However, it should be noted that instrument methods are executed sequentially, +# obeying the order of ``instruments`` argument. +# So for instrument classes like ``PassTimingInstrument``, it is inevitable to +# count-up the execution time of other instrument classes to the final +# profile result. +call_node_inst = RelayCallNodeDiffer() +desired_layouts = { + "nn.conv2d": ["NHWC", "HWIO"], +} +pass_seq = tvm.transform.Sequential( + [ + relay.transform.FoldConstant(), + relay.transform.ConvertLayout(desired_layouts), + relay.transform.FoldConstant(), + ] +) +relay_mod["main"] = bind_params_by_name(relay_mod["main"], relay_params) +# timing_inst is put after call_node_inst. +# So the execution time of ``call_node.inst.run_after_pass()`` is also counted. +with tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]): + relay_mod = pass_seq(relay_mod) + profiles = timing_inst.render() +# Uncomment the next line to see timing-profile results. +# print(profiles) + + +############################################################################### +# We can see how many CallNode increase/decrease per op type. +from pprint import pprint + +print("Printing the change in number of occurrences of each operator caused by each pass...") +pprint(call_node_inst.get_pass_to_op_diff()) + + +############################################################################### +# Exception Handling +# ------------------ +# Let's see what happens if an exception occurs in a method of a ``PassInstrument``. +# +# Define ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``: +class PassExampleBase: + def __init__(self, name): + self._name = name + + def enter_pass_ctx(self): + print(self._name, "enter_pass_ctx") + + def exit_pass_ctx(self): + print(self._name, "exit_pass_ctx") + + def should_run(self, mod, info): + print(self._name, "should_run") + return True + + def run_before_pass(self, mod, pass_info): + print(self._name, "run_before_pass") + + def run_after_pass(self, mod, pass_info): + print(self._name, "run_after_pass") + + +@pass_instrument +class PassFine(PassExampleBase): + pass + + +@pass_instrument +class PassBadEnterCtx(PassExampleBase): + def enter_pass_ctx(self): + print(self._name, "bad enter_pass_ctx!!!") + raise ValueError("{} bad enter_pass_ctx".format(self._name)) + + +@pass_instrument +class PassBadExitCtx(PassExampleBase): + def exit_pass_ctx(self): + print(self._name, "bad exit_pass_ctx!!!") + raise ValueError("{} bad exit_pass_ctx".format(self._name)) + + +############################################################################### +# If an exception occurs in ``enter_pass_ctx``, ``PassContext`` will disable the pass +# instrumentation. And it will run the ``exit_pass_ctx`` of each ``PassInstrument`` +# which successfully finished ``enter_pass_ctx``. +# +# In following example, we can see ``exit_pass_ctx`` of `PassFine_0` is executed after exception. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadEnterCtx("PassBadEnterCtx"), + PassFine("PassFine_1"), + ] +) +try: + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Exceptions in ``PassInstrument`` instances cause all instruments of the current ``PassContext`` +# to be cleared, so nothing is printed when ``override_instruments`` is called. +demo_ctx.override_instruments([]) # no PassFine_0 exit_pass_ctx printed....etc + +############################################################################### +# If an exception occurs in ``exit_pass_ctx``, then the pass instrument is disabled. +# Then exception is propagated. That means ``PassInstrument`` instances registered +# after the one throwing the exception do not execute ``exit_pass_ctx``. +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadExitCtx("PassBadExitCtx"), + PassFine("PassFine_1"), + ] +) +try: + # PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass`` +# are not handled explicitly -- we rely on the context manager (the ``with`` syntax) +# to exit ``PassContext`` safely. +# +# We use ``run_before_pass`` as an example: +@pass_instrument +class PassBadRunBefore(PassExampleBase): + def run_before_pass(self, mod, pass_info): + print(self._name, "bad run_before_pass!!!") + raise ValueError("{} bad run_before_pass".format(self._name)) + + +demo_ctx = tvm.transform.PassContext( + instruments=[ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) +try: + # All exit_pass_ctx are called. + with demo_ctx: + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Also note that pass instrumentation is not disable. So if we call +# ``override_instruments``, the ``exit_pass_ctx`` of old registered ``PassInstrument`` +# is called. +demo_ctx.override_instruments([]) + +############################################################################### +# If we don't wrap pass execution with ``with`` syntax, ``exit_pass_ctx`` is not +# called. Let try this with current ``PassContext``: +cur_pass_ctx = tvm.transform.PassContext.current() +cur_pass_ctx.override_instruments( + [ + PassFine("PassFine_0"), + PassBadRunBefore("PassBadRunBefore"), + PassFine("PassFine_1"), + ] +) + +############################################################################### +# Then call passes. ``exit_pass_ctx`` is not executed after the exception, +# as expectation. +try: + # No ``exit_pass_ctx`` got executed. + relay_mod = relay.transform.InferType()(relay_mod) +except ValueError as ex: + print("Catching", str(ex).split("\n")[-1]) + +############################################################################### +# Clear instruments. +cur_pass_ctx.override_instruments([]) diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py index 0ce610a2cdd6..027e9e6eb757 100644 --- a/tutorials/frontend/from_mxnet.py +++ b/tutorials/frontend/from_mxnet.py @@ -32,7 +32,7 @@ pip install mxnet --user -or please refer to offical installation guide. +or please refer to official installation guide. https://mxnet.apache.org/versions/master/install/index.html """ # some standard imports diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index 4eba297935f0..26aeb6ecaf38 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -29,7 +29,7 @@ pip install onnx --user -or please refer to offical site. +or please refer to official site. https://github.com/onnx/onnx """ import onnx @@ -122,7 +122,7 @@ # Notes # --------------------------------------------- # By default, ONNX defines models in terms of dynamic shapes. The ONNX importer -# retains that dynamism upon import, and the compiler attemps to convert the model +# retains that dynamism upon import, and the compiler attempts to convert the model # into a static shapes at compile time. If this fails, there may still be dynamic # operations in the model. Not all TVM kernels currently support dynamic shapes, # please file an issue on discuss.tvm.apache.org if you hit an error with dynamic kernels. diff --git a/tutorials/get_started/tvmc_command_line_driver.py b/tutorials/get_started/tvmc_command_line_driver.py index b1fd1d1c2035..c729b86a3245 100644 --- a/tutorials/get_started/tvmc_command_line_driver.py +++ b/tutorials/get_started/tvmc_command_line_driver.py @@ -72,13 +72,6 @@ # -################################################################################ -# .. note:: Supported operating systems -# -# TVMC is only supported on Linux. Currently macOS and Windows default -# threading models do not support the model used for tuning by TVMC. - - ################################################################################ # Obtaining the Model # ------------------- diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index b7a9c79392d2..dcf564dd0314 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -34,7 +34,6 @@ from ..libinfo import find_libvta -@tvm.register_func("tvm.rpc.server.start", override=True) def server_start(): """VTA RPC server extension.""" # pylint: disable=unused-variable @@ -148,8 +147,21 @@ def main(): else: tracker_addr = None + # register the initialization callback + def server_init_callback(): + # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self + import tvm + import vta.exec.rpc_server + + tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True) + server = rpc.Server( - args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr + args.host, + args.port, + args.port_end, + key=args.key, + tracker_addr=tracker_addr, + server_init_callback=server_init_callback, ) server.proc.join() diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 7c7d02b40fbb..383841f19e34 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -495,21 +495,21 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): # FIXME: pad_value is ignored... env = get_env() _ = pad_value - if dst.scope == "global": + if dst.scope() == "global": # Store if pad_before or pad_after: raise RuntimeError("Do not support copy into DRAM with pad") - if src.scope == env.acc_scope: + if src.scope() == env.acc_scope: elem_width = env.OUT_WIDTH elem_bytes = env.OUT_ELEM_BYTES mem_type = env.dev.MEM_ID_OUT data_type = "int%d" % env.OUT_WIDTH task_qid = env.dev.QID_STORE_OUT else: - raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + raise RuntimeError("Do not support copy %s->dram" % (src.scope())) _check_compact(src) x_size, y_size, x_stride, offset = _get_2d_pattern( - dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True + dst, elem_width, elem_bytes, data_type, src.scope(), allow_fold=True ) irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) @@ -528,27 +528,27 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): ) ) return irb.get() - elif src.scope == "global": - if dst.scope == env.acc_scope: + elif src.scope() == "global": + if dst.scope() == env.acc_scope: elem_width = env.ACC_WIDTH elem_bytes = env.ACC_ELEM_BYTES mem_type = env.dev.MEM_ID_ACC data_type = "int%d" % env.ACC_WIDTH task_qid = env.dev.QID_LOAD_OUT - elif dst.scope == env.inp_scope: + elif dst.scope() == env.inp_scope: elem_width = env.INP_WIDTH elem_bytes = env.INP_ELEM_BYTES mem_type = env.dev.MEM_ID_INP data_type = "int%d" % env.INP_WIDTH task_qid = env.dev.QID_LOAD_INP - elif dst.scope == env.wgt_scope: + elif dst.scope() == env.wgt_scope: elem_width = env.WGT_WIDTH elem_bytes = env.WGT_ELEM_BYTES mem_type = env.dev.MEM_ID_WGT data_type = "int%d" % env.WGT_WIDTH task_qid = env.dev.QID_LOAD_WGT else: - raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + raise RuntimeError("Do not support copy dram->%s" % (dst.scope())) # collect pad statistics if pad_before: assert pad_after @@ -586,7 +586,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): _check_compact(dst) x_size, y_size, x_stride, offset = _get_2d_pattern( - src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold + src, elem_width, elem_bytes, data_type, dst.scope(), allow_fold=allow_fold ) if data_type != src.dtype: @@ -617,7 +617,7 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): return irb.get() else: - raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) + raise RuntimeError("Do not support copy %s->%s" % (src.scope(), dst.scope())) return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy) diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index f12837f421f8..226797eb7d19 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise { interface FunctionInfo { name: string; arg_types: Array; - thread_axis_tags: Array; + launch_param_tags: Array; } /** @@ -114,8 +114,8 @@ export class WebGPUContext { const dispatchToDim: Array = []; - for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { - const tag: string = finfo.thread_axis_tags[i]; + for (let i = 0; i < finfo.launch_param_tags.length; ++i) { + const tag: string = finfo.launch_param_tags[i]; if (tag.startsWith("blockIdx.")) { const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); assert(target >= 0 && target < 3);