diff --git a/.gitignore b/.gitignore index 068cb87484a04..1fcb2dc2d3fca 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,10 @@ __pycache__/ *.py[cod] *$py.class - +*.S # C extensions *.so +*.ll # Distribution / packaging .Python diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc index 56d05efb03a50..674feb4f29c0d 100644 --- a/3rdparty/bfloat16/bfloat16.cc +++ b/3rdparty/bfloat16/bfloat16.cc @@ -17,6 +17,7 @@ ==============================================================================*/ #include + #include #include @@ -50,8 +51,7 @@ void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { #endif } -void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, - size_t size) { +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t size) { float a_f, b_f; BFloat16ToFloat(a, &a_f, 1); BFloat16ToFloat(b, &b_f, 1); diff --git a/3rdparty/cma/cma.h b/3rdparty/cma/cma.h index f005b3065c3a2..2cd5501226145 100644 --- a/3rdparty/cma/cma.h +++ b/3rdparty/cma/cma.h @@ -27,20 +27,17 @@ #ifndef VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ #define VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ - /* Should be defined in settings.mk file */ #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) - -#define CMA_IOCTL_MAXNR 5 - +#define CMA_IOCTL_MAXNR 5 #endif // VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ diff --git a/3rdparty/cma/cma_api_impl.h b/3rdparty/cma/cma_api_impl.h index 12c0e3b27efcb..317be5c9af1af 100644 --- a/3rdparty/cma/cma_api_impl.h +++ b/3rdparty/cma/cma_api_impl.h @@ -30,48 +30,47 @@ * \brief Application layer implementation for contigous memory allocation. */ +#include +#include #include #include -#include -#include -#include #include -#include #include #include +#include +#include #include "cma_api.h" #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_IOCTL_MAXNR 5 +#define CMA_IOCTL_MAXNR 5 #ifndef CMA_DEBUG - #define CMA_DEBUG 0 +#define CMA_DEBUG 0 #endif #ifndef DRIVER_NODE_NAME - #define DRIVER_NODE_NAME "cma" +#define DRIVER_NODE_NAME "cma" #endif #if CMA_DEBUG == 1 - #define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) +#define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) #else - #define __DEBUG(fmt, args...) +#define __DEBUG(fmt, args...) #endif -#define ROUND_UP(N, S) ((((N) + (S) - 1) / (S)) * (S)) - +#define ROUND_UP(N, S) ((((N) + (S)-1) / (S)) * (S)) /* Private functions */ -void *cma_alloc(size_t size, unsigned ioctl_cmd); +void* cma_alloc(size_t size, unsigned ioctl_cmd); /* Global file descriptor */ int cma_fd = 0; @@ -99,23 +98,19 @@ int cma_release(void) { return 0; } -void *cma_alloc_cached(size_t size) { - return cma_alloc(size, CMA_ALLOC_CACHED); -} +void* cma_alloc_cached(size_t size) { return cma_alloc(size, CMA_ALLOC_CACHED); } -void *cma_alloc_noncached(size_t size) { - return cma_alloc(size, CMA_ALLOC_NONCACHED); -} +void* cma_alloc_noncached(size_t size) { return cma_alloc(size, CMA_ALLOC_NONCACHED); } -int cma_free(void *mem) { +int cma_free(void* mem) { __DEBUG("Releasing contigous memory from 0x%x\n", (unsigned)mem); unsigned data, v_addr; /* save user space pointer value */ - data = (unsigned)mem; + data = (unsigned)mem; v_addr = (unsigned)mem; - if ( ioctl(cma_fd, CMA_GET_SIZE, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_SIZE, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 0\n"); return -1; } @@ -125,7 +120,7 @@ int cma_free(void *mem) { munmap(mem, data); /* free cma entry */ - if ( ioctl(cma_fd, CMA_FREE, &v_addr) == -1 ) { + if (ioctl(cma_fd, CMA_FREE, &v_addr) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 1\n"); return -1; } @@ -133,7 +128,7 @@ int cma_free(void *mem) { return 0; } -unsigned cma_get_phy_addr(void *mem) { +unsigned cma_get_phy_addr(void* mem) { unsigned data; __DEBUG("Getting physical address from 0x%x\n", (unsigned)mem); @@ -141,7 +136,7 @@ unsigned cma_get_phy_addr(void *mem) { data = (unsigned)mem; /* get physical address */ - if ( ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful\n"); return 0; } @@ -150,10 +145,9 @@ unsigned cma_get_phy_addr(void *mem) { return data; } - -void *cma_alloc(size_t size, unsigned ioctl_cmd) { +void* cma_alloc(size_t size, unsigned ioctl_cmd) { unsigned data; - void *mem; + void* mem; __DEBUG("Allocating 0x%x bytes of contigous memory\n", size); /* Page align size */ @@ -161,7 +155,7 @@ void *cma_alloc(size_t size, unsigned ioctl_cmd) { /* ioctl cmd to allocate contigous memory */ data = (unsigned)size; - if ( ioctl(cma_fd, ioctl_cmd, &data) == -1 ) { + if (ioctl(cma_fd, ioctl_cmd, &data) == -1) { __DEBUG("cma_alloc - ioctl command unsuccsessful\n"); return NULL; } diff --git a/3rdparty/compiler-rt/builtin_fp16.h b/3rdparty/compiler-rt/builtin_fp16.h index fa8efddcd4caf..8048980819968 100644 --- a/3rdparty/compiler-rt/builtin_fp16.h +++ b/3rdparty/compiler-rt/builtin_fp16.h @@ -29,16 +29,33 @@ static inline uint32_t __clz(uint32_t x) { int n = 32; uint32_t y; - y = x >>16; if (y) { n = n -16; x = y; } - y = x >> 8; if (y) { n = n - 8; x = y; } - y = x >> 4; if (y) { n = n - 4; x = y; } - y = x >> 2; if (y) { n = n - 2; x = y; } - y = x >> 1; if (y) return n - 2; + y = x >> 16; + if (y) { + n = n - 16; + x = y; + } + y = x >> 8; + if (y) { + n = n - 8; + x = y; + } + y = x >> 4; + if (y) { + n = n - 4; + x = y; + } + y = x >> 2; + if (y) { + n = n - 2; + x = y; + } + y = x >> 1; + if (y) return n - 2; return n - x; } -template +template static inline DST_T __truncXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -71,7 +88,10 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const DST_REP_T dstNaNCode = dstQNaN - 1; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -88,25 +108,21 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const SRC_REP_T roundBits = aAbs & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; - } - else if (aAbs > srcInfinity) { + } else if (aAbs > srcInfinity) { // a is NaN. // Conjure the result by beginning with infinity, setting the qNaN // bit and inserting the (truncated) trailing NaN field. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= dstQNaN; absResult |= ((aAbs & srcNaNCode) >> (SRC_SIG_BITS - DST_SIG_BITS)) & dstNaNCode; - } - else if (aAbs >= overflow) { + } else if (aAbs >= overflow) { // a overflows to infinity. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; - } - else { + } else { // a underflows on conversion to the destination type or is an exact // zero. The result may be a denormal or zero. Extract the exponent // to get the shift amount for the denormalization. @@ -124,9 +140,8 @@ static inline DST_T __truncXfYf2__(SRC_T a) { absResult = denormalizedSignificand >> (SRC_SIG_BITS - DST_SIG_BITS); const SRC_REP_T roundBits = denormalizedSignificand & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; } @@ -134,14 +149,17 @@ static inline DST_T __truncXfYf2__(SRC_T a) { // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | sign >> (srcBits - dstBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; } -template +template static inline DST_T __extendXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -157,7 +175,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const SRC_REP_T srcQNaN = SRC_REP_T(1) << (SRC_SIG_BITS - 1); const SRC_REP_T srcNaNCode = srcQNaN - 1; - const int dstBits = sizeof(DST_T)*8; + const int dstBits = sizeof(DST_T) * 8; const int dstExpBits = dstBits - DST_SIG_BITS - 1; const int dstInfExp = (1 << dstExpBits) - 1; const int dstExpBias = dstInfExp >> 1; @@ -165,7 +183,10 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const DST_REP_T dstMinNormal = DST_REP_T(1) << DST_SIG_BITS; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -191,8 +212,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= (DST_REP_T)(aAbs & srcQNaN) << (DST_SIG_BITS - SRC_SIG_BITS); absResult |= (DST_REP_T)(aAbs & srcNaNCode) << (DST_SIG_BITS - SRC_SIG_BITS); - } - else if (aAbs) { + } else if (aAbs) { // a is denormal. // renormalize the significand and clear the leading bit, then insert // the correct adjusted exponent in the destination type. @@ -201,15 +221,17 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult ^= dstMinNormal; const int resultExponent = dstExpBias - srcExpBias - scale + 1; absResult |= (DST_REP_T)resultExponent << DST_SIG_BITS; - } - else { + } else { // a is zero. absResult = 0; } // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | (DST_REP_T)sign << (dstBits - srcBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 981b1c32f9166..ff3db4367a30f 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 981b1c32f91668e669ee376856f92f36cfd2a351 +Subproject commit ff3db4367a30f542aafb83b4af45e685b80102d0 diff --git a/CMakeLists.txt b/CMakeLists.txt index fc7c67c83a488..7c9fe1d8a8eb2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,11 +146,35 @@ else(MSVC) if(BUILD_FOR_HEXAGON) message(STATUS "Building for Hexagon") endif() + + # Detect if we're compiling for Android. + set(TEST_FOR_ANDROID_CXX + "#ifndef __ANDROID__" + "#error" + "#endif" + "int main() {}") + set(TEST_FOR_ANDROID_DIR + "${CMAKE_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/CMakeTmp") + set(TEST_FOR_ANDROID_FILE "${TEST_FOR_ANDROID_DIR}/test_for_android.cc") + string(REPLACE ";" "\n" TEST_FOR_ANDROID_CXX_TEXT "${TEST_FOR_ANDROID_CXX}") + file(WRITE "${TEST_FOR_ANDROID_FILE}" "${TEST_FOR_ANDROID_CXX_TEXT}") + try_compile(BUILD_FOR_ANDROID "${CMAKE_BINARY_DIR}${CMAKE_FILES_DIRECTORY}" + "${TEST_FOR_ANDROID_FILE}") + file(REMOVE "${TEST_FOR_ANDROID_FILE}") + if(BUILD_FOR_ANDROID) + message(STATUS "Building for Android") + endif() endif(MSVC) # Hexagon has dlopen built into QuRT (no need for static library). if(NOT BUILD_FOR_HEXAGON) - string(APPEND TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) +endif() + +if(BUILD_FOR_ANDROID) + # EmuTLS on Android is in libgcc. Without it linked in, libtvm_runtime.so + # won't load on Android due to missing __emutls_XXX symbols. + list(APPEND TVM_RUNTIME_LINKER_LIBS "gcc") endif() # add source group @@ -304,12 +328,15 @@ include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) +include(CheckCXXCompilerFlag) if(NOT MSVC) - include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14) - message(STATUS "Build with c++14") set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_STANDARD 14) +else() + check_cxx_compiler_flag("/std:c++14" SUPPORT_CXX14) + set(CMAKE_CXX_FLAGS "/std:c++14 ${CMAKE_CXX_FLAGS}") + set(CMAKE_CUDA_STANDARD 14) endif() add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) diff --git a/CMakeLists.txt.user b/CMakeLists.txt.user new file mode 100644 index 0000000000000..9f16b45af3d8d --- /dev/null +++ b/CMakeLists.txt.user @@ -0,0 +1,1734 @@ + + + + + + EnvironmentId + {d3f5af93-838a-44b7-a4da-43811b5000f1} + + + ProjectExplorer.Project.ActiveTarget + 0 + + + ProjectExplorer.Project.EditorSettings + + true + false + true + + Cpp + + CppGlobal + + + + QmlJS + + QmlJSGlobal + + + 2 + UTF-8 + false + 4 + false + 80 + true + true + 1 + true + false + 0 + true + true + 0 + 8 + true + 1 + true + true + true + false + + + + ProjectExplorer.Project.PluginSettings + + + true + true + true + true + + 0 + true + + false + {834ee4fe-e21e-4d01-9189-5141641a43f1} + + true + Builtin.TidyAndClazy + 4 + + + + true + + + + + ProjectExplorer.Project.Target.0 + + Desktop + Desktop + {c04cfa7f-6b8e-4f37-b439-1d0e17d2c3ff} + 0 + 0 + 19 + + + CMAKE_BUILD_TYPE:STRING=Debug + CMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + CMAKE_CXX_COMPILER_LAUNCHER:STRING=ccache + CMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} + CMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} + QT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} + + /home/xiaoquan.li/submit_tvm/build_tvm + + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + Default + CMakeProjectManager.CMakeBuildConfiguration + + + + CMAKE_BUILD_TYPE:STRING=Debug + CMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + CMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} + CMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} + QT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} + + /home/xiaoquan.li/submit_tvm/build-tvm-Desktop-Debug + + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + Debug + CMakeProjectManager.CMakeBuildConfiguration + + + + CMAKE_BUILD_TYPE:STRING=Release + CMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + CMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} + CMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} + QT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} + + /home/xiaoquan.li/submit_tvm/build-tvm-Desktop-Release + + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + Release + CMakeProjectManager.CMakeBuildConfiguration + + + + CMAKE_BUILD_TYPE:STRING=RelWithDebInfo + CMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + CMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} + CMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} + QT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} + + /home/xiaoquan.li/submit_tvm/build-tvm-Desktop-Release-with-Debug-Information + + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + Release with Debug Information + CMakeProjectManager.CMakeBuildConfiguration + + + + CMAKE_BUILD_TYPE:STRING=MinSizeRel + CMAKE_CXX_COMPILER:STRING=%{Compiler:Executable:Cxx} + CMAKE_C_COMPILER:STRING=%{Compiler:Executable:C} + CMAKE_PREFIX_PATH:STRING=%{Qt:QT_INSTALL_PREFIX} + QT_QMAKE_EXECUTABLE:STRING=%{Qt:qmakeExecutable} + + /home/xiaoquan.li/submit_tvm/build-tvm-Desktop-Minimum-Size-Release + + + + + all + + true + CMakeProjectManager.MakeStep + + 1 + Build + Build + ProjectExplorer.BuildSteps.Build + + + + + + clean + + true + CMakeProjectManager.MakeStep + + 1 + Clean + Clean + ProjectExplorer.BuildSteps.Clean + + 2 + false + + Minimum Size Release + CMakeProjectManager.CMakeBuildConfiguration + + 5 + + + 0 + Deploy + Deploy + ProjectExplorer.BuildSteps.Deploy + + 1 + ProjectExplorer.DefaultDeployConfiguration + + 1 + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + packed_func_test + CMakeProjectManager.CMakeRunConfiguration.packed_func_test + packed_func_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + simple_passes_test + CMakeProjectManager.CMakeRunConfiguration.simple_passes_test + simple_passes_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + ir_functor_test + CMakeProjectManager.CMakeRunConfiguration.ir_functor_test + ir_functor_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + utvm_runtime_standalone_test + CMakeProjectManager.CMakeRunConfiguration.utvm_runtime_standalone_test + utvm_runtime_standalone_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + container_test + CMakeProjectManager.CMakeRunConfiguration.container_test + container_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + threading_backend_test + CMakeProjectManager.CMakeRunConfiguration.threading_backend_test + threading_backend_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + arith_simplify_test + CMakeProjectManager.CMakeRunConfiguration.arith_simplify_test + arith_simplify_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + tensor_test + CMakeProjectManager.CMakeRunConfiguration.tensor_test + tensor_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + build_module_test + CMakeProjectManager.CMakeRunConfiguration.build_module_test + build_module_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + relay_build_module_test + CMakeProjectManager.CMakeRunConfiguration.relay_build_module_test + relay_build_module_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + standalone-opt + CMakeProjectManager.CMakeRunConfiguration.standalone-opt + standalone-opt + /home/xiaoquan.li/submit_tvm/tvm/mlir/standalone/test/Standalone/dummy.mlir + false + + false + true + true + false + false + true + + /home/xiaoquan.li/submit_tvm/build_tvm/bin + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + to_mlir + CMakeProjectManager.CMakeRunConfiguration.to_mlir + to_mlir + /home/xiaoquan.li/submit_tvm/tvm/mlir/standalone/test/Standalone/dummy.mlir + false + + false + true + true + false + false + true + false + + /home/xiaoquan.li/submit_tvm/build_tvm/mlir/standalone/to_mlir + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + relay_pass_type_infer_test + CMakeProjectManager.CMakeRunConfiguration.relay_pass_type_infer_test + relay_pass_type_infer_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + expr_test + CMakeProjectManager.CMakeRunConfiguration.expr_test + expr_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + relay_transform_sequential + CMakeProjectManager.CMakeRunConfiguration.relay_transform_sequential + relay_transform_sequential + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + object_protocol_test + CMakeProjectManager.CMakeRunConfiguration.object_protocol_test + object_protocol_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + pattern_match_test + CMakeProjectManager.CMakeRunConfiguration.pattern_match_test + pattern_match_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + attrs_test + CMakeProjectManager.CMakeRunConfiguration.attrs_test + attrs_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + topi_ewise_test + CMakeProjectManager.CMakeRunConfiguration.topi_ewise_test + topi_ewise_test + + false + + false + true + true + false + false + true + + + + + dwarf + + cpu-cycles + + + 250 + + -e + cpu-cycles + --call-graph + dwarf,4096 + -F + 250 + + -F + true + 4096 + + false + false + false + false + true + 0.01 + 10 + true + kcachegrind + 1 + 25 + + 1 + true + false + true + valgrind + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + + 2 + + crt_memory_test + CMakeProjectManager.CMakeRunConfiguration.crt_memory_test + crt_memory_test + + false + + false + true + true + false + false + true + + + + 20 + + + + ProjectExplorer.Project.TargetCount + 1 + + + ProjectExplorer.Project.Updater.FileVersion + 22 + + + Version + 22 + + diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 10b247ae8eb17..b14e65cc0721a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ We do encourage everyone to work anything they are interested in. - [Aditya Atluri](https://github.com/adityaatluri): @adityaatluri - rocm - [Tianqi Chen](https://github.com/tqchen) (PPMC): @tqchen - topi, compiler, relay, docs +- [Liangfu Chen](https://github.com/liangfu): @liangfu - vta, chisel, intel FPGA, c runtime - [Wei Chen](https://github.com/wweic): @wweic - runtime, relay, vm - [Zhi Chen](https://github.com/zhiics): @zhiics - relay, quantization, pass manager - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends @@ -99,6 +100,7 @@ We do encourage everyone to work anything they are interested in. - [Kazutaka Morita](https://github.com/kazum): @kazum - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 +- [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Josh Pollock](https://github.com/joshpoll): @joshpoll - [Jared Roesch](https://github.com/jroesch): @jroesch - [Siva](https://github.com/srkreddy1238): @srkreddy1238 diff --git a/Jenkinsfile b/Jenkinsfile index f469a474d477b..1ad83508422d0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,7 +44,7 @@ // ci_lint = "tvmai/ci-lint:v0.60" -ci_gpu = "tvmai/ci-gpu:v0.61" +ci_gpu = "tvmai/ci-gpu:v0.63" ci_cpu = "tvmai/ci-cpu:v0.61" ci_i386 = "tvmai/ci-i386:v0.52" diff --git a/Makefile b/Makefile index 757b3300f7d52..e54b9a93b2309 100644 --- a/Makefile +++ b/Makefile @@ -73,7 +73,8 @@ build/libtvm_web_runtime.js: build/libtvm_web_runtime.bc cpplint: python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include; - python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \ + python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp \ + include src \ examples/extension/src examples/graph_executor/src pylint: diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index a58252e780fe3..bc10bdaa508c0 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 0d038fb1060cd..f1a47a674281b 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -22,23 +22,23 @@ * \brief Pack all tvm runtime source files */ #include + #include #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" -#include "../src/runtime/object.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/ndarray.cc" - -#include "../src/runtime/graph/graph_runtime.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5d2bca2e216dd..0b713b88ba9e8 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index 3e5080927db4f..d8ff683decc35 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -17,51 +17,47 @@ * under the License. */ -#include #include #include +#include + #define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C" { -TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json, - const char * build_params_bin, +TVM_BUNDLE_FUNCTION void* tvm_runtime_create(const char* build_graph_json, + const char* build_params_bin, const uint64_t build_params_bin_len) { const int build_graph_json_len = strlen(build_graph_json); - const std::string json_data(&build_graph_json[0], - &build_graph_json[0] + build_graph_json_len); - tvm::runtime::Module mod_syslib = - (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); + const std::string json_data(&build_graph_json[0], &build_graph_json[0] + build_graph_json_len); + tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); int device_type = kDLCPU; int device_id = 0; - tvm::runtime::Module mod = - (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( - json_data, mod_syslib, device_type, device_id); + tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( + json_data, mod_syslib, device_type, device_id); TVMByteArray params; - params.data = reinterpret_cast(&build_params_bin[0]); + params.data = reinterpret_cast(&build_params_bin[0]); params.size = build_params_bin_len; mod.GetFunction("load_params")(params); return new tvm::runtime::Module(mod); } -TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void *handle) { - delete reinterpret_cast(handle); +TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void* handle) { + delete reinterpret_cast(handle); } -TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void *handle, const char *name, - void *tensor) { - reinterpret_cast(handle)->GetFunction("set_input")( - name, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void* handle, const char* name, void* tensor) { + reinterpret_cast(handle)->GetFunction("set_input")( + name, reinterpret_cast(tensor)); } -TVM_BUNDLE_FUNCTION void tvm_runtime_run(void *handle) { - reinterpret_cast(handle)->GetFunction("run")(); +TVM_BUNDLE_FUNCTION void tvm_runtime_run(void* handle) { + reinterpret_cast(handle)->GetFunction("run")(); } -TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index, - void *tensor) { - reinterpret_cast(handle)->GetFunction("get_output")( - index, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void* handle, int index, void* tensor) { + reinterpret_cast(handle)->GetFunction("get_output")( + index, reinterpret_cast(tensor)); } } diff --git a/apps/bundle_deploy/bundle.h b/apps/bundle_deploy/bundle.h index aa57faa38666f..80238e1e231af 100644 --- a/apps/bundle_deploy/bundle.h +++ b/apps/bundle_deploy/bundle.h @@ -22,20 +22,15 @@ #include -TVM_DLL void * tvm_runtime_create(const char * json_data, - const char * params_data, - const uint64_t params_size); +TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, + const uint64_t params_size); -TVM_DLL void tvm_runtime_destroy(void * runtime); +TVM_DLL void tvm_runtime_destroy(void* runtime); -TVM_DLL void tvm_runtime_set_input(void * runtime, - const char * name, - DLTensor * tensor); +TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor); -TVM_DLL void tvm_runtime_run(void * runtime); +TVM_DLL void tvm_runtime_run(void* runtime); -TVM_DLL void tvm_runtime_get_output(void * runtime, - int32_t index, - DLTensor * tensor); +TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor); #endif /* TVM_APPS_BUNDLE_DEPLOY_BUNDLE_H_ */ diff --git a/apps/bundle_deploy/demo.cc b/apps/bundle_deploy/demo.cc index 0de10d7177eb5..5c210a2cab88d 100644 --- a/apps/bundle_deploy/demo.cc +++ b/apps/bundle_deploy/demo.cc @@ -17,44 +17,44 @@ * under the License. */ +#include +#include //dlopen +#include #include -#include -#include //dlopen #include #include #include -#include #include "build/graph.json.c" #include "build/params.bin.c" -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 3 && "Usage: demo "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - char * json_data = reinterpret_cast(build_graph_json); - char * params_data = reinterpret_cast(build_params_bin); + char* json_data = reinterpret_cast(build_graph_json); + char* params_data = reinterpret_cast(build_params_bin); uint64_t params_size = build_params_bin_len; struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; - FILE * fp = fopen(argv[2], "rb"); + FILE* fp = fopen(argv[2], "rb"); fread(input_storage, 3 * 224 * 224, 4, fp); fclose(fp); @@ -68,12 +68,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "data", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "data", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -89,8 +87,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); float max_iter = -std::numeric_limits::max(); @@ -102,19 +99,19 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("The maximum position in output vector is: %d, with max-value %f.\n", - max_index, max_iter); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf("The maximum position in output vector is: %d, with max-value %f.\n", max_index, max_iter); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); dlclose(bundle); - + return 0; } diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 7a116e89fa880..8e294a05775de 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -19,19 +19,19 @@ #include #include -#include #include +#include #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" +#include "../../src/runtime/graph/graph_runtime.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" #include "../../src/runtime/system_library.cc" -#include "../../src/runtime/graph/graph_runtime.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index c92400d295163..882e04be8ef9b 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -17,35 +17,35 @@ * under the License. */ +#include +#include //dlopen +#include +#include #include -#include -#include //dlopen #include #include #include -#include -#include -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 6 && "Usage: test "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); struct stat st; - char * json_data; - char * params_data; + char* json_data; + char* params_data; uint64_t params_size; - FILE * fp = fopen(argv[4], "rb"); + FILE* fp = fopen(argv[4], "rb"); stat(argv[4], &st); json_data = (char*)malloc(st.st_size); fread(json_data, st.st_size, 1, fp); @@ -61,7 +61,7 @@ int main(int argc, char **argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); @@ -85,12 +85,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "x", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "x", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -106,8 +104,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); for (auto i = 0; i < 10 * 5; ++i) { @@ -117,20 +114,21 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); free(json_data); free(params_data); dlclose(bundle); - + return 0; } diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 5168da31d696d..ae2636da75558 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,20 +21,21 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include #include #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #endif #include -#include + #include -#include +#include #include +#include -#include "../../src/support/util.h" #include "../../src/support/socket.h" +#include "../../src/support/util.h" #include "rpc_server.h" #if defined(_WIN32) @@ -45,21 +46,21 @@ using namespace std; using namespace tvm::runtime; using namespace tvm::support; -static const string kUsage = \ -"Command line usage\n" \ -" server - Start the server\n" \ -"--host - The hostname of the server, Default=0.0.0.0\n" \ -"--port - The port of the RPC, Default=9090\n" \ -"--port-end - The end search port of the RPC, Default=9199\n" \ -"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ -"--key - The key used to identify the device type in tracker. Default=\"\"\n" \ -"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ -"--silent - Whether to run in silent mode. Default=False\n" \ -"\n" \ -" Example\n" \ -" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " -" --tracker=127.0.0.1:9190 --key=rasp" \ -"\n"; +static const string kUsage = + "Command line usage\n" + " server - Start the server\n" + "--host - The hostname of the server, Default=0.0.0.0\n" + "--port - The port of the RPC, Default=9090\n" + "--port-end - The end search port of the RPC, Default=9199\n" + "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" + "--key - The key used to identify the device type in tracker. Default=\"\"\n" + "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" + "--silent - Whether to run in silent mode. Default=False\n" + "\n" + " Example\n" + " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " + " --tracker=127.0.0.1:9190 --key=rasp" + "\n"; /*! * \brief RpcServerArgs. @@ -95,7 +96,7 @@ void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "tracker = " << args.tracker; LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; - LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); + LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) @@ -151,7 +152,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { * \param tracker The tracker input. * \return result of operation. */ -bool ValidateTracker(string &tracker) { +bool ValidateTracker(string& tracker) { vector list = Split(tracker, ':'); if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { return false; @@ -168,7 +169,7 @@ bool ValidateTracker(string &tracker) { * \param argv arg values * \param args the output structure which holds the parsed values */ -void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { +void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; @@ -232,12 +233,11 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { } #if defined(WIN32) const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); - if(!mmap_path.empty()) { + if (!mmap_path.empty()) { args.mmap_path = mmap_path; dmlc::InitLogging("--minloglevel=0"); } #endif - } /*! @@ -246,7 +246,7 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \param argv arg values * \return result of operation. */ -int RpcServer(int argc, char * argv[]) { +int RpcServer(int argc, char* argv[]) { RpcServerArgs args; /* parse the command line args */ @@ -260,21 +260,21 @@ int RpcServer(int argc, char * argv[]) { #endif #if defined(WIN32) - if(!args.mmap_path.empty()) { + if (!args.mmap_path.empty()) { int ret = 0; try { - ChildProcSocketHandler(args.mmap_path); + ChildProcSocketHandler(args.mmap_path); } catch (const std::exception&) { - ret = -1; + ret = -1; } return ret; } #endif - RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, + args.silent); return 0; } @@ -284,7 +284,7 @@ int RpcServer(int argc, char * argv[]) { * \param argv arg values * \return result of operation. */ -int main(int argc, char * argv[]) { +int main(int argc, char* argv[]) { if (argc <= 1) { LOG(INFO) << kUsage; return 0; diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7efe..a690fd85a59be 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,8 +20,9 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ -#include #include + +#include #ifndef _WIN32 #include #include @@ -30,44 +31,53 @@ #include #include namespace { - int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } -} +int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} // namespace #endif #include #include #include #include #include -#include -#include "../../src/support/util.h" #include "../../src/runtime/file_util.h" +#include "../../src/support/util.h" #include "rpc_env.h" namespace { - std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { - std::string untar_cmd; - untar_cmd.reserve(512); +std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { + std::string untar_cmd; + untar_cmd.reserve(512); #if defined(__linux__) || defined(__ANDROID__) - untar_cmd += "tar -C "; - untar_cmd += output_dir; - untar_cmd += " -zxf "; - untar_cmd += tar_file; + untar_cmd += "tar -C "; + untar_cmd += output_dir; + untar_cmd += " -zxf "; + untar_cmd += tar_file; #elif defined(_WIN32) - untar_cmd += "python -m tarfile -e "; - untar_cmd += tar_file; - untar_cmd += " "; - untar_cmd += output_dir; + untar_cmd += "python -m tarfile -e "; + untar_cmd += tar_file; + untar_cmd += " "; + untar_cmd += output_dir; #endif - return untar_cmd; - } + return untar_cmd; +} -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { RPCEnv::RPCEnv() { +#ifndef _WIN32 + char cwd[PATH_MAX]; + if (char* rc = getcwd(cwd, sizeof(cwd))) { + base_ = std::string(cwd) + "/rpc"; + } else { + base_ = "./rpc"; + } +#else base_ = "./rpc"; +#endif + mkdir(base_.c_str(), 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { static RPCEnv env; @@ -162,22 +172,20 @@ std::vector ListDir(const std::string& dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, - const std::vector &files, - std::string options = "", - std::string cc = "g++") { - std::string cmd = cc; - cmd += " -shared -fPIC "; - cmd += " -o " + output; - for (auto f = files.begin(); f != files.end(); ++f) { - cmd += " " + *f; - } - cmd += " " + options; - std::string err_msg; - auto executed_status = support::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } +void LinuxShared(const std::string output, const std::vector& files, + std::string options = "", std::string cc = "g++") { + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (auto f = files.begin(); f != files.end(); ++f) { + cmd += " " + *f; + } + cmd += " " + options; + std::string err_msg; + auto executed_status = support::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } } #endif @@ -189,10 +197,8 @@ void LinuxShared(const std::string output, * \param options The compiler options * \param cc The compiler */ -void WindowsShared(const std::string& output, - const std::vector& files, - const std::string& options = "", - const std::string& cc = "clang") { +void WindowsShared(const std::string& output, const std::vector& files, + const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; cmd += " -o " + output; @@ -233,7 +239,7 @@ void CreateShared(const std::string& output, const std::vector& fil * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string& fmt) { +Module Load(std::string* fileIn, const std::string& fmt) { const std::string& file = *fileIn; if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { return Module::LoadFromFile(file, fmt); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index d046f6ecb480c..464b10a2714c7 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_ENV_H_ #include + #include namespace tvm { @@ -40,13 +41,13 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string& fmt = ""); +Module Load(std::string* path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory * \param dirname THe name of the directory */ -void CleanDir(const std::string &dirname); +void CleanDir(const std::string& dirname); /*! * \brief RPCEnv The RPC Environment parameters for c++ rpc server diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index ea4ab00c113b9..2628ff77a5f79 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -32,9 +32,9 @@ #include #include -#include "../../src/support/socket.h" -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" +#include "../../src/support/socket.h" #include "rpc_env.h" #include "rpc_server.h" #include "rpc_tracker_client.h" @@ -66,6 +66,22 @@ static pid_t waitPidEintr(int* status) { } #endif +#ifdef __ANDROID__ +static std::string getNextString(std::stringstream* iss) { + std::string str = iss->str(); + size_t start = iss->tellg(); + size_t len = str.size(); + // Skip leading spaces. + while (start < len && isspace(str[start])) start++; + + size_t end = start; + while (end < len && !isspace(str[end])) end++; + + iss->seekg(end); + return str.substr(start, end - start); +} +#endif + /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 @@ -80,14 +96,15 @@ class RPCServer { /*! * \brief Constructor. */ - RPCServer(std::string host, int port, int port_end, std::string tracker_addr, - std::string key, std::string custom_addr) : - host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), - tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), - custom_addr_(std::move(custom_addr)) - { - - } + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key, + std::string custom_addr) + : host_(std::move(host)), + port_(port), + my_port_(0), + port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), + key_(std::move(key)), + custom_addr_(std::move(custom_addr)) {} /*! * \brief Destructor. @@ -97,8 +114,7 @@ class RPCServer { // Free the resources tracker_sock_.Close(); listen_sock_.Close(); - } catch(...) { - + } catch (...) { } } @@ -144,7 +160,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -164,9 +180,9 @@ class RPCServer { int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); + kill(worker_pid, SIGTERM); } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); + kill(timer_pid, SIGTERM); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } @@ -197,7 +213,6 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - } auto dur = high_resolution_clock::now() - start_time; @@ -217,11 +232,8 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, - support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, - int ping_period = 2) { + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, + support::SockAddr* addr, std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -233,7 +245,7 @@ class RPCServer { support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -260,7 +272,12 @@ class RPCServer { std::stringstream ssin(remote_key); std::string arg0; +#ifndef __ANDROID__ ssin >> arg0; +#else + arg0 = getNextString(&ssin); +#endif + if (arg0 != expect_header) { code = kRPCMismatch; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); @@ -274,7 +291,11 @@ class RPCServer { CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); +#ifndef __ANDROID__ ssin >> *opts; +#else + *opts = getNextString(&ssin); +#endif *conn_sock = conn; return; } @@ -301,8 +322,9 @@ class RPCServer { int GetTimeOutFromOpts(const std::string& opts) const { const std::string option = "-timeout="; - if (opts.find(option) == 0) { - const std::string cmd = opts.substr(opts.find_last_of(option) + 1); + size_t pos = opts.rfind(option); + if (pos != std::string::npos) { + const std::string cmd = opts.substr(pos + option.size()); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -322,15 +344,15 @@ class RPCServer { #if defined(WIN32) /*! -* \brief ServerLoopFromChild The Server loop process. -* \param socket The socket information -*/ + * \brief ServerLoopFromChild The Server loop process. + * \param socket The socket information + */ void ServerLoopFromChild(SOCKET socket) { // Server loop tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); - + sock.Close(); env.CleanUp(); } @@ -341,10 +363,10 @@ void ServerLoopFromChild(SOCKET socket) { * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" - * \param key The key used to identify the device type in tracker. Default="" - * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \param silent Whether run in silent mode. Default=True + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 + * Default="" \param key The key used to identify the device type in tracker. Default="" \param + * custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in + * silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, std::string key, std::string custom_addr, bool silent) { @@ -353,13 +375,13 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), + std::move(custom_addr)); rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") -.set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); - }); +TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); +}); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index db7c89d823dd6..0936c51bb2ce6 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_SERVER_H_ #include + #include "tvm/runtime/c_runtime_api.h" namespace tvm { @@ -49,13 +50,9 @@ void ServerLoopFromChild(SOCKET socket); * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099, + std::string tracker_addr = "", std::string key = "", + std::string custom_addr = "", bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index dfd576f4c1951..cdfb64780ba61 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -24,14 +24,14 @@ #ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ #define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ -#include -#include #include +#include #include -#include +#include #include +#include -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/support/socket.h" namespace tvm { @@ -47,29 +47,28 @@ class TrackerClient { public: /*! * \brief Constructor. - */ - TrackerClient(const std::string& tracker_addr, - const std::string& key, + */ + TrackerClient(const std::string& tracker_addr, const std::string& key, const std::string& custom_addr) - : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), - gen_(std::random_device{}()), dis_(0.0, 1.0) { - } + : tracker_addr_(tracker_addr), + key_(key), + custom_addr_(custom_addr), + gen_(std::random_device{}()), + dis_(0.0, 1.0) {} /*! * \brief Destructor. - */ + */ ~TrackerClient() { // Free the resources Close(); } /*! * \brief IsValid Check tracker is valid. - */ - bool IsValid() { - return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); - } + */ + bool IsValid() { return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); } /*! * \brief TryConnect Connect to tracker if the tracker address is valid. - */ + */ void TryConnect() { if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) { tracker_sock_ = ConnectWithRetry(); @@ -80,8 +79,8 @@ class TrackerClient { CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kUpdateInfo) - << ", {\"key\": \"server:"<< key_ << "\"}]"; + ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ + << "\"}]"; tracker_sock_.SendBytes(ss.str()); // Receive status and validate @@ -91,20 +90,19 @@ class TrackerClient { } /*! * \brief Close Clean up tracker resources. - */ + */ void Close() { // close tracker resource if (!tracker_sock_.IsClosed()) { tracker_sock_.Close(); } } - /*! - * \brief ReportResourceAndGetKey Report resource to tracker. - * \param port listening port. - * \param matchkey Random match key output. - */ - void ReportResourceAndGetKey(int port, - std::string *matchkey) { + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param port listening port. + * \param matchkey Random match key output. + */ + void ReportResourceAndGetKey(int port, std::string* matchkey) { if (!tracker_sock_.IsClosed()) { *matchkey = RandomKey(key_ + ":", old_keyset_); if (custom_addr_.empty()) { @@ -112,8 +110,8 @@ class TrackerClient { } std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); @@ -121,7 +119,7 @@ class TrackerClient { std::string remote_status = tracker_sock_.RecvBytes(); CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); } else { - *matchkey = key_; + *matchkey = key_; } } @@ -131,11 +129,9 @@ class TrackerClient { * \param port listening port. * \param ping_period Select wait time. * \param matchkey Random match key output. - */ - void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, - int port, - int ping_period, - std::string *matchkey) { + */ + void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, int port, int ping_period, + std::string* matchkey) { int unmatch_period_count = 0; int unmatch_timeout = 4; while (true) { @@ -155,9 +151,9 @@ class TrackerClient { // if match key not in pending key set // it means the key is acquired by a client but not used. if (pending_keys.find(*matchkey) == std::string::npos) { - unmatch_period_count += 1; + unmatch_period_count += 1; } else { - unmatch_period_count = 0; + unmatch_period_count = 0; } // regenerate match key if key is acquired but not used for a while if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) { @@ -166,8 +162,8 @@ class TrackerClient { *matchkey = RandomKey(key_ + ":", old_keyset_); std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); std::string remote_status = tracker_sock_.RecvBytes(); @@ -201,26 +197,25 @@ class TrackerClient { } auto period = (std::chrono::duration_cast( - std::chrono::system_clock::now() - tbegin)).count(); + std::chrono::system_clock::now() - tbegin)) + .count(); CHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); - LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() - << " retry in " << retry_period << " seconds."; + LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() << " retry in " + << retry_period << " seconds."; std::this_thread::sleep_for(std::chrono::seconds(retry_period)); } } /*! - * \brief Random Generate a random number between 0 and 1. - * \return random float value. - */ - float Random() { - return dis_(gen_); - } + * \brief Random Generate a random number between 0 and 1. + * \return random float value. + */ + float Random() { return dis_(gen_); } /*! * \brief Generate a random key. * \param prefix The string prefix. * \return cmap The conflict map set. */ - std::string RandomKey(const std::string& prefix, const std::set &cmap) { + std::string RandomKey(const std::string& prefix, const std::set& cmap) { if (!cmap.empty()) { while (true) { std::string key = prefix + std::to_string(Random()); @@ -236,10 +231,9 @@ class TrackerClient { std::string key_; std::string custom_addr_; support::TCPSocket tracker_sock_; - std::set old_keyset_; + std::set old_keyset_; std::mt19937 gen_; std::uniform_real_distribution dis_; - }; } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc index c6c72d79ab81d..bbf8367903bb2 100644 --- a/apps/cpp_rpc/win32_process.cc +++ b/apps/cpp_rpc/win32_process.cc @@ -20,15 +20,18 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#include "win32_process.h" + +#include +#include #include #include + #include #include -#include -#include #include -#include -#include "win32_process.h" +#include + #include "rpc_server.h" using namespace std::chrono; @@ -82,36 +85,36 @@ UniqueHandle MakeUniqueHandle(HANDLE handle) { */ SOCKET GetSocket(const std::string& mmap_path) { WSAPROTOCOL_INFO protocol_info; - + const std::string parent_event_name = mmap_path + kParent; const std::string child_event_name = mmap_path + kChild; // Open the events UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } UniqueHandle child_file_mapping_event; - if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } - + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read - if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { - LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); } - const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, - false, - mmap_path.c_str())); + const UniqueHandle file_map = + MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, false, mmap_path.c_str())); if (!file_map) { - LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } - void* map_view = MapViewOfFile(file_map.get(), - FILE_MAP_READ | FILE_MAP_WRITE, - 0, 0, 0); + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); SOCKET sock_duplicated = INVALID_SOCKET; @@ -120,12 +123,8 @@ SOCKET GetSocket(const std::string& mmap_path) { UnmapViewOfFile(map_view); // Creates the duplicate socket, that was created in the parent - sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - &protocol_info, - 0, - 0); + sock_duplicated = + WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &protocol_info, 0, 0); // Let the parent know we are finished dupicating the socket SetEvent(child_file_mapping_event.get()); @@ -135,7 +134,7 @@ SOCKET GetSocket(const std::string& mmap_path) { return sock_duplicated; } -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { @@ -146,7 +145,7 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, seconds timeout) { STARTUPINFOA startup_info; - + memset(&startup_info, 0, sizeof(startup_info)); startup_info.cb = sizeof(startup_info); @@ -157,13 +156,15 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Create an event to let the child know the socket info was set to the mmap file UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for parent file mapping failed"; } UniqueHandle child_file_mapping_event; // An event to let the parent know the socket info was read from the mmap file - if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for child file mapping failed"; } @@ -181,35 +182,22 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { strcpy(command_line_ptr.get(), child_command_line.c_str()); PROCESS_INFORMATION child_process_info; - if (CreateProcessA(nullptr, - command_line_ptr.get(), - nullptr, - nullptr, - false, - CREATE_NO_WINDOW, - nullptr, - nullptr, - &startup_info, - &child_process_info)) { + if (CreateProcessA(nullptr, command_line_ptr.get(), nullptr, nullptr, false, CREATE_NO_WINDOW, + nullptr, nullptr, &startup_info, &child_process_info)) { // Child process and thread handles must be closed, so wrapped in RAII auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); WSAPROTOCOL_INFO protocol_info; // Get info needed to duplicate the socket - if (WSADuplicateSocket(fd, - child_process_info.dwProcessId, - &protocol_info) == SOCKET_ERROR) { + if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) { LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); } // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc - UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, - nullptr, - PAGE_READWRITE, - 0, - sizeof(WSAPROTOCOL_INFO), - file_map_path.c_str())); + UniqueHandle file_map = + MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0, + sizeof(WSAPROTOCOL_INFO), file_map_path.c_str())); if (!file_map) { LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } @@ -225,11 +213,13 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Let child proc know the mmap file is ready to be read SetEvent(parent_file_mapping_event.get()); - + // Wait for the child to finish reading mmap file - if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { TerminateProcess(child_process_handle.get(), 0); - LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child " + "process."; } } else { TerminateProcess(child_process_handle.get(), 0); @@ -237,9 +227,8 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } - const DWORD process_timeout = timeout.count() - ? uint32_t(duration_cast(timeout).count()) - : INFINITE; + const DWORD process_timeout = + timeout.count() ? uint32_t(duration_cast(timeout).count()) : INFINITE; // Wait for child process to exit, or hit configured timeout if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { @@ -251,8 +240,9 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path) { SOCKET socket; @@ -260,14 +250,12 @@ void ChildProcSocketHandler(const std::string& mmap_path) { // Set high thread priority to avoid the thread scheduler from // interfering with any measurements in the RPC server. SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { tvm::runtime::ServerLoopFromChild(socket); - } - else { + } else { LOG(FATAL) << "GetSocket() failed"; } - } } // namespace runtime } // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h index 7d1a27680ed31..621444e187640 100644 --- a/apps/cpp_rpc/win32_process.h +++ b/apps/cpp_rpc/win32_process.h @@ -17,10 +17,10 @@ * under the License. */ - /*! - * \file win32_process.h - * \brief Win32 process code to mimic a POSIX fork() - */ +/*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ #ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #include @@ -34,8 +34,9 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path); } // namespace runtime diff --git a/apps/dso_plugin_module/plugin_module.cc b/apps/dso_plugin_module/plugin_module.cc index 7c3c5accf1ec7..eed11f855693c 100644 --- a/apps/dso_plugin_module/plugin_module.cc +++ b/apps/dso_plugin_module/plugin_module.cc @@ -20,10 +20,10 @@ * \brief Example code that can be compiled and loaded by TVM runtime. * \file plugin_module.cc */ -#include #include -#include #include +#include +#include namespace tvm_dso_plugin { @@ -31,24 +31,16 @@ using namespace tvm::runtime; class MyModuleNode : public ModuleNode { public: - explicit MyModuleNode(int value) - : value_(value) {} + explicit MyModuleNode(int value) : value_(value) {} - virtual const char* type_key() const final { - return "MyModule"; - } + virtual const char* type_key() const final { return "MyModule"; } - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { if (name == "add") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ + value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ + value; }); } else if (name == "mul") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ * value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ * value; }); } else { LOG(FATAL) << "unknown function " << name; return PackedFunc(); @@ -64,18 +56,14 @@ void CreateMyModule_(TVMArgs args, TVMRetValue* rv) { *rv = Module(make_object(value)); } -int SubOne_(int x) { - return x - 1; -} +int SubOne_(int x) { return x - 1; } // USE TVM_DLL_EXPORT_TYPED_PACKED_FUNC to export a // typed function as packed function. TVM_DLL_EXPORT_TYPED_FUNC(SubOne, SubOne_); // TVM_DLL_EXPORT_TYPED_PACKED_FUNC also works for lambda. -TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { - return x + 1; -}); +TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { return x + 1; }); // Use TVM_EXPORT_PACKED_FUNC to export a function with TVM_DLL_EXPORT_PACKED_FUNC(CreateMyModule, tvm_dso_plugin::CreateMyModule_); diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index a92d55fc4acde..87cb69b4f4cef 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -17,16 +17,15 @@ * under the License. */ - /*! * \brief Example package that uses TVM. * \file tvm_ext.cc */ -#include +#include #include -#include #include -#include +#include +#include #include using namespace tvm; @@ -50,8 +49,7 @@ class NDSubClass : public tvm::runtime::NDArray { public: class SubContainer : public NDArray::Container { public: - SubContainer(int additional_info) : - additional_info_(additional_info) { + SubContainer(int additional_info) : additional_info_(additional_info) { type_index_ = SubContainer::RuntimeTypeIndex(); } int additional_info_{0}; @@ -74,14 +72,14 @@ class NDSubClass : public tvm::runtime::NDArray { data_ = GetObjectPtr(ptr); } - NDSubClass AddWith(const NDSubClass &other) const { - SubContainer *a = static_cast(get_mutable()); - SubContainer *b = static_cast(other.get_mutable()); + NDSubClass AddWith(const NDSubClass& other) const { + SubContainer* a = static_cast(get_mutable()); + SubContainer* b = static_cast(other.get_mutable()); CHECK(a != nullptr && b != nullptr); return NDSubClass(a->additional_info_ + b->additional_info_); } int get_additional_info() const { - SubContainer *self = static_cast(get_mutable()); + SubContainer* self = static_cast(get_mutable()); CHECK(self != nullptr); return self->additional_info_; } @@ -116,60 +114,48 @@ TVM_REGISTER_OBJECT_TYPE(IntVectorObj); namespace tvm_ext { -TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { - auto n = tvm::runtime::make_object(); - for (int i = 0; i < args.size(); ++i) { - n->vec.push_back(args[i].operator int()); - } - *rv = IntVector(n); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") -.set_body([](TVMArgs args, TVMRetValue *rv) { - IntVector p = args[0]; - *rv = p->vec[args[1].operator int()]; - }); - - -TVM_REGISTER_GLOBAL("tvm_ext.bind_add") -.set_body([](TVMArgs args_, TVMRetValue *rv_) { - PackedFunc pf = args_[0]; - int b = args_[1]; - *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue *rv) { - *rv = pf(b, args[0]); - }); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.sym_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { - Var a = args[0]; - Var b = args[1]; - *rv = a + b; - }); - -TVM_REGISTER_GLOBAL("device_api.ext_dev") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.nd_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = tvm::runtime::make_object(); + for (int i = 0; i < args.size(); ++i) { + n->vec.push_back(args[i].operator int()); + } + *rv = IntVector(n); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) { + IntVector p = args[0]; + *rv = p->vec[args[1].operator int()]; +}); + +TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) { + PackedFunc pf = args_[0]; + int b = args_[1]; + *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); }); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) { + Var a = args[0]; + Var b = args[1]; + *rv = a + b; +}); + +TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) { int additional_info = args[0]; *rv = NDSubClass(additional_info); CHECK_EQ(rv->type_code(), kTVMNDArrayHandle); - }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; NDSubClass b = args[1]; *rv = a.AddWith(b); }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; *rv = a.get_additional_info(); }); @@ -177,17 +163,14 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") } // namespace tvm_ext // External function exposed to runtime. -extern "C" float TVMTestAddOne(float y) { - return y + 1; -} +extern "C" float TVMTestAddOne(float y) { return y + 1; } // This callback approach allows extension allows tvm to extract // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { - const PackedFunc& fregister = - *static_cast(pregister); - auto mul = [](TVMArgs args, TVMRetValue *rv) { + const PackedFunc& fregister = *static_cast(pregister); + auto mul = [](TVMArgs args, TVMRetValue* rv) { int x = args[0]; int y = args[1]; *rv = x * y; diff --git a/apps/howto_deploy/cpp_deploy.cc b/apps/howto_deploy/cpp_deploy.cc index a386dffa0b30e..b7a60f49d9178 100644 --- a/apps/howto_deploy/cpp_deploy.cc +++ b/apps/howto_deploy/cpp_deploy.cc @@ -21,11 +21,12 @@ * \brief Example code on load and run TVM module.s * \file cpp_deploy.cc */ -#include #include #include -#include #include +#include + +#include void Verify(tvm::runtime::Module mod, std::string fname) { // Get the function from the module. @@ -52,10 +53,8 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int device_type = kDLCPU; int device_id = 0; int64_t shape[1] = {10}; - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &x); - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &y); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y); for (int i = 0; i < shape[0]; ++i) { static_cast(x->data)[i] = i; } @@ -72,8 +71,7 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int main(void) { // Normally we can directly - tvm::runtime::Module mod_dylib = - tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); + tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); LOG(INFO) << "Verify dynamic loading from test_addone_dll.so"; Verify(mod_dylib, "addone"); // For libraries that are directly packed as system lib and linked together with the app diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 81bab497bebb8..37e3968ca3125 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -39,15 +39,15 @@ */ #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/apps/ios_rpc/tvmrpc/AppDelegate.h b/apps/ios_rpc/tvmrpc/AppDelegate.h index 0c54a47e7a2dd..a810aeafa47f9 100644 --- a/apps/ios_rpc/tvmrpc/AppDelegate.h +++ b/apps/ios_rpc/tvmrpc/AppDelegate.h @@ -25,7 +25,6 @@ @interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; - +@property(strong, nonatomic) UIWindow* window; @end diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.h b/apps/ios_rpc/tvmrpc/TVMRuntime.h index 96a5c1bfa3180..f6a6dc64c53a7 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.h +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,8 @@ #define DMLC_LOG_CUSTOMIZE 1 #define TVM_METAL_RUNTIME 1 -#include #include +#include #include namespace tvm { @@ -52,8 +52,7 @@ using FEventHandler = std::function(data) - maxLength:size]; + ssize_t nbytes = [stream_ write:reinterpret_cast(data) maxLength:size]; if (nbytes < 0) { - NSLog(@"%@",[stream_ streamError].localizedDescription); + NSLog(@"%@", [stream_ streamError].localizedDescription); throw dmlc::Error("Stream error"); } return nbytes; @@ -83,8 +79,8 @@ size_t Recv(void* data, size_t size) final { NSOutputStream* stream_; }; -FEventHandler CreateServerEventHandler( - NSOutputStream *outputStream, std::string name, std::string remote_key) { +FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, + std::string remote_key) { std::unique_ptr ch(new NSStreamChannel(outputStream)); std::shared_ptr sess = RPCSession::Create(std::move(ch), name, remote_key); return [sess](const std::string& in_bytes, int flag) { @@ -103,9 +99,7 @@ FEventHandler CreateServerEventHandler( } } // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + file_name; - } + std::string GetPath(const std::string& file_name) { return base_ + file_name; } private: std::string base_; @@ -115,49 +109,44 @@ void LaunchSyncServer() { // only load dylib from frameworks. NSBundle* bundle = [NSBundle mainBundle]; NSString* base = [bundle privateFrameworksPath]; - NSString* path = [base stringByAppendingPathComponent: @"tvm/rpc_config.txt"]; + NSString* path = [base stringByAppendingPathComponent:@"tvm/rpc_config.txt"]; std::string name = [path UTF8String]; std::ifstream fs(name, std::ios::in); std::string url, key; int port; - CHECK(fs >> url >> port >> key) - << "Invalid RPC config file " << name; - RPCConnect(url, port, "server:" + key) - ->ServerLoop(); + CHECK(fs >> url >> port >> key) << "Invalid RPC config file " << name; + RPCConnect(url, port, "server:" + key)->ServerLoop(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string name = args[0]; - std::string fmt = GetFileFormat(name, ""); - NSString* base; - if (fmt == "dylib") { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - base = [[bundle privateFrameworksPath] - stringByAppendingPathComponent: @"tvm"]; - } else { - // Load other modules in tempdir. - base = NSTemporaryDirectory(); - } - NSString* path = [base stringByAppendingPathComponent: - [NSString stringWithUTF8String:name.c_str()]]; - name = [path UTF8String]; - *rv = Module::LoadFromFile(name, fmt); - LOG(INFO) << "Load module from " << name << " ..."; - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + std::string fmt = GetFileFormat(name, ""); + NSString* base; + if (fmt == "dylib") { + // only load dylib from frameworks. + NSBundle* bundle = [NSBundle mainBundle]; + base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; + } else { + // Load other modules in tempdir. + base = NSTemporaryDirectory(); + } + NSString* path = + [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; + name = [path UTF8String]; + *rv = Module::LoadFromFile(name, fmt); + LOG(INFO) << "Load module from " << name << " ..."; +}); } // namespace runtime } // namespace tvm @implementation TVMRuntime -+(void) launchSyncServer { ++ (void)launchSyncServer { tvm::runtime::LaunchSyncServer(); } diff --git a/apps/ios_rpc/tvmrpc/ViewController.h b/apps/ios_rpc/tvmrpc/ViewController.h index 3a3c928f81120..b188a87b20d33 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.h +++ b/apps/ios_rpc/tvmrpc/ViewController.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,11 @@ #import #include "TVMRuntime.h" -@interface ViewController : UIViewController -{ +@interface ViewController : UIViewController { // input socket stream - NSInputStream *inputStream_; + NSInputStream* inputStream_; // output socket stream - NSOutputStream *outputStream_; + NSOutputStream* outputStream_; // temporal receive buffer. std::string recvBuffer_; // Whether connection is initialized. @@ -46,11 +45,11 @@ tvm::runtime::FEventHandler handler_; } -@property (weak, nonatomic) IBOutlet UITextField *proxyURL; -@property (weak, nonatomic) IBOutlet UITextField *proxyPort; -@property (weak, nonatomic) IBOutlet UITextField *proxyKey; -@property (weak, nonatomic) IBOutlet UILabel *statusLabel; -@property (weak, nonatomic) IBOutlet UITextView *infoText; +@property(weak, nonatomic) IBOutlet UITextField* proxyURL; +@property(weak, nonatomic) IBOutlet UITextField* proxyPort; +@property(weak, nonatomic) IBOutlet UITextField* proxyKey; +@property(weak, nonatomic) IBOutlet UILabel* statusLabel; +@property(weak, nonatomic) IBOutlet UITextView* infoText; - (IBAction)connect:(id)sender; - (IBAction)disconnect:(id)sender; diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 0f7611002042a..6c618c48096f5 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -21,12 +21,12 @@ * \file ViewController.mm */ -#include #import "ViewController.h" +#include @implementation ViewController -- (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event { +- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event { std::string buffer; switch (event) { case NSStreamEventOpenCompleted: { @@ -45,7 +45,7 @@ - (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event { break; } case NSStreamEventErrorOccurred: { - NSLog(@"%@",[strm streamError].localizedDescription); + NSLog(@"%@", [strm streamError].localizedDescription); break; } case NSStreamEventEndEncountered: { @@ -64,8 +64,7 @@ - (void)onReadAvailable { constexpr int kRPCMagic = 0xff271; if (!initialized_) { int code; - size_t nbytes = [inputStream_ read:reinterpret_cast(&code) - maxLength:sizeof(code)]; + size_t nbytes = [inputStream_ read:reinterpret_cast(&code) maxLength:sizeof(code)]; if (nbytes != sizeof(code)) { self.infoText.text = @"Fail to receive remote confirmation code."; [self close]; @@ -115,7 +114,7 @@ - (void)onShutdownReceived { - (void)onWriteAvailable { if (initSendPtr_ < initBytes_.length()) { initSendPtr_ += [outputStream_ write:reinterpret_cast(&initBytes_[initSendPtr_]) - maxLength:(initBytes_.length() - initSendPtr_)]; + maxLength:(initBytes_.length() - initSendPtr_)]; } if (initialized_) { try { @@ -148,13 +147,10 @@ - (void)open { // Initialize the network. CFReadStreamRef readStream; CFWriteStreamRef writeStream; - CFStreamCreatePairWithSocketToHost( - NULL, - (__bridge CFStringRef) self.proxyURL.text, - [self.proxyPort.text intValue], - &readStream, &writeStream); - inputStream_ = (__bridge_transfer NSInputStream *)readStream; - outputStream_ = (__bridge_transfer NSOutputStream *)writeStream; + CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.proxyURL.text, + [self.proxyPort.text intValue], &readStream, &writeStream); + inputStream_ = (__bridge_transfer NSInputStream*)readStream; + outputStream_ = (__bridge_transfer NSOutputStream*)writeStream; [inputStream_ setDelegate:self]; [outputStream_ setDelegate:self]; [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; diff --git a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm index c4a6f8bd240f7..eb538f07bf492 100644 --- a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm +++ b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm @@ -32,16 +32,15 @@ @interface tvmrpcLauncher : XCTestCase @implementation tvmrpcLauncher - (void)setUp { - [super setUp]; + [super setUp]; } - (void)tearDown { - [super tearDown]; + [super tearDown]; } - (void)testRPC { [TVMRuntime launchSyncServer]; } - @end diff --git a/apps/rocm_rpc/rocm_runtime_pack.cc b/apps/rocm_rpc/rocm_runtime_pack.cc index a137a9b28f8af..de5c504523405 100644 --- a/apps/rocm_rpc/rocm_runtime_pack.cc +++ b/apps/rocm_rpc/rocm_runtime_pack.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #define TVM_USE_MIOPEN 1 #define __HIP_PLATFORM_HCC__ 1 -#include "../../src/runtime/rocm/rocm_device_api.cc" -#include "../../src/runtime/rocm/rocm_module.cc" #include "../../src/contrib/miopen/conv_forward.cc" #include "../../src/contrib/miopen/miopen_utils.cc" +#include "../../src/runtime/rocm/rocm_device_api.cc" +#include "../../src/runtime/rocm/rocm_module.cc" diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 5b56982a42c5b..30b4ccbc56180 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +include(ExternalProject) + set(PICK_SIM "sim") set(PICK_HW "target") set(PICK_NONE "OFF") @@ -77,6 +79,13 @@ if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") include_directories("${HEXAGON_TOOLCHAIN}/include/iss") link_directories("${HEXAGON_TOOLCHAIN}/lib/iss") list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") + ExternalProject_Add(sim_dev + SOURCE_DIR "${CMAKE_SOURCE_DIR}/src/runtime/hexagon/sim/driver" + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang" + "-DCMAKE_CXX_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" + INSTALL_COMMAND "true" + ) elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") find_hexagon_sdk_root() find_hexagon_toolchain() @@ -87,7 +96,11 @@ elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") include_directories( "${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64") include_directories("${HEXAGON_TOOLCHAIN}/include/iss") - list(APPEND TVM_RUNTIME_LINKER_LIBS "-ldl") + list(APPEND TVM_RUNTIME_LINKER_LIBS "dl") + if(BUILD_FOR_ANDROID) + # Hexagon runtime uses __android_log_print, which is in liblog. + list(APPEND TVM_RUNTIME_LINKER_LIBS "log") + endif() endif() file(GLOB RUNTIME_HEXAGON_SRCS src/runtime/hexagon/*.cc) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 4af39e088b236..31208bca3af42 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -89,7 +89,7 @@ elseif(PYTHON) # VTA FPGA driver sources if(USE_VTA_FPGA) - file(GLOB FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/*.cc) + file(GLOB FPGA_RUNTIME_SRCS vta/runtime/*.cc) # Rules for Zynq-class FPGAs with pynq OS support (see pynq.io) if(${VTA_TARGET} STREQUAL "pynq" OR ${VTA_TARGET} STREQUAL "ultra96") @@ -108,6 +108,7 @@ elseif(PYTHON) endforeach() if(${VTA_TARGET} STREQUAL "pynq" OR ${VTA_TARGET} STREQUAL "ultra96") + target_include_directories(vta PUBLIC ${VTA_HW_PATH}/include) target_link_libraries(vta ${__cma_lib}) elseif(${VTA_TARGET} STREQUAL "de10nano") # DE10-Nano rules #target_compile_definitions(vta PUBLIC VTA_MAX_XFER=2097152) # (1<<21) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 10c8c62d970b6..abd7c0d04dabf 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -63,3 +63,11 @@ RUN bash /install/ubuntu_install_antlr.sh # Chisel deps for TSIM COPY install/ubuntu_install_chisel.sh /install/ubuntu_install_chisel.sh RUN bash /install/ubuntu_install_chisel.sh + +# TFLite deps +COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh +RUN bash /install/ubuntu_install_tflite.sh + +# TensorFlow deps +COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh +RUN bash /install/ubuntu_install_tensorflow.sh diff --git a/docker/build.sh b/docker/build.sh index d5925dcae2880..43f0a08700a40 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -67,6 +67,7 @@ fi if [[ "$1" == "--cache-from" ]]; then shift 1 cached_image="$1" + CI_DOCKER_BUILD_EXTRA_PARAMS+=("--cache-from tvm.$CONTAINER_TYPE") CI_DOCKER_BUILD_EXTRA_PARAMS+=("--cache-from $cached_image") shift 1 fi diff --git a/docker/install/ubuntu_install_mxnet.sh b/docker/install/ubuntu_install_mxnet.sh index d587843d4dec0..aa04d4c19177b 100755 --- a/docker/install/ubuntu_install_mxnet.sh +++ b/docker/install/ubuntu_install_mxnet.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip3 install mxnet-mkl==1.6.0 +pip3 install mxnet==1.6.0 diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index ff22ea31cdd9e..310e6507e3f3f 100755 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -29,5 +29,11 @@ curl -s -S -L https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default . $CARGO_HOME/env rustup component add rustfmt +# install wasmtime +export WASMTIME_HOME=/opt/wasmtime +curl https://wasmtime.dev/install.sh -sSf | bash +export PATH="${WASMTIME_HOME}/bin:${PATH}" +rustup target add wasm32-wasi + # make rust usable by all users chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 49b0f2badf828..f7ed4841b3287 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,12 +26,18 @@ cd flatbuffers cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release make install -j8 cd .. -rm -rf flatbuffers # Install flatbuffers python packages. pip3 install flatbuffers pip2 install flatbuffers +# Build the TFLite static library, necessary for building with TFLite ON. +# The library is built at: +# tensorflow/tensorflow/lite/tools/make/gen/*/lib/libtensorflow-lite.a. +git clone https://github.com/tensorflow/tensorflow --branch=r2.1 +./tensorflow/tensorflow/lite/tools/make/download_dependencies.sh +./tensorflow/tensorflow/lite/tools/make/build_lib.sh + # Setup tflite from schema mkdir tflite cd tflite diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index b482d30515d45..8ac4e1ff7d3a1 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -48,9 +48,9 @@ tvm.contrib.dlpack .. automodule:: tvm.contrib.dlpack :members: -tvm.contrib.emscripten -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: tvm.contrib.emscripten +tvm.contrib.emcc +~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.emcc :members: tvm.contrib.miopen diff --git a/docs/conf.py b/docs/conf.py index 6ef86ca5a39e0..a0922cfa653d2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -283,7 +283,7 @@ def process_docstring(app, what, name, obj, options, lines): def setup(app): app.connect('autodoc-process-docstring', process_docstring) - app.add_stylesheet('css/tvm_theme.css') + app.add_css_file('css/tvm_theme.css') app.add_config_value('recommonmark_config', { 'url_resolver': lambda url: github_doc_root + url, 'auto_doc_ref': True diff --git a/docs/vta/install.md b/docs/vta/install.md deleted file mode 100644 index a938a67218ffb..0000000000000 --- a/docs/vta/install.md +++ /dev/null @@ -1,419 +0,0 @@ - - - - - - - - - - - - - - - - - -VTA Installation Guide -====================== - -We present three installation guides, each extending on the previous one: -1. [Simulator installation](#vta-simulator-installation) -2. [PYNQ-based test setup](#vta-pynq-based-test-setup) -3. [Custom test setup for Intel FPGA](#vta-custom-test-setup-for-intel-fpga) -4. [FPGA toolchain installation](#vta-fpga-toolchain-installation) - -## VTA Simulator Installation - -You need [TVM installed](https://tvm.apache.org/docs/install/index.html) on your machine. -For a quick and easy start, checkout the [Docker Guide](https://tvm.apache.org/docs/install/docker.html). - -You'll need to set the following paths to use VTA: -```bash -export TVM_PATH= -export VTA_HW_PATH=$TVM_PATH/3rdparty/vta-hw -``` - -The VTA functional simulation library needs to be enabled when building TVM. -```bash -cd -mkdir build -cp cmake/config.cmake build/. -echo 'set(USE_VTA_FSIM ON)' >> build/config.cmake -cd build && cmake .. && make -j4 -``` - -Add the VTA python library to your python path to run the VTA examples. - -```bash -export PYTHONPATH=/path/to/vta/python:${PYTHONPATH} -``` - -### Testing your VTA Simulation Setup - -To ensure that you've properly installed the VTA python package, run the following 2D convolution testbench. - -```bash -python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -``` - -> Note: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. - -You are invited to try out our [VTA programming tutorials](https://tvm.apache.org/docs/vta/tutorials/index.html). - - -### Advanced Configuration (optional) - -VTA is a generic configurable deep learning accelerator. -The configuration is specified by `vta_config.json` under `3rdparty/vta-hw/config`. -This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack. - -The VTA configuration file also specifies the TVM compiler target. -When `TARGET` is set to `sim`, all TVM workloads execute on the VTA simulator. -You can modify the content of the configuration file to rebuild VTA to a different parameterization. -To do so, - -```bash -cd -vim 3rdparty/vta-hw/config/vta_config.json -# edit vta_config.json -make -``` - -## VTA Pynq-Based Test Setup - -This second guide extends the *VTA Simulator Installation* guide above to run FPGA hardware tests of the complete TVM and VTA software-hardware stack. -In terms of hardware components you'll need: -* The [Pynq](http://www.pynq.io/) FPGA development board which can be acquired for $200, or $150 for academics from [Digilent](https://store.digilentinc.com/pynq-z1-python-productivity-for-zynq/). -* An Ethernet-to-USB adapter to connect the Pynq board to your development machine. -* An 8+GB micro SD card. -* An AC to DC 12V 3A power adapter. - -This guide covers the following themes: -1. Pynq board setup instructions. -2. Pynq-side RPC server build and deployment. -3. Revisiting the test examples from the *VTA Simulator Installation* guide, this time executing on the Pynq board. - -### Pynq Board Setup - -Setup your Pynq board based on the [Pynq board getting started tutorial](http://pynq.readthedocs.io/en/latest/getting_started.html). -You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). -* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.4](http://www.pynq.io/board.html)(released February 22rd 2019), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). -* For this test setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Ethernet setup instructions. To be able to talk to the board, make sure to [assign your computer a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) - -Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: -```bash -# To connect to the Pynq board use the [username, password] combo: [xilinx, xilinx] -ssh xilinx@192.168.2.99 -``` - -### Pynq-Side RPC Server Build & Deployment - -Because the direct board-to-computer connection prevents the board from directly accessing the internet, we'll need to mount the Pynq's file system to your development machine's file system with [sshfs](https://www.digitalocean.com/community/tutorials/how-to-use-sshfs-to-mount-remote-file-systems-over-ssh). Next we directly clone the TVM repository into the sshfs mountpoint on your development machine. - -```bash -# On the Host-side -mkdir -sshfs xilinx@192.168.2.99:/home/xilinx -cd -git clone --recursive https://github.com/apache/incubator-tvm tvm -# When finished, you can leave the moutpoint and unmount the directory -cd ~ -sudo umount -``` - -Now that we've cloned the VTA repository in the Pynq's file system, we can ssh into it and launch the build of the TVM-based RPC server. -The build process should take roughly 5 minutes. - -```bash -ssh xilinx@192.168.2.99 -# Build TVM runtime library (takes 5 mins) -cd /home/xilinx/tvm -mkdir build -cp cmake/config.cmake build/. -echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake -# Copy pynq specific configuration -cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json -cd build -cmake .. -make runtime vta -j2 -# Build VTA RPC server (takes 1 min) -cd .. -sudo ./apps/vta_rpc/start_rpc_server.sh # pw is 'xilinx' -``` - -You should see the following being displayed when starting the RPC server. In order to run the next examples, you'll need to leave the RPC server running in an `ssh` session. -``` -INFO:root:RPCServer: bind to 0.0.0.0:9091 -``` - -Tips regarding the Pynq RPC Server: -* The RPC server should be listening on port `9091`. If not, an earlier process might have terminated unexpectedly and it's recommended in this case to just reboot the Pynq, and re-run the RPC server. -* To kill the RPC server, just send the `Ctrl + c` command. You can re-run it with `sudo ./apps/pynq_rpc/start_rpc_server.sh`. -* If unresponsive, the board can be rebooted by power-cycling it with the physical power switch. - -### Testing your Pynq-based Hardware Setup - -Before running the examples on your development machine, you'll need to configure your host environment as follows: -```bash -# On the Host-side -export VTA_RPC_HOST=192.168.2.99 -export VTA_RPC_PORT=9091 -``` - -In addition, you'll need to edit the `vta_config.json` file on the host to indicate that we are targeting the Pynq platform, by setting the `TARGET` field to `"pynq"`. -> Note: in contrast to our simulation setup, there are no libraries to compile on the host side since the host offloads all of the computation to the Pynq board. - -```bash -# On the Host-side -cd -cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json -``` - -This time again, we will run the 2D convolution testbench. -Beforehand, we need to program the Pynq board FPGA with a VTA bitstream, and build the VTA runtime via RPC. -The following `test_program_rpc.py` script will perform two operations: -* FPGA programming, by downloading a pre-compiled bitstream from a [VTA bitstream repository](https://github.com/uwsaml/vta-distro) that matches the default `vta_config.json` configuration set by the host, and sending it over to the Pynq via RPC to program the Pynq's FPGA. -* Runtime building on the Pynq, which needs to be run every time the `vta_config.json` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete so be patient! - -```bash -# On the Host-side -python /vta/tests/python/pynq/test_program_rpc.py -``` - -> Tip: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq `ssh` session. - -We are now ready to run the 2D convolution testbench in hardware. - -```bash -# On the Host-side -python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -``` - -The performance metrics measured on the Pynq board will be reported for each convolutional layer. - -You can also try out our [VTA programming tutorials](https://tvm.apache.org/docs/vta/tutorials/index.html). - -## VTA Custom Test Setup for Intel FPGA - -Similar to the PYNQ side setup steps, this third guide bring us the details on how can we setup up the Linux environment for Intel FPGA boards like DE10-Nano. - -In terms of hardware components, you would need the [DE10-Nano Development Kit](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&No=1046), which can be acquired for $130, or $100 for academics from [Terasic](https://www.terasic.com.tw/). A microSD card would be delivered the kit. Power cables and USB cables would be included as well. However, an additional Ethernet cable would be needed to connect the board to LAN. - -The rest part of this guide would provide the steps to - -* Flash the microSD card with latest Angstrom Linux image -* Cross compilation setup -* Device-side RPC server setup and deployment - -### DE10-Nano Board Setup - -Before powering up the device, we need to flash the microSD card image with latest Angstrom Linux image. - -#### Flash SD Card and Boot Angstrom Linux - -To flash SD card and boot Linux on DE10-Nano, it is recommended to navigate to the [Resource](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&CategoryNo=167&No=1046&PartNo=4) tab of the DE10-Nano product page from Terasic Inc. -After registration and login on the webpage, the prebuilt Angstrom Linux image would be available for downloading and flashing. -Specifically, to flash the downloaded Linux SD card image into your physical SD card: - -First, extract the gzipped archive file. - -``` bash -tar xf de10-nano-image-Angstrom-v2016.12.socfpga-sdimg.2017.03.31.tgz -``` - -This would produce a single SD card image named `de10-nano-image-Angstrom-v2016.12.socfpga-sdimg` (approx. 2.4 GB), it contains all the file systems to boot Angstrom Linux. - -Second, plugin a SD card that is ready to flash in your PC, and identify the device id for the disk with `fdisk -l`, or `gparted` if you feel better to use GUI. The typical device id for your disk would likely to be `/dev/sdb`. - -Then, flash the disk image into your physical SD card with the following command: - -``` bash -# NOTE: root privilege is typically required to run the following command. -dd if=de10-nano-image-Angstrom-v2016.12.socfpga-sdimg of=/dev/sdb status=progress -``` -This would take a few minutes for your PC to write the whole file systems into the SD card. -After this process completes, you are ready to unmount the SD card and insert it into your DE10-Nano board. -Now you can connect the power cable and serial port to boot the Angstrom Linux. - -> Note: When boot up from the microSD card, you might notice the incompatibility of the linux kernel `zImage` in the microSD card. -> In this case, you might need to build the `zImage` file of your own from [socfpga-4.9.78-ltsi](https://github.com/altera-opensource/linux-socfpga/tree/socfpga-4.9.78-ltsi) branch of the [linux-socfpga](https://github.com/altera-opensource/linux-socfpga) repository. -> For a quick fix, you can also download a prebuilt version of the `zImage` file [here](https://raw.githubusercontent.com/liangfu/de10-nano-supplement/master/zImage). - -After connecting the usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using `minicom` on your host PC: - -``` bash -# NOTE: root privilege is typically required to run the following command. -minicom -D /dev/ttyUSB0 -``` - -The default user name for the device would be `root`, and the password is empty for the default user. - -You may now start to install supporting Python3 packages (TVM has dropped the support for Python2), specifically, they are `numpy`, `attrs` and `decorator`. - -> Note: You might fail to install `numpy` by using `pip3` on the DE10-Nano device. -> In that case, you have the option to either build your own filesystem image for the board from [meta-de10-nano](https://github.com/intel/meta-de10-nano) repository; -> an alternative option is to download prebuilt packages from existing Linux distributions, e.g. Debian. -> For a quick fix, we have concatenated the supplementary binary files [here](https://raw.githubusercontent.com/liangfu/de10-nano-supplement/master/rootfs_supplement.tgz), and you can extract the files into the root filesystem. - -#### Install Required Python Packages - -After accessing bash terminal from the serial port, we need to install required Python packages before building and installing TVM and VTA programs. - -#### Build Additional Components to Use VTA Bitstream - -To use the above built bitstream on DE10-Nano hardware, several additional components need to be compiled for the system. -Specifically, to compile application executables for the system, you need to download and install [SoCEDS](http://fpgasoftware.intel.com/soceds/18.1/?edition=standard&download_manager=dlm3&platform=linux) (recommended), or alternatively install the `g++-arm-linux-gnueabihf` package on your host machine. You would also need a `cma` kernel module to allocate contigous memory, and a driver for communicating with the VTA subsystem. - -## VTA FPGA Toolchain Installation - -This last guide allows users to generate custom VTA bitstreams using free-to-use Xilinx or Intel compilation toolchains. - -### Xilinx Toolchain Installation - -We recommend using `Vivado 2018.3` since our scripts have been tested to work on this version of the Xilinx toolchains. -Our guide is written for Linux (Ubuntu) installation. - -You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPACK 2018.3](https://www.xilinx.com/products/design-tools/vivado.html), which a license-free version of the Vivado HLx toolchain. - -#### Obtaining and Launching the Vivado GUI Installer - -1. Go to the [download webpage](https://www.xilinx.com/support/download/index.html/content/xilinx/en/downloadNav/vivado-design-tools/2018-3.html), and download the Linux Self Extracting Web Installer for Vivado HLx 2018.3: WebPACK and Editions. -2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. -3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called `Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin`. -4. Now that the file is downloaded, go to your `Downloads` directory, and change the file permissions so it can be executed: -```bash -chmod u+x Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin -``` -5. Now you can execute the binary: -```bash -./Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin -``` - -#### Xilinx Vivado GUI Installer Steps - -At this point you've launched the Vivado 2018.3 Installer GUI program. - -1. Click “Next” on the *Welcome* screen. -2. On the *Select Install Type* screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next” . -3. On the *Accept License Agreements* screen, accept all terms before clicking “Next”. -4. On the *Select Edition to Install* screen, select the “Vivado HL WebPACK” before clicking “Next” . -5. Under the *Vivado HL WebPACK* screen, before hitting “Next", check the following options (the rest should be unchecked): - * Design Tools -> Vivado Design Suite -> Vivado - * Devices -> Production Devices -> SoCs -> Zynq-7000 (if you are targeting the Pynq board) - * Devices -> Production Devices -> SoCs -> UltraScale+ MPSoC (if you are targeting the Ultra-96 board) -6. Your total download size should be about 5GB and the amount of Disk Space Required 23GB. -7. On the *Select Destination Directory* screen, set the installation directory before clicking “Next”. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to the directory. In that case select a path that doesn’t require special write permissions (e.g. your home directory). -8. On the *Installation Summary* screen, hit “Install”. -9. An *Installation Progress* window will pop-up to track progress of the download and the installation. -10. This process will take about 20-30 minutes depending on your connection speed. -11. A pop-up window will inform you that the installation completed successfully. Click "OK". -12. Finally the *Vivado License Manager* will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. - -#### Environment Setup - -The last step is to update your `~/.bashrc` with the following lines. This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. -```bash -# Xilinx Vivado 2018.3 environment -export XILINX_VIVADO=${XILINX_PATH}/Vivado/2018.3 -export PATH=${XILINX_VIVADO}/bin:${PATH} -``` - -### Intel Toolchain Installation - -It is recommended to use `Intel Quartus Prime 18.1`, since the test scripts contained in this document have been tested on this version. - -You would need to install Intel's FPGA compilation toolchain, [Quartus Prime Lite](http://fpgasoftware.intel.com/?edition=lite), which is a license-free version of the Intel Quartus Prime software. - -#### Obtaining and Launching the Quartus GUI Installer - -1. Go to the [download center](http://fpgasoftware.intel.com/?edition=lite), and download the linux version of `Quartus Prime (include Nios II EDS)` and `Cyclone V device support` files in the `Separate file` tab. This avoid downloading unused device support files. -2. Sign in the form if you have an account, or register on the right side of the web page to create an account. -3. After signed in, you are able to download the installer and the device support files. -4. Now that the files are downloaded, go to your `Downloads` directory, and change the file permissions: -```bash -chmod u+x QuartusLiteSetup-18.1.0.625-linux.run -``` -5. Now ensure both the installer and device support files are in the same directory, and you can run the install with: -```bash -./QuartusLiteSetup-18.1.0.625-linux.run -``` -6. Follow the instructions on the pop-up GUI form, and install all the content in the `/usr/local` directory. After installation, `/usr/local/intelFPGA_lite/18.1` would be created and the Quartus program along with other programs would be available in the folder. - -#### Environment Setup - -Similar to what should be done for Xilinx toolchain, the following line should be added to your `~/.bashrc`. -```bash -# Intel Quartus 18.1 environment -export QUARTUS_ROOTDIR="/usr/local/intelFPGA_lite/18.1/quartus" -export PATH=${QUARTUS_ROOTDIR}/bin:${PATH} -export PATH=${QUARTUS_ROOTDIR}/sopc_builder/bin:${PATH} -``` -This would add quartus binary path into your `PATH` environment variable, so you can launch compilation scripts from the command line. - -### HLS-based Custom VTA Bitstream Compilation for PYNQ - -High-level hardware parameters are listed in the VTA configuration file and can be customized by the user. -For this custom VTA bitstream compilation exercise, we'll change the frequency of our design, so it can be clocked a little faster. -* Set the `HW_FREQ` field to `142`. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution. -* Set the `HW_CLK_TARGET` to `6`. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. - -Bitstream generation is driven by a top-level `Makefile` under `/3rdparty/vta-hw/hardware/xilinx/`. - -If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: -```bash -cd /3rdparty/vta-hw/hardware/xilinx -make ip MODE=sim -``` - -If you just want to generate the HLS-based VTA IP cores without launching the entire design place and route, enter: -```bash -make ip -``` -You'll be able to view the HLS synthesis reports under `/3rdparty/vta-hw/build/hardware/xilinx/hls/` `//solution0/syn/report/_csynth.rpt` -> Note: The `` name is a string that summarizes the VTA configuration parameters listed in the `vta_config.json`. The `` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline. - -Finally to run the full hardware compilation and generate the VTA bitstream, run: - -```bash -make -``` - -This process is lengthy, and can take around up to an hour to complete depending on your machine's specs. -We recommend setting the `VTA_HW_COMP_THREADS` variable in the Makefile to take full advantage of all the cores on your development machine. - -Once the compilation completes, the generated bitstream can be found under `/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit`. - -### Chisel-based Custom VTA Bitstream Compilation for DE10-Nano - -Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file [Configs.scala](https://github.com/apache/incubator-tvm/blob/master/3rdparty/vta-hw/hardware/chisel/src/main/scala/core/Configs.scala), and they can be customized by the user. - -For Intel FPGA, bitstream generation is driven by a top-level `Makefile` under `/3rdparty/vta-hw/hardware/intel`. - -If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter: -```bash -cd /3rdparty/vta-hw/hardware/intel -make ip -``` -Then you'll be able to locate the generated verilog file at `/3rdparty/vta-hw/build/hardware/intel/chisel//VTA.DefaultDe10Config.v`. - -If you would like to run the full hardware compilation for the `de10nano` board: -```bash -make -``` - -This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process. - -Once the compilation completes, the generated bistream can be found under `/3rdparty/vta-hw/build/hardware/intel/quartus//export/vta.rbf`. You can also open the Quartus project file (.qpf) available at `/3rdparty/vta-hw/build/hardware/intel/quartus//de10_nano_top.qpf` to look around the generated reports. - -### Use the Custom Bitstream - -We can program the new VTA FPGA bitstream by setting the bitstream path of the `vta.program_fpga()` function in the tutorial examples, or in the `test_program_rpc.py` script. - -```python -vta.program_fpga(remote, bitstream="/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit") -``` - -Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency. -Do you observe a noticeable performance increase on the ImageNet classification example? diff --git a/docs/vta/install.rst b/docs/vta/install.rst new file mode 100644 index 0000000000000..b68fab7da2d1e --- /dev/null +++ b/docs/vta/install.rst @@ -0,0 +1,488 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +VTA Installation Guide +====================== + +We present three installation guides, each extending on the previous one: + +1. `Simulator Installation`_ +2. `Xilinx Pynq FPGA Setup`_ +3. `Intel DE10 FPGA Setup`_ +4. `Bitstream Generation with Xilinx Toolchains`_ +5. `Bitstream Generation with Intel Toolchains`_ + + +Simulator Installation +---------------------- + +You need `TVM installed `_ on your machine. +For a quick and easy start, checkout the `Docker Guide `_. + +You'll need to set the following paths to use VTA: + +.. code:: bash + + export TVM_PATH= + export VTA_HW_PATH=$TVM_PATH/3rdparty/vta-hw + +The VTA functional simulation library needs to be enabled when building TVM. + +.. code:: bash + + cd + mkdir build + cp cmake/config.cmake build/. + echo 'set(USE_VTA_FSIM ON)' >> build/config.cmake + cd build && cmake .. && make -j4 + +Add the VTA python library to your python path to run the VTA examples. + +.. code:: bash + + export PYTHONPATH=/path/to/vta/python:${PYTHONPATH} + +Testing your VTA Simulation Setup +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To ensure that you've properly installed the VTA python package, run the following 2D convolution testbench. + +.. code:: bash + + python /vta/tests/python/integration/test_benchmark_topi_conv2d.py + +You are invited to try out our `VTA programming tutorials `_. + + **Note**: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. + +Advanced Configuration (optional) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +VTA is a generic configurable deep learning accelerator. +The configuration is specified by ``vta_config.json`` under ``3rdparty/vta-hw/config``. +This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack. + +The VTA configuration file also specifies the TVM compiler target. +When ``TARGET`` is set to ``sim``, all TVM workloads execute on the VTA simulator. +You can modify the content of the configuration file to rebuild VTA to a different parameterization. +To do so, + +.. code:: bash + + cd + vim 3rdparty/vta-hw/config/vta_config.json + # edit vta_config.json + make + + + +Xilinx Pynq FPGA Setup +---------------------- + +This second guide extends the *VTA Simulator Installation* guide above to run FPGA hardware tests of the complete TVM and VTA software-hardware stack. +In terms of hardware components you'll need: + +* The `Pynq `_ FPGA development board which can be acquired for $200, or $150 for academics from `Digilent `_. +* An Ethernet-to-USB adapter to connect the Pynq board to your development machine. +* An 8+GB micro SD card. +* An AC to DC 12V 3A power adapter. + +This guide covers the following themes: + +1. Pynq board setup instructions. +2. Pynq-side RPC server build and deployment. +3. Revisiting the test examples from the *VTA Simulator Installation* guide, this time executing on the Pynq board. + +Pynq Board Setup +^^^^^^^^^^^^^^^^ + +Setup your Pynq board based on the `Pynq board getting started tutorial `_. + +You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). + +* Make sure that you've downloaded the latest Pynq image, `PYNQ-Z1 v2.4 `_ (released February 22rd 2019), and have imaged your SD card with it (we recommend the free `Etcher `_ program). +* For this test setup, follow the `"Connect to a Computer" `_ Ethernet setup instructions. To be able to talk to the board, make sure to `assign your computer a static IP address `_ + +Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: + +.. code:: bash + + # To connect to the Pynq board use the combo: + ssh xilinx@192.168.2.99 + +Pynq-Side RPC Server Build & Deployment +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Because the direct board-to-computer connection prevents the board from directly accessing the internet, we'll need to mount the Pynq's file system to your development machine's file system with `sshfs `_. Next we directly clone the TVM repository into the sshfs mountpoint on your development machine. + +.. code:: bash + + # On the Host-side + mkdir + sshfs xilinx@192.168.2.99:/home/xilinx + cd + git clone --recursive https://github.com/apache/incubator-tvm tvm + # When finished, you can leave the moutpoint and unmount the directory + cd ~ + sudo umount + +Now that we've cloned the VTA repository in the Pynq's file system, we can ssh into it and launch the build of the TVM-based RPC server. +The build process should take roughly 5 minutes. + +.. code:: bash + + ssh xilinx@192.168.2.99 + # Build TVM runtime library (takes 5 mins) + cd /home/xilinx/tvm + mkdir build + cp cmake/config.cmake build/. + echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake + # Copy pynq specific configuration + cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json + cd build + cmake .. + make runtime vta -j2 + # Build VTA RPC server (takes 1 min) + cd .. + sudo ./apps/vta_rpc/start_rpc_server.sh # pw is 'xilinx' + + +You should see the following being displayed when starting the RPC server. In order to run the next examples, you'll need to leave the RPC server running in an ``ssh`` session. + +.. code:: bash + + INFO:root:RPCServer: bind to 0.0.0.0:9091 + + +Tips regarding the Pynq RPC Server: + +* The RPC server should be listening on port ``9091``. If not, an earlier process might have terminated unexpectedly and it's recommended in this case to just reboot the Pynq, and re-run the RPC server. +* To kill the RPC server, just send the ``Ctrl + c`` command. You can re-run it with ``sudo ./apps/pynq_rpc/start_rpc_server.sh``. +* If unresponsive, the board can be rebooted by power-cycling it with the physical power switch. + +Testing your Pynq-based Hardware Setup +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Before running the examples on your development machine, you'll need to configure your host environment as follows: + +.. code:: bash + + # On the Host-side + export VTA_RPC_HOST=192.168.2.99 + export VTA_RPC_PORT=9091 + + +In addition, you'll need to edit the ``vta_config.json`` file on the host to indicate that we are targeting the Pynq platform, by setting the ``TARGET`` field to ``"pynq"``. +> Note: in contrast to our simulation setup, there are no libraries to compile on the host side since the host offloads all of the computation to the Pynq board. + +.. code:: bash + + # On the Host-side + cd + cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json + + +This time again, we will run the 2D convolution testbench. +Beforehand, we need to program the Pynq board FPGA with a VTA bitstream, and build the VTA runtime via RPC. +The following ``test_program_rpc.py`` script will perform two operations: +* FPGA programming, by downloading a pre-compiled bitstream from a `VTA bitstream repository `_ that matches the default ``vta_config.json`` configuration set by the host, and sending it over to the Pynq via RPC to program the Pynq's FPGA. +* Runtime building on the Pynq, which needs to be run every time the ``vta_config.json`` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete so be patient! + +.. code:: bash + + # On the Host-side + python /vta/tests/python/pynq/test_program_rpc.py + + +We are now ready to run the 2D convolution testbench in hardware. + +.. code:: bash + + # On the Host-side + python /vta/tests/python/integration/test_benchmark_topi_conv2d.py + +The performance metrics measured on the Pynq board will be reported for each convolutional layer. + +**Tip**: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq ``ssh`` session. + +You can also try out our `VTA programming tutorials `_. + + + +Intel DE10 FPGA Setup +--------------------- + +Similar to the PYNQ side setup steps, this third guide bring us the details on how can we setup up the Linux environment for Intel FPGA boards like DE10-Nano. + +In terms of hardware components, you would need the `DE10-Nano Development Kit `_, which can be acquired for $130, or $100 for academics from `Terasic `_. A microSD card would be delivered the kit. Power cables and USB cables would be included as well. However, an additional Ethernet cable would be needed to connect the board to LAN. + +The rest part of this guide would provide the steps to + +* Flash the microSD card with latest Angstrom Linux image +* Cross compilation setup +* Device-side RPC server setup and deployment + +DE10-Nano Board Setup +^^^^^^^^^^^^^^^^^^^^^ + +Before powering up the device, we need to flash the microSD card image with latest Angstrom Linux image. + +Flash SD Card and Boot Angstrom Linux +""""""""""""""""""""""""""""""""""""" + +To flash SD card and boot Linux on DE10-Nano, it is recommended to navigate to the `Resource `_ tab of the DE10-Nano product page from Terasic Inc. +After registration and login on the webpage, the prebuilt Angstrom Linux image would be available for downloading and flashing. +Specifically, to flash the downloaded Linux SD card image into your physical SD card: + +First, extract the gzipped archive file. + +.. code:: bash + + tar xf de10-nano-image-Angstrom-v2016.12.socfpga-sdimg.2017.03.31.tgz + +This would produce a single SD card image named ``de10-nano-image-Angstrom-v2016.12.socfpga-sdimg`` (approx. 2.4 GB), it contains all the file systems to boot Angstrom Linux. + +Second, plugin a SD card that is ready to flash in your PC, and identify the device id for the disk with ``fdisk -l``, or ``gparted`` if you feel better to use GUI. The typical device id for your disk would likely to be ``/dev/sdb``. + +Then, flash the disk image into your physical SD card with the following command: + +.. code:: bash + + # NOTE: root privilege is typically required to run the following command. + dd if=de10-nano-image-Angstrom-v2016.12.socfpga-sdimg of=/dev/sdb status=progress + +This would take a few minutes for your PC to write the whole file systems into the SD card. +After this process completes, you are ready to unmount the SD card and insert it into your DE10-Nano board. +Now you can connect the power cable and serial port to boot the Angstrom Linux. + + **Note**: When boot up from the microSD card, you might notice the incompatibility of the linux kernel ``zImage`` in the microSD card. + In this case, you might need to build the ``zImage`` file of your own from `socfpga-4.9.78-ltsi `_ branch of the `linux-socfpga `_ repository. + For a quick fix, you can also download a prebuilt version of the ``zImage`` file `from this link `_. + +After connecting the usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using ``minicom`` on your host PC: + +.. code:: bash + + # NOTE: root privilege is typically required to run the following command. + minicom -D /dev/ttyUSB0 + +The default user name for the device would be ``root``, and the password is empty for the default user. + +You may now start to install supporting Python3 packages (TVM has dropped the support for Python2), specifically, they are ``numpy``, ``attrs`` and ``decorator``. + + **Note**: You might fail to install ``numpy`` by using ``pip3`` on the DE10-Nano device. + In that case, you have the option to either build your own filesystem image for the board from `meta-de10-nano `_ repository; + an alternative option is to download prebuilt packages from existing Linux distributions, e.g. Debian. + For a quick fix, we have concatenated the supplementary binary files `here `_, and you can extract the files into the root filesystem. + +Install Required Python Packages +"""""""""""""""""""""""""""""""" + +After accessing bash terminal from the serial port, we need to install required Python packages before building and installing TVM and VTA programs. + +Build Additional Components to Use VTA Bitstream +"""""""""""""""""""""""""""""""""""""""""""""""" + +To use the above built bitstream on DE10-Nano hardware, several additional components need to be compiled for the system. +Specifically, to compile application executables for the system, you need to download and install `SoCEDS `_ (recommended), or alternatively install the ``g++-arm-linux-gnueabihf`` package on your host machine. You would also need a ``cma`` kernel module to allocate contigous memory, and a driver for communicating with the VTA subsystem. + + +Bitstream Generation with Xilinx Toolchains +------------------------------------------- + +If you're interested in generating the Xilinx FPGA bitstream on your own instead of using the pre-built VTA bistreams, follow the instructions below. + +Xilinx Toolchain Installation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We recommend using Vivado 2018.3 since our scripts have been tested to work on this version of the Xilinx toolchains. +Our guide is written for Linux (Ubuntu) installation. + +You’ll need to install Xilinx’ FPGA compilation toolchain, `Vivado HL WebPACK 2018.3 `_, which a license-free version of the Vivado HLx toolchain. + +Obtaining and Launching the Vivado GUI Installer +"""""""""""""""""""""""""""""""""""""""""""""""" + +1. Go to the `download webpage `_, and download the Linux Self Extracting Web Installer for Vivado HLx 2018.3: WebPACK and Editions. +2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. +3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called ``Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin``. +4. Now that the file is downloaded, go to your ``Downloads`` directory, and change the file permissions so it can be executed: + +.. code:: bash + + chmod u+x Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin + +5. Now you can execute the binary: + +.. code:: bash + + ./Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin + +Xilinx Vivado GUI Installer Steps +""""""""""""""""""""""""""""""""" + +At this point you've launched the Vivado 2018.3 Installer GUI program. + +1. Click “Next” on the "Welcome" screen. +2. On the "Select Install Type" screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next”. +3. On the "Accept License Agreements" screen, accept all terms before clicking “Next”. +4. On the "Select Edition to Install" screen, select the “Vivado HL WebPACK” before clicking “Next”. +5. Under the "Vivado HL WebPACK" screen, before hitting “Next", check the following options (the rest should be unchecked): + * Design Tools -> Vivado Design Suite -> Vivado + * Devices -> Production Devices -> SoCs -> Zynq-7000 (if you are targeting the Pynq board) + * Devices -> Production Devices -> SoCs -> UltraScale+ MPSoC (if you are targeting the Ultra-96 board) +6. Your total download size should be about 5GB and the amount of Disk Space Required 23GB. +7. On the "Select Destination Directory" screen, set the installation directory before clicking “Next”. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to the directory. In that case select a path that doesn’t require special write permissions (e.g. your home directory). +8. On the "Installation Summary" screen, hit “Install”. +9. An "Installation Progress" window will pop-up to track progress of the download and the installation. +10. This process will take about 20-30 minutes depending on your connection speed. +11. A pop-up window will inform you that the installation completed successfully. Click "OK". +12. Finally the "Vivado License Manager" will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. + +Environment Setup +""""""""""""""""" + +The last step is to update your ``~/.bashrc`` with the following lines. This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. + +.. code:: bash + + # Xilinx Vivado 2018.3 environment + export XILINX_VIVADO=${XILINX_PATH}/Vivado/2018.3 + export PATH=${XILINX_VIVADO}/bin:${PATH} + +HLS-based Custom VTA Bitstream Compilation for PYNQ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +High-level hardware parameters are listed in the VTA configuration file and can be customized by the user. +For this custom VTA bitstream compilation exercise, we'll change the frequency of our design, so it can be clocked a little faster. + +* Set the ``HW_FREQ`` field to ``142``. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution. +* Set the ``HW_CLK_TARGET`` to ``6``. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. + +Bitstream generation is driven by a top-level ``Makefile`` under ``/3rdparty/vta-hw/hardware/xilinx/``. + +If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: + +.. code:: bash + + cd /3rdparty/vta-hw/hardware/xilinx + make ip MODE=sim + + +If you just want to generate the HLS-based VTA IP cores without launching the entire design place and route, enter: + +.. code:: bash + + make ip + +You'll be able to view the HLS synthesis reports under ``/3rdparty/vta-hw/build/hardware/xilinx/hls///solution0/syn/report/_csynth.rpt`` + + **Note**: The ```` name is a string that summarizes the VTA configuration parameters listed in the ``vta_config.json``. The ```` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline. + +Finally to run the full hardware compilation and generate the VTA bitstream, run ``make``. + +This process is lengthy, and can take around up to an hour to complete depending on your machine's specs. +We recommend setting the ``VTA_HW_COMP_THREADS`` variable in the Makefile to take full advantage of all the cores on your development machine. + +Once the compilation completes, the generated bitstream can be found under ``/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit``. + +Using A Custom Bitstream +^^^^^^^^^^^^^^^^^^^^^^^^ + +We can program the new VTA FPGA bitstream by setting the bitstream path of the ``vta.program_fpga()`` function in the tutorial examples, or in the ``test_program_rpc.py`` script. + +.. code:: python + + vta.program_fpga(remote, bitstream="/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit") + +Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency. +Do you observe a noticeable performance increase on the ImageNet classification example? + + + +Bitstream Generation with Intel Toolchains +------------------------------------------- + +If you're interested in generating the Xilinx FPGA bitstream on your own instead of using the pre-built VTA bistreams, follow the instructions below. + +Intel Toolchain Installation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is recommended to use ``Intel Quartus Prime 18.1``, since the test scripts contained in this document have been tested on this version. + +You would need to install Intel's FPGA compilation toolchain, `Quartus Prime Lite `_, which is a license-free version of the Intel Quartus Prime software. + +Obtaining and Launching the Quartus GUI Installer +""""""""""""""""""""""""""""""""""""""""""""""""" + +1. Go to the `download center `_, and download the linux version of "Quartus Prime (include Nios II EDS)" and "Cyclone V device support" files in the "Separate file" tab. This avoid downloading unused device support files. +2. Sign in the form if you have an account, or register on the right side of the web page to create an account. +3. After signed in, you are able to download the installer and the device support files. +4. Now that the files are downloaded, go to your ``Downloads`` directory, and change the file permissions: + +.. code:: bash + + chmod u+x QuartusLiteSetup-18.1.0.625-linux.run + +5. Now ensure both the installer and device support files are in the same directory, and you can run the install with: + +.. code:: bash + + ./QuartusLiteSetup-18.1.0.625-linux.run + +6. Follow the instructions on the pop-up GUI form, and install all the content in the ``/usr/local`` directory. After installation, ``/usr/local/intelFPGA_lite/18.1`` would be created and the Quartus program along with other programs would be available in the folder. + +Environment Setup +""""""""""""""""" + +Similar to what should be done for Xilinx toolchain, the following line should be added to your ``~/.bashrc``. + +.. code:: bash + + # Intel Quartus 18.1 environment + export QUARTUS_ROOTDIR="/usr/local/intelFPGA_lite/18.1/quartus" + export PATH=${QUARTUS_ROOTDIR}/bin:${PATH} + export PATH=${QUARTUS_ROOTDIR}/sopc_builder/bin:${PATH} + +This would add quartus binary path into your ``PATH`` environment variable, so you can launch compilation scripts from the command line. + +Chisel-based Custom VTA Bitstream Compilation for DE10-Nano +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file `Configs.scala `_, and they can be customized by the user. + +For Intel FPGA, bitstream generation is driven by a top-level ``Makefile`` under ``/3rdparty/vta-hw/hardware/intel``. + +If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter: + +.. code:: bash + + cd /3rdparty/vta-hw/hardware/intel + make ip + +Then you'll be able to locate the generated verilog file at ``/3rdparty/vta-hw/build/hardware/intel/chisel//VTA.DefaultDe10Config.v``. + +If you would like to run the full hardware compilation for the ``de10nano`` board: + +.. code:: bash + + make + +This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process. + +Once the compilation completes, the generated bistream can be found under ``/3rdparty/vta-hw/build/hardware/intel/quartus//export/vta.rbf``. You can also open the Quartus project file (.qpf) available at ``/3rdparty/vta-hw/build/hardware/intel/quartus//de10_nano_top.qpf`` to look around the generated reports. + diff --git a/golang/sample/gen_mobilenet_lib.py b/golang/sample/gen_mobilenet_lib.py index 4f6a615d14c9b..8becd078fd5e4 100644 --- a/golang/sample/gen_mobilenet_lib.py +++ b/golang/sample/gen_mobilenet_lib.py @@ -18,7 +18,6 @@ import os from tvm import relay from tvm.contrib.download import download_testdata -import tflite.Model ################################################ @@ -49,7 +48,12 @@ def extract(path): # get TFLite model from buffer tflite_model_buf = open(model_file, "rb").read() -tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) +try: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) +except AttributeError: + import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) ############################## diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc index af6e4303a85ad..f599c405d5d58 100644 --- a/golang/src/gotvm.cc +++ b/golang/src/gotvm.cc @@ -24,14 +24,17 @@ // Standard includes #include +#include #include #include #include #include -#include // golang string compatible definition -typedef struct { char *p; int n; } _gostring_; +typedef struct { + char* p; + int n; +} _gostring_; #include #ifdef __cplusplus @@ -39,8 +42,8 @@ extern "C" { #endif // TVM runtime C interface -#include #include +#include /*! * \brief Convert native char array to _gostring_ structure. @@ -53,7 +56,7 @@ extern "C" { * \return _gostring_ object corresponding to native char array. * Caller is responsible to free the memory block allocated here. */ -static _gostring_ _native_to_gostring(const char *p, size_t l) { +static _gostring_ _native_to_gostring(const char* p, size_t l) { _gostring_ ret; ret.p = reinterpret_cast(malloc(l)); if (NULL == ret.p) { @@ -72,10 +75,10 @@ static _gostring_ _native_to_gostring(const char *p, size_t l) { * \param off is the offset in the string object. * \param v is the uint64_t value which need to embed into given string. */ -static void putuint64(std::string *s, size_t off, uint64_t v) { - for (int i = 0; i < 8; i++) { - (*s)[off + i] = (v >> (i * 8)) & 0xff; - } +static void putuint64(std::string* s, size_t off, uint64_t v) { + for (int i = 0; i < 8; i++) { + (*s)[off + i] = (v >> (i * 8)) & 0xff; + } } // TVM runtime C interface wrappers @@ -86,7 +89,7 @@ static void putuint64(std::string *s, size_t off, uint64_t v) { * \return char pointer to TVM-VERSION */ const char* _TVM_VERSION(void) { - const char *version = TVM_VERSION; + const char* version = TVM_VERSION; return version; } @@ -101,16 +104,16 @@ const char* _TVM_VERSION(void) { */ int _TVMFuncListGlobalNames(_gostring_* names) { int names_size; - char **names_array; + char** names_array; int result; - result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array); + result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array); if (result) { return result; } size_t tot = 8; - for (int ii = 0; ii < names_size ; ++ii) { + for (int ii = 0; ii < names_size; ++ii) { tot += 8 + strlen(names_array[ii]); } @@ -118,7 +121,7 @@ int _TVMFuncListGlobalNames(_gostring_* names) { str.resize(tot); putuint64(&str, 0, names_size); size_t off = 8; - for (int64_t ii = 0; ii < names_size ; ++ii) { + for (int64_t ii = 0; ii < names_size; ++ii) { putuint64(&str, off, strlen(names_array[ii])); off += 8; str.replace(off, strlen(names_array[ii]), names_array[ii]); @@ -143,9 +146,9 @@ int _TVMFuncListGlobalNames(_gostring_* names) { * \param array index in native array. */ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p+ind, from_p, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p + ind, from_p, sizeof(TVMValue)); } /*! @@ -157,9 +160,9 @@ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { * \param array index in native array. */ void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p, from_p+ind, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p, from_p + ind, sizeof(TVMValue)); } extern int goTVMCallback(void*, void*, int, void*, void*); @@ -175,21 +178,16 @@ extern int goTVMCallback(void*, void*, int, void*, void*); * * \returns the error status as TVM_DLL */ -int _TVMCallback(TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, +int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, void* resource_handle) { - return goTVMCallback(args, type_codes, num_args, ret, resource_handle); + return goTVMCallback(args, type_codes, num_args, ret, resource_handle); } /*! * _TVMPackedCFuncFinalizer is finalizer for packed function system. * */ -void _TVMPackedCFuncFinalizer(void* resource_handle) { - return; -} +void _TVMPackedCFuncFinalizer(void* resource_handle) { return; } /*! * /brief _ConvertFunction creates a packed function for with given resource handle. @@ -199,11 +197,8 @@ void _TVMPackedCFuncFinalizer(void* resource_handle) { * * /return is an int indicating the return status. */ -int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) { - int ret = TVMFuncCreateFromCFunc(_TVMCallback, - fptr, - _TVMPackedCFuncFinalizer, - fhandle); +int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) { + int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle); return ret; } diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h index 12b594b8c9a92..a053e39bd79a9 100644 --- a/golang/src/gotvm.h +++ b/golang/src/gotvm.h @@ -32,11 +32,11 @@ extern "C" { #endif +#include #include #include #include #include -#include // Some type definitions for golang "C" typedef void* native_voidp; diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index 416067dcdca1a..644249fa75c98 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -23,15 +23,15 @@ */ #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" -#include "src/runtime/workspace_pool.cc" +#include "src/runtime/file_util.cc" #include "src/runtime/library_module.cc" #include "src/runtime/module.cc" -#include "src/runtime/registry.cc" -#include "src/runtime/file_util.cc" -#include "src/runtime/threading_backend.cc" -#include "src/runtime/thread_pool.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/thread_pool.cc" +#include "src/runtime/threading_backend.cc" +#include "src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6ca3ba9cfd559..4623b5e0a2853 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -24,14 +24,14 @@ #ifndef TVM_ARITH_ANALYZER_H_ #define TVM_ARITH_ANALYZER_H_ -#include -#include #include +#include +#include -#include -#include -#include #include +#include +#include +#include namespace tvm { /*! \brief namespace of arithmetic analysis. */ @@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef { */ class ConstIntBoundAnalyzer { public: + using BoundMapType = std::unordered_map; /*! * \brief analyze the expr * \param expr The expression of interest. @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - TVM_DLL ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound); /*! * \brief Update constant int bound information of var. @@ -130,16 +130,15 @@ class ConstIntBoundAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const ConstIntBound& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false); /*! * \brief Bind variable to a range. * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's range. */ - TVM_DLL void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); private: friend class Analyzer; @@ -220,9 +219,7 @@ class ModularSetAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const ModularSet& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false); private: friend class Analyzer; @@ -261,9 +258,7 @@ class RewriteSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -297,9 +292,7 @@ class CanonicalSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - TVM_DLL void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); private: friend class Analyzer; @@ -411,8 +404,9 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param expr The expression we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const PrimExpr& expr); + void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -421,14 +415,16 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const Range& range); + void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Map& variables); + void Bind(const Map& variables, bool override = false); /*! * \brief Whether can we prove expr >= val. @@ -442,6 +438,19 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); + /*! + * \brief Whether can we prove expr < val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param upper_bound The upper bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveLess(const PrimExpr& expr, int64_t upper_bound); /*! * \brief Whether can we prove condition. * diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index b1cb779b4227d..df1a9e7c7a43a 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -23,9 +23,9 @@ #ifndef TVM_ARITH_BOUND_H_ #define TVM_ARITH_BOUND_H_ -#include -#include #include +#include +#include #include #include @@ -38,10 +38,10 @@ class Tensor; } namespace arith { -using tir::Var; -using tir::VarNode; using tir::Domain; using tir::Stmt; +using tir::Var; +using tir::VarNode; /*! * \brief Deduce the bound of the target variable in a expression, @@ -58,8 +58,7 @@ using tir::Stmt; * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(PrimExpr v, PrimExpr cond, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, const Map& relax_map); /*! * \brief Same as DeduceBound with unordered_map signature. @@ -83,9 +82,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(const Stmt& body, - const tir::Buffer& buffer, - bool consider_loads, +Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, bool consider_stores); } // namespace arith diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 86ef906fef0a6..7cd74d245e29b 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -26,14 +26,15 @@ #include #include + #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; //----------------------------------------------- // Integer set data structure. @@ -44,12 +45,7 @@ using tir::IterVar; /*! * \brief Sign type of an integer expression. */ -enum SignType { - kPositive, - kNegative, - kZero, - kUnknown -}; +enum SignType { kPositive, kNegative, kZero, kUnknown }; /*! * \brief Base class of all Integer set containers. @@ -77,9 +73,7 @@ class IntSet : public ObjectRef { * \brief access the internal node container * \return the pointer to the internal node container */ - const IntSetNode* operator->() const { - return static_cast(get()); - } + const IntSetNode* operator->() const { return static_cast(get()); } /*! * \brief Find a range that covers the region. * \param max_range The range to be covered. @@ -152,6 +146,13 @@ class IntSet : public ObjectRef { //----------------------------------------------- // Integer set legacy API. //------------------------------------------------ +/*! + * \brief Convert std::unordered_map to Map + * + * \param dom_map The domain map to convert. + * \return The converted map. + */ +Map ConvertDomMap(const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -160,8 +161,7 @@ class IntSet : public ObjectRef { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const Map& dom_map); +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -169,9 +169,7 @@ IntSet EvalSet(PrimExpr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map); - +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. @@ -180,8 +178,7 @@ IntSet EvalSet(PrimExpr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(Range r, - const Map& dom_map); +IntSet EvalSet(Range r, const Map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -191,8 +188,7 @@ IntSet EvalSet(Range r, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map); +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -200,9 +196,7 @@ IntSet EvalSet(IntSet s, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(Range r, - const std::unordered_map& dom_map); - +IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! @@ -213,9 +207,8 @@ using ExprIntSetMap = std::unordered_map& dom_map); +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 57f3af4bb67b3..ae18cab0a9fa1 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -26,15 +26,16 @@ #include #include + #include #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; /*! * \brief Represent integer constrains including (integer) variables, their ranges and @@ -60,10 +61,8 @@ class IntConstraintsNode : public Object { } bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { - return - equal(variables, other->variables) && - equal(ranges, other->ranges) && - equal(relations, other->relations); + return equal(variables, other->variables) && equal(ranges, other->ranges) && + equal(relations, other->relations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -90,9 +89,7 @@ class IntConstraints : public ObjectRef { * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL IntConstraints(Array variables, - Map ranges, - Array relations); + TVM_DLL IntConstraints(Array variables, Map ranges, Array relations); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -126,11 +123,8 @@ class IntConstraintsTransformNode : public Object { } bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { - return - equal(src, other->src) && - equal(dst, other->dst) && - equal(src_to_dst, other->src_to_dst) && - equal(dst_to_src, other->dst_to_src); + return equal(src, other->src) && equal(dst, other->dst) && + equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src); } void SHashReduce(SHashReducer hash_reduce) const { @@ -161,10 +155,8 @@ class IntConstraintsTransform : public ObjectRef { * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, * e.g., {m -> a, n -> -b} */ - TVM_DLL IntConstraintsTransform(IntConstraints src, - IntConstraints dst, - Map src_to_dst, - Map dst_to_src); + TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, + Map src_to_dst, Map dst_to_src); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; @@ -176,20 +168,16 @@ class IntConstraintsTransform : public ObjectRef { * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. * TODO(yzhliu): From sergei-grechanik: - * computing the proper Smith normal form may improve stability of automatic differentiation - * (generating the same gradient code for slightly different but equivalent input code - * U_{mxm} and V_{nxn} are invertible matrices. - * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, - * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. - * \param S the original A_{mxn}, it will be modified to S_{mxn} - * \param V an identity matrix, it will be modified to V_{nxn} - * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} - * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1} + * computing the proper Smith normal form may improve stability of automatic + * differentiation (generating the same gradient code for slightly different but equivalent input + * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V + * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original + * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to + * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y + * in A x = y. it will be modified to U_{mxm} y_{mx1} */ -void SmithNormalFormDiag(std::vector> *S, - std::vector> *V, - std::vector* x, - std::vector *y); +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y); /*! * \brief Solve linear equations. @@ -201,7 +189,7 @@ void SmithNormalFormDiag(std::vector> *S, * as well as inequalities inferred from the \p system_to_solve. * You can get the mapping from the original variables to the solution via ret->src_to_dst. */ -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index d3ba3e980430a..301d95636ca43 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -24,8 +24,8 @@ #ifndef TVM_ARITH_PATTERN_H_ #define TVM_ARITH_PATTERN_H_ -#include #include +#include #include namespace tvm { @@ -38,8 +38,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars); +Array DetectLinearEquation(const PrimExpr& e, const Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -49,8 +48,7 @@ Array DetectLinearEquation(const PrimExpr& e, * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const PrimExpr& e, - const Array& vars); +Array DetectClipBound(const PrimExpr& e, const Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index e6d4427544465..1d4d4931d0503 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -29,47 +29,42 @@ #ifndef TVM_DRIVER_DRIVER_API_H_ #define TVM_DRIVER_DRIVER_API_H_ +#include #include -#include #include -#include +#include #include #include -#include -#include #include #include +#include +#include namespace tvm { /*! -* \brief Build an IRModule given a schedule, args and binds -* \param sch The schedule to lower. -* \param args The arguments to the function. -* \param name The name of the lowered function. -* \param binds Buffer assignments. -* \param config The build configuration. -* \return The result module. -*/ -TVM_DLL IRModule lower( - te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config); + * \brief Build an IRModule given a schedule, args and binds + * \param sch The schedule to lower. + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \param config The build configuration. + * \return The result module. + */ +TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds, + const BuildConfig& config); /*! -* \brief Build a device and host module for a specific target from an IRModule. -* \param funcs The functions to be built. -* \param target The target device to build for. -* \param target_host The target for building host code. To use the default, pass Target() -* \param config The build configuration. -* \return The built module. -*/ -TVM_DLL runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config); + * \brief Build a device and host module for a specific target from an IRModule. + * \param funcs The functions to be built. + * \param target The target device to build for. + * \param target_host The target for building host code. To use the default, pass Target() + * \param config The build configuration. + * \return The built module. + */ +TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, + const Target& target_host, const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from a map @@ -81,8 +76,7 @@ TVM_DLL runtime::Module build(const IRModule& funcs, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); /*! @@ -95,8 +89,7 @@ TVM_DLL runtime::Module build(const Map& input, * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); } // namespace tvm diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index f9cb622255848..9d45dc10800e1 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -27,11 +27,12 @@ #ifndef TVM_IR_ADT_H_ #define TVM_IR_ADT_H_ -#include -#include -#include #include #include +#include +#include +#include + #include namespace tvm { @@ -66,9 +67,7 @@ class ConstructorNode : public RelayExprNode { bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const { // Use namehint for now to be consistent with the legacy relay impl // TODO(tvm-team) revisit, need to check the type var. - return - equal(name_hint, other->name_hint) && - equal(inputs, other->inputs); + return equal(name_hint, other->name_hint) && equal(inputs, other->inputs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -92,9 +91,7 @@ class Constructor : public RelayExpr { * \param inputs The input types. * \param belong_to The data type var the constructor will construct. */ - TVM_DLL Constructor(std::string name_hint, - Array inputs, - GlobalTypeVar belong_to); + TVM_DLL Constructor(std::string name_hint, Array inputs, GlobalTypeVar belong_to); TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); }; @@ -122,10 +119,8 @@ class TypeDataNode : public TypeNode { } bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const { - return - equal.DefEqual(header, other->header) && - equal.DefEqual(type_vars, other->type_vars) && - equal(constructors, other->constructors); + return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) && + equal(constructors, other->constructors); } void SHashReduce(SHashReducer hash_reduce) const { @@ -157,9 +152,7 @@ class TypeData : public Type { * \param type_vars type variables. * \param constructors constructors field. */ - TVM_DLL TypeData(GlobalTypeVar header, - Array type_vars, - Array constructors); + TVM_DLL TypeData(GlobalTypeVar header, Array type_vars, Array constructors); TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); }; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d12f1b85114c7..819aafa0281cd 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -50,12 +50,12 @@ #include #include -#include -#include #include -#include #include +#include +#include #include +#include namespace tvm { /*! @@ -63,34 +63,30 @@ namespace tvm { * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ - template \ +#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ + static constexpr const char* _type_key = TypeKey; \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ + template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) - /*! * \brief Declare an attribute field. * \param FieldName The field name. */ -#define TVM_ATTR_FIELD(FieldName) \ - __fvisit__(#FieldName, &FieldName) - +#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ -template +template inline TObjectRef NullValue() { - static_assert(TObjectRef::_type_is_nullable, - "Can only get NullValue for nullable types"); + static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } -template<> +template <> inline DataType NullValue() { return DataType(DataType::kHandle, 0, 0); } @@ -101,8 +97,7 @@ struct AttrError : public dmlc::Error { * \brief constructor * \param msg error message */ - explicit AttrError(const std::string &msg) - : dmlc::Error(msg) {} + explicit AttrError(const std::string& msg) : dmlc::Error(msg) {} }; /*! @@ -154,13 +149,13 @@ class BaseAttrsNode : public Object { * \param args The postional arguments in the form * [key0, value0, key1, value1, ..., key_n, value_n] */ - template - inline void InitBySeq(Args&& ...args); + template + inline void InitBySeq(Args&&... args); /*! * \brief Print readible docstring to ostream, add newline. * \param os the stream to print the docstring to. */ - inline void PrintDocString(std::ostream &os) const; // NOLINT(*) + inline void PrintDocString(std::ostream& os) const; // NOLINT(*) /*! * \brief Visit attributes that do not equal the default value. * @@ -212,9 +207,7 @@ class DictAttrsNode : public BaseAttrsNode { return equal(dict, other->dict); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dict); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } // implementations void VisitAttrs(AttrVisitor* v) final; @@ -239,7 +232,6 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); - TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -252,18 +244,16 @@ using runtime::TVMArgValue; struct AttrNopEntry { using TSelf = AttrNopEntry; - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } - template + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } @@ -272,10 +262,8 @@ struct AttrNopEntry { // Wrapper for normal visitor. class AttrNormalVisitor { public: - explicit AttrNormalVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template + explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template AttrNopEntry operator()(const char* key, T* value) { visitor_->Visit(key, value); return AttrNopEntry(); @@ -290,16 +278,13 @@ class AttrsSEqualVisitor { bool result_{true}; // constructor AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) - : lhs_(lhs), rhs_(rhs), equal_(equal) { - } - template + : lhs_(lhs), rhs_(rhs), equal_(equal) {} + template AttrNopEntry operator()(const char* key, T* lhs_value) { if (!result_) return AttrNopEntry(); - const T* rhs_value = - reinterpret_cast( - reinterpret_cast(rhs_) + - (reinterpret_cast(lhs_value) - - reinterpret_cast(lhs_))); + const T* rhs_value = reinterpret_cast( + reinterpret_cast(rhs_) + + (reinterpret_cast(lhs_value) - reinterpret_cast(lhs_))); if (!equal_(*lhs_value, *rhs_value)) { result_ = false; } @@ -314,10 +299,9 @@ class AttrsSEqualVisitor { class AttrsSHashVisitor { public: - explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) - : hash_reducer_(hash_reducer) {} + explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} - template + template AttrNopEntry operator()(const char* key, T* value) { hash_reducer_(*value); return AttrNopEntry(); @@ -328,7 +312,7 @@ class AttrsSHashVisitor { }; // helper entry that does initialization, set default. -template +template struct AttrInitEntry { // The attributes using TSelf = AttrInitEntry; @@ -344,34 +328,31 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ - << "\' during initialization"; + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; throw AttrError(os.str()); } } // override fields. // This function sets the lower bound of the attribute TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (begin > val) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is smaller than the lower bound " << begin; + << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } return *this; } // This function sets the upper bound of the attribute TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (val > end) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is bigger than the upper bound " << end; + << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } return *this; @@ -383,19 +364,17 @@ struct AttrInitEntry { value_missing_ = false; return *this; } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } }; // Template function to allow smart conversion // from Expr types into the constants. -template +template inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } -template +template inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast(val.value().v_int64); @@ -405,7 +384,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kTVMStr) { *ptr = val.operator std::string(); @@ -414,7 +393,7 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); @@ -430,36 +409,34 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } } } -template<> +template <> inline void SetValue(int* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(int64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(uint64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(bool* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } // Visitor for value initialization -template +template class AttrInitVisitor { public: // Counter of number of matched attributes during visit. // This is used to decide if there is additional unmatched attributes. size_t hit_count_{0}; // constructor - AttrInitVisitor(const char* type_key, FFind ffind) - : type_key_(type_key), ffind_(ffind) { - } + AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {} - template + template AttrInitEntry operator()(const char* key, T* value) { TVMArgValue val; AttrInitEntry opt; @@ -482,10 +459,8 @@ class AttrInitVisitor { FFind ffind_; }; -template -inline AttrInitVisitor CreateInitVisitor( - const char* type_key, - FFind ffind) { +template +inline AttrInitVisitor CreateInitVisitor(const char* type_key, FFind ffind) { return AttrInitVisitor(type_key, ffind); } @@ -493,47 +468,47 @@ inline AttrInitVisitor CreateInitVisitor( * \brief Helper struct to get the type name known to tvm. * \tparam T the type we are interested in. */ -template +template struct TypeName { static constexpr const char* value = T::ContainerType::_type_key; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int64"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "uint64_t"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "DataType"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "str"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "bool"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "handle"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "double"; }; @@ -542,25 +517,23 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(ObjectPtr info) - : info_(info) { - } + explicit AttrDocEntry(ObjectPtr info) : info_(info) {} TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { info_->description = str; return *this; } - template + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { std::ostringstream os; os << info_->type_info << ", default=" << value; info_->type_info = os.str(); return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } @@ -571,10 +544,9 @@ class AttrDocEntry { class AttrDocVisitor { public: - template + template AttrDocEntry operator()(const char* key, T* v) { - ObjectPtr info - = make_object(); + ObjectPtr info = make_object(); info->name = key; info->type_info = TypeName::value; fields_.push_back(AttrFieldInfo(info)); @@ -589,7 +561,7 @@ class AttrExistVisitor { std::string key_; bool exist_{false}; - template + template AttrNopEntry operator()(const char* key, T* v) { if (exist_) return AttrNopEntry(); if (key == key_) exist_ = true; @@ -597,12 +569,11 @@ class AttrExistVisitor { } }; -template +template struct AttrTriggerNonDefaultEntry { using TSelf = AttrTriggerNonDefaultEntry; // constructor - AttrTriggerNonDefaultEntry( - AttrVisitor* visitor, const char* key, T* data) + AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data) : visitor_(visitor), key_(key), data_(data) {} ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { @@ -610,37 +581,28 @@ struct AttrTriggerNonDefaultEntry { visitor_->Visit(key_, data_); } } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } TSelf& set_default(const T& value) { if (tvm::StructuralEqual()(value, *data_)) { trigger_ = false; } return *this; } - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - return *this; - } - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - return *this; - } + TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } + TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } private: AttrVisitor* visitor_; - const char * key_; - T *data_; + const char* key_; + T* data_; bool trigger_{true}; }; class AttrNonDefaultVisitor { public: - explicit AttrNonDefaultVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template - AttrTriggerNonDefaultEntry - operator()(const char* key, T* value) { + explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template + AttrTriggerNonDefaultEntry operator()(const char* key, T* value) { return AttrTriggerNonDefaultEntry(visitor_, key, value); } @@ -655,7 +617,7 @@ class AttrNonDefaultVisitor { * * \tparam DerivedType The final attribute type. */ -template +template class AttrsNode : public BaseAttrsNode { public: void VisitAttrs(AttrVisitor* v) { @@ -695,7 +657,7 @@ class AttrsNode : public BaseAttrsNode { CHECK_EQ(args.type_codes[i], kTVMStr); kwargs[args[i].operator std::string()] = args[i + 1]; } - auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { + auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; @@ -715,8 +677,7 @@ class AttrsNode : public BaseAttrsNode { self()->__VisitAttrs__(visitor); if (!visitor.exist_) { std::ostringstream os; - os << DerivedType::_type_key - << ": does not have field \'" << visitor.key_ + os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); @@ -746,21 +707,18 @@ class AttrsNode : public BaseAttrsNode { private: DerivedType* self() const { - return const_cast( - static_cast(this)); + return const_cast(static_cast(this)); } }; - -template -inline void BaseAttrsNode::InitBySeq(Args&& ...args) { - runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { - this->InitByPackedArgs(args); - }); +template +inline void BaseAttrsNode::InitBySeq(Args&&... args) { + runtime::PackedFunc pf( + [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); }); pf(std::forward(args)...); } -inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) +inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*) Array entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { os << info->name << " : " << info->type_info << '\n'; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 67492ab24ba6c..320d6e38e610e 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -47,9 +47,7 @@ class EnvFuncNode : public Object { /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { // name uniquely identifies the env function. @@ -76,15 +74,13 @@ class EnvFunc : public ObjectRef { EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments * \returns The return value. */ - template + template runtime::TVMRetValue operator()(Args&&... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); @@ -104,7 +100,7 @@ class EnvFunc : public ObjectRef { /*! * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc" */ -template +template class TypedEnvFunc; /*! @@ -116,7 +112,7 @@ class TypedEnvFunc; * \tparam Args The argument signature of the function. * \sa EnvFunc */ -template +template class TypedEnvFunc : public ObjectRef { public: /*! \brief short hand for this function type */ @@ -133,9 +129,7 @@ class TypedEnvFunc : public ObjectRef { return *this; } /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments @@ -144,8 +138,8 @@ class TypedEnvFunc : public ObjectRef { R operator()(Args... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); - return runtime::detail::typed_packed_call_dispatcher - ::run(n->func, std::forward(args)...); + return runtime::detail::typed_packed_call_dispatcher::run(n->func, + std::forward(args)...); } /*! \brief specify container node */ using ContainerType = EnvFuncNode; diff --git a/include/tvm/ir/error.h b/include/tvm/ir/error.h index 94064ae8c8faa..c6576c8a9d0e6 100644 --- a/include/tvm/ir/error.h +++ b/include/tvm/ir/error.h @@ -24,13 +24,13 @@ #ifndef TVM_IR_ERROR_H_ #define TVM_IR_ERROR_H_ -#include #include +#include -#include -#include #include +#include #include +#include namespace tvm { /*! @@ -51,7 +51,7 @@ namespace tvm { */ struct ErrorBuilder { public: - template + template ErrorBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; return *this; @@ -78,12 +78,12 @@ class Error : public dmlc::Error { * \brief construct error from error builder. * \param err The error builder */ - Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) + Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) /*! * \brief copy constructor. * \param other The other ereor. */ - Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) + Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) /*! * \brief default constructor. */ Error() : dmlc::Error(""), span(nullptr) {} @@ -173,9 +173,7 @@ class ErrorReporter { */ void RenderErrors(const IRModule& module, bool use_color = true); - inline bool AnyErrors() { - return errors_.size() != 0; - } + inline bool AnyErrors() { return errors_.size() != 0; } private: std::vector errors_; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fba35a9193f94..717ffb1b4826d 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,14 +24,15 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include +#include #include namespace tvm { @@ -111,9 +112,7 @@ class PrimExpr : public BaseExpr { TVM_DLL PrimExpr(float value); // NOLINT(*) /*! \return the data type of this expression. */ - DataType dtype() const { - return static_cast(get())->dtype; - } + DataType dtype() const { return static_cast(get())->dtype; } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); @@ -160,7 +159,7 @@ class RelayExprNode : public BaseExprNode { * \return The corresponding TTypeNode pointer. * \tparam The specific TypeNode we look for. */ - template + template inline const TTypeNode* type_as() const; static constexpr const char* _type_key = "RelayExpr"; @@ -199,9 +198,7 @@ class GlobalVarNode : public RelayExprNode { bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { // name matters for global var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -322,35 +319,21 @@ class FloatImm : public PrimExpr { */ class Bool : public IntImm { public: - explicit Bool(bool value) - : IntImm(DataType::Bool(), value) { - } - Bool operator!() const { - return Bool((*this)->value == 0); - } - operator bool() const { - return (*this)->value != 0; - } + explicit Bool(bool value) : IntImm(DataType::Bool(), value) {} + Bool operator!() const { return Bool((*this)->value == 0); } + operator bool() const { return (*this)->value != 0; } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); }; // Overload operators to make sure we have the most fine grained types. -inline Bool operator||(const Bool& a, bool b) { - return Bool(a.operator bool() || b); -} -inline Bool operator||(bool a, const Bool& b) { - return Bool(a || b.operator bool()); -} +inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } +inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } inline Bool operator||(const Bool& a, const Bool& b) { return Bool(a.operator bool() || b.operator bool()); } -inline Bool operator&&(const Bool& a, bool b) { - return Bool(a.operator bool() && b); -} -inline Bool operator&&(bool a, const Bool& b) { - return Bool(a && b.operator bool()); -} +inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } +inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } inline Bool operator&&(const Bool& a, const Bool& b) { return Bool(a.operator bool() && b.operator bool()); } @@ -384,8 +367,7 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> + template ::value>::type> explicit Integer(Enum value) : Integer(static_cast(value)) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); @@ -402,8 +384,7 @@ class Integer : public IntImm { * \brief convert to int64_t */ operator int64_t() const { - CHECK(data_ != nullptr) - << " Trying to reference a null Integer"; + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } // comparators @@ -411,16 +392,12 @@ class Integer : public IntImm { if (data_ == nullptr) return Bool(false); return Bool((*this)->value == other); } - Bool operator!=(int other) const { - return !(*this == other); - } - template::value>::type> + Bool operator!=(int other) const { return !(*this == other); } + template ::value>::type> Bool operator==(Enum other) const { return *this == static_cast(other); } - template::value>::type> + template ::value>::type> Bool operator!=(Enum other) const { return *this != static_cast(other); } @@ -482,24 +459,21 @@ class Range : public ObjectRef { // implementataions inline const Type& RelayExprNode::checked_type() const { - CHECK(checked_type_.defined()) - << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " - << GetRef(this); + CHECK(checked_type_.defined()) << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " << GetRef(this); return this->checked_type_; } -template +template inline const TTypeNode* RelayExprNode::type_as() const { static_assert(std::is_base_of::value, "TType must be a special case of type"); CHECK(checked_type_.defined()) << "Type inference for this Expr has not completed. Try to call infer_type pass."; const TTypeNode* node = checked_type_.as(); - CHECK(node != nullptr) - << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->GetTypeKey(); + CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " + << checked_type_->GetTypeKey(); return node; } @@ -507,7 +481,7 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -template<> +template <> struct PackedFuncValueConverter { // common rule for both RetValue and ArgValue. static PrimExpr From(const TVMPODValue_& val) { diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index b4a9ed0d89645..00626e6d20a4f 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -24,12 +24,12 @@ #ifndef TVM_IR_FUNCTION_H_ #define TVM_IR_FUNCTION_H_ -#include #include +#include #include -#include -#include +#include +#include namespace tvm { @@ -96,7 +96,7 @@ class BaseFuncNode : public RelayExprNode { * * \endcode */ - template + template Optional GetAttr( const std::string& attr_key, Optional default_value = Optional(nullptr)) const { @@ -111,9 +111,8 @@ class BaseFuncNode : public RelayExprNode { } } // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr( - const std::string& attr_key, TObjectRef default_value) const { + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } /*! @@ -180,12 +179,9 @@ class BaseFunc : public RelayExpr { * * \endcode */ -template::value>::type> -inline TFunc WithAttr(TFunc func, - const std::string& attr_key, - ObjectRef attr_value) { +template ::value>::type> +inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = func.CopyOnWrite(); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b0776dee661f3..ba9a62aa0596e 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -24,15 +24,16 @@ #ifndef TVM_IR_MODULE_H_ #define TVM_IR_MODULE_H_ -#include +#include #include #include -#include +#include +#include #include -#include #include #include +#include namespace tvm { class IRModule; @@ -102,8 +103,7 @@ class IRModuleNode : public Object { * * It does not do type checking as AddTypeDef does. */ - TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, + TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! @@ -131,21 +131,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const std::string& name) const; + TVM_DLL bool ContainGlobalVar(const String& name) const; /*! * \brief Check if the global_type_var_map_ contains a global type variable. * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + TVM_DLL bool ContainGlobalTypeVar(const String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; + TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! * \brief Collect all global vars defined in this module. @@ -158,7 +158,7 @@ class IRModuleNode : public Object { * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const; /*! * \brief Collect all global type vars defined in this module. @@ -172,7 +172,7 @@ class IRModuleNode : public Object { * \param cons name of the constructor * \returns Constructor of ADT, error if not found */ - TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const; /*! * \brief Look up a global function by its variable. @@ -186,7 +186,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const std::string& name) const; + TVM_DLL BaseFunc Lookup(const String& name) const; /*! * \brief Look up a global type definition by its variable. @@ -200,7 +200,7 @@ class IRModuleNode : public Object { * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupTypeDef(const std::string& var) const; + TVM_DLL TypeData LookupTypeDef(const String& var) const; /*! * \brief Look up a constructor by its tag. @@ -225,18 +225,18 @@ class IRModuleNode : public Object { * relative it will be resovled against the current * working directory. */ - TVM_DLL void Import(const std::string& path); + TVM_DLL void Import(const String& path); /*! * \brief Import Relay code from the file at path, relative to the standard library. * \param path The path of the Relay code to import. */ - TVM_DLL void ImportFromStd(const std::string& path); + TVM_DLL void ImportFromStd(const String& path); /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -265,7 +265,7 @@ class IRModuleNode : public Object { /*! \brief The files previously imported, required to ensure importing is idempotent for each module. */ - std::unordered_set import_set_; + std::unordered_set import_set_; friend class IRModule; }; @@ -283,7 +283,7 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}); + std::unordered_set import_set = {}); /*! \brief default constructor */ IRModule() {} /*! @@ -303,9 +303,7 @@ class IRModule : public ObjectRef { * * \returns The constructed module */ - static IRModule Empty() { - return IRModule(Map()); - } + static IRModule Empty() { return IRModule(Map()); } /*! * \brief Construct a module from a standalone expression. * @@ -318,10 +316,9 @@ class IRModule : public ObjectRef { * * \returns A module with expr set as the main function. */ - TVM_DLL static IRModule FromExpr( - const RelayExpr& expr, - const Map& global_funcs = {}, - const Map& type_definitions = {}); + TVM_DLL static IRModule FromExpr(const RelayExpr& expr, + const Map& global_funcs = {}, + const Map& type_definitions = {}); /*! * \brief Parse text format source file into an IRModule. @@ -329,7 +326,7 @@ class IRModule : public ObjectRef { * \param source_path The path to the source file. * \return A Relay module. */ - TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); + TVM_DLL static IRModule FromText(const String& text, const String& source_path); /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; @@ -346,7 +343,7 @@ class IRModule : public ObjectRef { * Use AsText if you want to store the text. * \sa AsText. */ -TVM_DLL std::string PrettyPrint(const ObjectRef& node); +TVM_DLL String PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the text format. @@ -362,8 +359,7 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node); * \sa PrettyPrint. * \return The text representation. */ -TVM_DLL std::string AsText(const ObjectRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 48cf61d187d55..7fafb5a69421d 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -27,10 +27,10 @@ #include #include -#include #include #include #include +#include #include #include @@ -227,8 +227,7 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string& name, - const std::string& type, + inline OpRegistry& add_argument(const std::string& name, const std::string& type, const std::string& description); /*! * \brief Attach the type function corresponding to the return type. @@ -239,16 +238,14 @@ class OpRegistry { */ inline OpRegistry& add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func); + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func); /*! * \brief Set the the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - template + template inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs @@ -306,9 +303,7 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, - runtime::TVMRetValue value, - int plevel); + TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel); }; /*! @@ -410,8 +405,7 @@ class OpMap { #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) // internal macros to make -#define TVM_OP_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op +#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op /*! * \def TVM_REGISTER_OP @@ -428,38 +422,28 @@ class OpMap { * * \endcode */ -#define TVM_REGISTER_OP(OpName) \ - TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::OpRegistry::Registry() \ - ->__REGISTER_OR_GET__(OpName) \ - .set_name() +#define TVM_REGISTER_OP(OpName) \ + TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() // implementations -inline const OpNode* Op::operator->() const { - return static_cast(get()); -} +inline const OpNode* Op::operator->() const { return static_cast(get()); } template inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } -inline bool Op::HasAttr(const std::string& key) { - return Op::HasGenericAttr(key); -} +inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); } -inline OpNode* OpRegistry::get() { - return const_cast(op_.operator->()); -} +inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } -inline OpRegistry& OpRegistry::describe( - const std::string& descr) { // NOLINT(*) +inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string& name, - const std::string& type, +inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, const std::string& description) { auto n = make_object(); n->name = name; @@ -471,10 +455,8 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, inline OpRegistry& OpRegistry::add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func) { + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func) { auto func_name = std::string("tvm.relay.type_relation.") + rel_name; TypeRelationFn env_type_rel_func; @@ -482,8 +464,7 @@ inline OpRegistry& OpRegistry::add_type_rel( auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } else { - runtime::Registry::Register(func_name) - .set_body(type_rel_func.packed()); + runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } @@ -517,13 +498,9 @@ inline OpRegistry& OpRegistry::add_type_rel( // A common example is sum(x, axis), where the choice of axis // can affect the type of the function. TypeConstraint type_rel = - TypeRelation(env_type_rel_func, - ty_call_args, - arg_types.size(), - Attrs()); + TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs()); - auto func_type = - FuncType(arg_types, out_param, type_params, {type_rel}); + auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); get()->op_type = func_type; @@ -535,7 +512,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -template +template inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) get()->attrs_type_key = AttrsType::_type_key; get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); @@ -567,13 +544,11 @@ inline int GenericOpMap::count(const Op& op) const { } } -inline const runtime::TVMRetValue& -GenericOpMap::operator[](const Op& op) const { +inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const { CHECK(op.defined()); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " - << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } @@ -614,14 +589,12 @@ inline ValueType OpMap::operator[](const Op& op) const { } template -inline ValueType OpMap::get(const Op& op, - ValueType def_value) const { +inline ValueType OpMap::get(const Op& op, ValueType def_value) const { return map_.get(op, def_value); } template -inline ValueType OpMap::get(const RelayExpr& expr, - ValueType def_value) const { +inline ValueType OpMap::get(const RelayExpr& expr, ValueType def_value) const { return map_.get(expr, def_value); } diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 7194e903549ca..1ed6848eb9e10 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -24,8 +24,9 @@ #ifndef TVM_IR_SPAN_H_ #define TVM_IR_SPAN_H_ -#include #include +#include + #include namespace tvm { @@ -40,7 +41,7 @@ class SourceName; class SourceNameNode : public Object { public: /*! \brief The source name. */ - std::string name; + String name; // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } @@ -64,7 +65,7 @@ class SourceName : public ObjectRef { * \param name Name of the operator. * \return SourceName valid throughout program lifetime. */ - TVM_DLL static SourceName Get(const std::string& name); + TVM_DLL static SourceName Get(const String& name); TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); }; @@ -92,10 +93,8 @@ class SpanNode : public Object { } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return - equal(source, other->source) && - equal(lineno, other->lineno) && - equal(col_offset, other->col_offset); + return equal(source, other->source) && equal(lineno, other->lineno) && + equal(col_offset, other->col_offset); } TVM_DLL static Span make(SourceName source, int lineno, int col_offset); @@ -104,7 +103,6 @@ class SpanNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; - class Span : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h index 489ea64409006..7a700258f23c0 100644 --- a/include/tvm/ir/tensor_type.h +++ b/include/tvm/ir/tensor_type.h @@ -24,8 +24,8 @@ #ifndef TVM_IR_TENSOR_TYPE_H_ #define TVM_IR_TENSOR_TYPE_H_ -#include #include +#include namespace tvm { /*! @@ -75,9 +75,7 @@ class TensorTypeNode : public BaseTensorTypeNode { } bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { - return - equal(shape, other->shape) && - equal(dtype, other->dtype); + return equal(shape, other->shape) && equal(dtype, other->dtype); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 3680f6db9afec..558d2da79361b 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -56,11 +56,12 @@ #ifndef TVM_IR_TRANSFORM_H_ #define TVM_IR_TRANSFORM_H_ -#include -#include -#include #include #include +#include +#include +#include + #include #include @@ -74,9 +75,7 @@ class PassInfo; * */ using TraceFunc = - runtime::TypedPackedFunc; + runtime::TypedPackedFunc; /*! * \brief PassContextNode contains the information that a pass can rely on, @@ -117,7 +116,6 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; - /*! * \brief PassContext that is used to configure the pass behavior. * @@ -226,9 +224,7 @@ class PassInfo : public ObjectRef { * \param name Name of the pass. * \param required The passes that are required to perform the current pass. */ - TVM_DLL PassInfo(int opt_level, - std::string name, - Array required); + TVM_DLL PassInfo(int opt_level, std::string name, Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -264,8 +260,7 @@ class PassNode : public Object { * * \return The transformed module. */ - virtual IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const = 0; + virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; void VisitAttrs(AttrVisitor* v) {} @@ -303,8 +298,7 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const { + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); CHECK(node != nullptr); return node->operator()(std::move(mod), pass_ctx); @@ -352,12 +346,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const Array& required); - +TVM_DLL Pass +CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const Array& required); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 0ef03c42c03e3..ed648411266cf 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,11 +49,12 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include -#include -#include -#include #include +#include +#include +#include +#include + #include namespace tvm { @@ -109,23 +110,18 @@ class PrimTypeNode : public TypeNode { */ runtime::DataType dtype; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } static constexpr const char* _type_key = "PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; - /* * \brief Managed reference to PrimTypeNode. * \sa PrimTypeNode @@ -141,7 +137,6 @@ class PrimType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; - /*! * \brief Low-level raw pointer type. * @@ -159,17 +154,13 @@ class PointerTypeNode : public TypeNode { */ Type element_type; - void VisitAttrs(AttrVisitor* v) { - v->Visit("element_type", &element_type); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); } bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { return equal(element_type, other->element_type); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(element_type); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); } static constexpr const char* _type_key = "PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); @@ -190,7 +181,6 @@ class PointerType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; - /*! \brief Possible kinds of TypeVars. */ enum TypeKind : int { kType = 0, @@ -227,7 +217,7 @@ class TypeVarNode : public TypeNode { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; /*! \brief The kind of type parameter */ TypeKind kind; @@ -238,9 +228,7 @@ class TypeVarNode : public TypeNode { } bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -263,7 +251,7 @@ class TypeVar : public Type { * \param name_hint The name of the type var. * \param kind The kind of the type var. */ - TVM_DLL TypeVar(std::string name_hint, TypeKind kind); + TVM_DLL TypeVar(String name_hint, TypeKind kind); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); }; @@ -290,9 +278,7 @@ class GlobalTypeVarNode : public TypeNode { bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const { // name matters for now in global type var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -340,9 +326,7 @@ class TupleTypeNode : public TypeNode { return equal(fields, other->fields); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(fields); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } static constexpr const char* _type_key = "TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); @@ -372,9 +356,7 @@ class TupleType : public Type { /*! * \return a type that represents void. */ -inline Type VoidType() { - return TupleType::Empty(); -} +inline Type VoidType() { return TupleType::Empty(); } /*! * \brief Check whether the tyep represents void. @@ -439,11 +421,8 @@ class FuncTypeNode : public TypeNode { bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { // type params first as they defines type vars. - return - equal.DefEqual(type_params, other->type_params) && - equal(arg_types, other->arg_types) && - equal(ret_type, other->ret_type) && - equal(type_constraints, other->type_constraints); + return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) && + equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints); } void SHashReduce(SHashReducer hash_reduce) const { @@ -471,9 +450,7 @@ class FuncType : public Type { * \param type_constraints The type constraints. * \sa FuncTypeNode for more docs about these fields. */ - TVM_DLL FuncType(Array arg_types, - Type ret_type, - Array type_params, + TVM_DLL FuncType(Array arg_types, Type ret_type, Array type_params, Array type_constraints); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); @@ -500,14 +477,10 @@ class IncompleteTypeNode : public TypeNode { } bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(kind); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); } static constexpr const char* _type_key = "IncompleteType"; TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); @@ -528,7 +501,6 @@ class IncompleteType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; - /*! * \brief Reference Type High-level Relay IR. * @@ -550,9 +522,7 @@ class RelayRefTypeNode : public TypeNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } // Keep the relay prefix in the type as this type is specific // to the relay itself. diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 55071911fb80a..2a6314cf76443 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -25,11 +25,12 @@ #define TVM_IR_TYPE_FUNCTOR_H_ #include -#include #include +#include + #include -#include #include +#include namespace tvm { @@ -37,16 +38,13 @@ template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ +#define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } - -#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitType_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.get()), std::forward(args)...); \ + }); template class TypeFunctor { @@ -65,9 +63,7 @@ class TypeFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Type& n, Args... args) { - return VisitType(n, std::forward(args)...); - } + R operator()(const Type& n, Args... args) { return VisitType(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -80,8 +76,7 @@ class TypeFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitType_(const TensorTypeNode* op, - Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -126,8 +121,7 @@ class TypeFunctor { /*! * \brief A type visitor that recursively visit types. */ -class TVM_DLL TypeVisitor : - public TypeFunctor { +class TVM_DLL TypeVisitor : public TypeFunctor { public: void VisitType_(const TypeVarNode* op) override; void VisitType_(const IncompleteTypeNode* op) override; @@ -146,8 +140,7 @@ class TVM_DLL TypeVisitor : /*! * \brief TypeMutator that mutates expressions. */ -class TVM_DLL TypeMutator : - public TypeFunctor { +class TVM_DLL TypeMutator : public TypeFunctor { public: Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 06bcb7207c74b..dbd241afa4580 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -24,10 +24,10 @@ #ifndef TVM_IR_TYPE_RELATION_H_ #define TVM_IR_TYPE_RELATION_H_ -#include -#include -#include #include +#include +#include +#include namespace tvm { @@ -51,9 +51,7 @@ class TypeCallNode : public TypeNode { } bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args); + return equal(func, other->func) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { @@ -105,7 +103,7 @@ class TypeReporterNode : public Object { * \return false if assertation can be proven to have failed * true if solver can still proceed. */ - TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0; + TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0; /*! * \brief assert shape expression equals each other. * \param lhs The left operand. @@ -141,11 +139,9 @@ class TypeReporterNode : public Object { class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(ObjectPtr n) : ObjectRef(n) { - } + explicit TypeReporter(ObjectPtr n) : ObjectRef(n) {} TypeReporterNode* operator->() const { - return const_cast( - static_cast(get())); + return const_cast(static_cast(get())); } using ContainerType = TypeReporterNode; }; @@ -169,11 +165,8 @@ class TypeReporter : public ObjectRef { * \return false if This relation cannot be resolved. * true if this relation has been resolved. */ -using TypeRelationFn = - TypedEnvFunc& args, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter)>; +using TypeRelationFn = TypedEnvFunc& args, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter)>; /*! * \brief User defined type relation, it is an input-output relation on types. @@ -207,11 +200,8 @@ class TypeRelationNode : public TypeConstraintNode { } bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args) && - equal(num_inputs, other->num_inputs) && - equal(attrs, other->attrs); + return equal(func, other->func) && equal(args, other->args) && + equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -239,10 +229,7 @@ class TypeRelation : public TypeConstraint { * \param attrs Attributes to the relation function. * \sa TypeRelationNode for more docs about these fields. */ - TVM_DLL TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs); + TVM_DLL TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs); TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); }; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index ba1edf84383eb..2b6645fa165b7 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,28 +23,28 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ -#include +#include #include +#include #include -#include -#include -#include #include +#include +#include #include #include -#include +#include namespace tvm { -using runtime::String; -using runtime::StringObj; +using runtime::make_object; using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectRef; -using runtime::make_object; -using runtime::ObjectHash; -using runtime::ObjectEqual; +using runtime::String; +using runtime::StringObj; /*! \brief array node content in array */ class ArrayNode : public Object { @@ -60,10 +60,7 @@ class ArrayNode : public Object { class MapNode : public Object { public: /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - ObjectRef, - ObjectRef, - ObjectHash, ObjectEqual>; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; @@ -72,7 +69,6 @@ class MapNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); }; - /*! \brief specialized map node with string as key */ class StrMapNode : public Object { public: @@ -91,14 +87,13 @@ class StrMapNode : public Object { * \tparam Converter a struct that contains converting function * \tparam TIter the content iterator type. */ -template +template class IterAdapter { public: using difference_type = typename std::iterator_traits::difference_type; using value_type = typename Converter::ResultType; using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -106,26 +101,18 @@ class IterAdapter { ++iter_; return *this; } - inline IterAdapter operator+(difference_type offset) const { - return IterAdapter(iter_ + offset); - } + inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - template + template typename std::enable_if::value, - typename T::difference_type>::type - inline operator-(const IterAdapter& rhs) const { + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { return iter_ - rhs.iter_; } - inline bool operator==(IterAdapter other) const { - return iter_ == other.iter_; - } - inline bool operator!=(IterAdapter other) const { - return !(*this == other); - } - inline const value_type operator*() const { - return Converter::convert(*iter_); - } + inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + inline bool operator!=(IterAdapter other) const { return !(*this == other); } + inline const value_type operator*() const { return Converter::convert(*iter_); } private: TIter iter_; @@ -139,28 +126,26 @@ class IterAdapter { * operator[] only provide const acces, use Set to mutate the content. * \tparam T The content NodeRef type. */ -template::value>::type > +template ::value>::type> class Array : public ObjectRef { public: /*! * \brief default constructor */ - Array() { - data_ = make_object(); - } + Array() { data_ = make_object(); } /*! * \brief move constructor * \param other source */ - Array(Array && other) : ObjectRef() { // NOLINT(*) + Array(Array&& other) : ObjectRef() { // NOLINT(*) data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : ObjectRef() { // NOLINT(*) + Array(const Array& other) : ObjectRef() { // NOLINT(*) data_ = std::move(other.data_); } /*! @@ -174,7 +159,7 @@ class Array : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template Array(IterType begin, IterType end) { assign(begin, end); } @@ -182,14 +167,14 @@ class Array : public ObjectRef { * \brief constructor from initializer list * \param init The initalizer list */ - Array(std::initializer_list init) { // NOLINT(*) + Array(std::initializer_list init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init The vector */ - Array(const std::vector& init) { // NOLINT(*) + Array(const std::vector& init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! @@ -209,7 +194,7 @@ class Array : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Array& operator=(Array && other) { + Array& operator=(Array&& other) { data_ = std::move(other.data_); return *this; } @@ -218,7 +203,7 @@ class Array : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Array& operator=(const Array & other) { + Array& operator=(const Array& other) { data_ = other.data_; return *this; } @@ -228,7 +213,7 @@ class Array : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template void assign(IterType begin, IterType end) { auto n = make_object(); for (IterType it = begin; it != end; ++it) { @@ -242,8 +227,7 @@ class Array : public ObjectRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return DowncastNoCheck( - static_cast(data_.get())->data[i]); + return DowncastNoCheck(static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { @@ -259,7 +243,7 @@ class Array : public ObjectRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -292,16 +276,14 @@ class Array : public ObjectRef { n->data[i] = value; } /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } /*! * \brief Helper function to apply fmutate to mutate an array. * \param fmutate The transformation function T -> T. * \tparam F the type of the mutation function. * \note This function performs copy on write optimization. */ - template + template inline void MutateByApply(F fmutate) { ArrayNode* ptr = static_cast(data_.get()); if (ptr == nullptr) return; @@ -342,16 +324,12 @@ class Array : public ObjectRef { struct ValueConverter { using ResultType = T; - static inline T convert(const ObjectRef& n) { - return DowncastNoCheck(n); - } + static inline T convert(const ObjectRef& n) { return DowncastNoCheck(n); } }; - using iterator = IterAdapter::const_iterator>; + using iterator = IterAdapter::const_iterator>; - using reverse_iterator = IterAdapter< - ValueConverter, - std::vector::const_reverse_iterator>; + using reverse_iterator = + IterAdapter::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { @@ -380,32 +358,28 @@ class Array : public ObjectRef { * \tparam K The key NodeRef type. * \tparam V The value NodeRef type. */ -template::value || - std::is_base_of::value >::type, - typename = typename std::enable_if::value>::type> +template ::value || + std::is_base_of::value>::type, + typename = typename std::enable_if::value>::type> class Map : public ObjectRef { public: /*! * \brief default constructor */ - Map() { - data_ = make_object(); - } + Map() { data_ = make_object(); } /*! * \brief move constructor * \param other source */ - Map(Map && other) { // NOLINT(*) + Map(Map&& other) { // NOLINT(*) data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer @@ -418,7 +392,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template Map(IterType begin, IterType end) { assign(begin, end); } @@ -426,15 +400,15 @@ class Map : public ObjectRef { * \brief constructor from initializer list * \param init The initalizer list */ - Map(std::initializer_list > init) { // NOLINT(*) + Map(std::initializer_list > init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init The vector */ - template - Map(const std::unordered_map& init) { // NOLINT(*) + template + Map(const std::unordered_map& init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! @@ -442,7 +416,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(Map && other) { + Map& operator=(Map&& other) { data_ = std::move(other.data_); return *this; } @@ -451,7 +425,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(const Map & other) { + Map& operator=(const Map& other) { data_ = other.data_; return *this; } @@ -461,7 +435,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template void assign(IterType begin, IterType end) { ObjectPtr n = make_object(); for (IterType i = begin; i != end; ++i) { @@ -475,8 +449,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -484,8 +457,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { @@ -506,7 +478,7 @@ class Map : public ObjectRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -524,24 +496,18 @@ class Map : public ObjectRef { } /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } /*! \brief specify container node */ using ContainerType = MapNode; struct ValueConverter { using ResultType = std::pair; - static inline ResultType convert(const std::pair< - ObjectRef, - ObjectRef>& n) { - return std::make_pair(DowncastNoCheck(n.first), - DowncastNoCheck(n.second)); + static inline ResultType convert(const std::pair& n) { + return std::make_pair(DowncastNoCheck(n.first), DowncastNoCheck(n.second)); } }; - using iterator = IterAdapter< - ValueConverter, MapNode::ContainerType::const_iterator>; + using iterator = IterAdapter; /*! \return begin iterator */ inline iterator begin() const { @@ -553,46 +519,43 @@ class Map : public ObjectRef { } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator( - static_cast(data_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; // specialize of string map -template +template class Map : public ObjectRef { public: // for code reuse - Map() { - data_ = make_object(); - } - Map(Map && other) { // NOLINT(*) + Map() { data_ = make_object(); } + Map(Map&& other) { // NOLINT(*) data_ = std::move(other.data_); } - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) } explicit Map(ObjectPtr n) : ObjectRef(n) {} - template + template Map(IterType begin, IterType end) { assign(begin, end); } - Map(std::initializer_list > init) { // NOLINT(*) + Map(std::initializer_list > init) { // NOLINT(*) assign(init.begin(), init.end()); } - template - Map(const std::unordered_map& init) { // NOLINT(*) + template + Map(const std::unordered_map& init) { // NOLINT(*) assign(init.begin(), init.end()); } - Map& operator=(Map && other) { + Map& operator=(Map&& other) { data_ = std::move(other.data_); return *this; } - Map& operator=(const Map & other) { + Map& operator=(const Map& other) { data_ = other.data_; return *this; } - template + template void assign(IterType begin, IterType end) { auto n = make_object(); for (IterType i = begin; i != end; ++i) { @@ -601,12 +564,10 @@ class Map : public ObjectRef { data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } inline size_t size() const { if (data_.get() == nullptr) return 0; @@ -617,7 +578,7 @@ class Map : public ObjectRef { return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -628,22 +589,17 @@ class Map : public ObjectRef { StrMapNode* n = this->CopyOnWrite(); n->data[key] = value; } - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; struct ValueConverter { using ResultType = std::pair; - static inline ResultType convert(const std::pair< - std::string, - ObjectRef>& n) { + static inline ResultType convert(const std::pair& n) { return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; - using iterator = IterAdapter< - ValueConverter, StrMapNode::ContainerType::const_iterator>; + using iterator = IterAdapter; /*! \return begin iterator */ inline iterator begin() const { @@ -663,7 +619,7 @@ class Map : public ObjectRef { namespace tvm { namespace runtime { // Additional overloads for PackedFunc checking. -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -676,12 +632,10 @@ struct ObjectTypeChecker > { } return true; } - static std::string TypeName() { - return "List[" + ObjectTypeChecker::TypeName() + "]"; - } + static std::string TypeName() { return "List[" + ObjectTypeChecker::TypeName() + "]"; } }; -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -692,13 +646,10 @@ struct ObjectTypeChecker > { } return true; } - static std::string TypeName() { - return "Map[str, " + - ObjectTypeChecker::TypeName()+ ']'; - } + static std::string TypeName() { return "Map[str, " + ObjectTypeChecker::TypeName() + ']'; } }; -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -711,10 +662,8 @@ struct ObjectTypeChecker > { return true; } static std::string TypeName() { - return "Map[" + - ObjectTypeChecker::TypeName() + - ", " + - ObjectTypeChecker::TypeName()+ ']'; + return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + + ']'; } }; } // namespace runtime diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index e11fda892c305..0837f35bd7156 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -26,9 +26,9 @@ #include #include -#include #include #include +#include namespace tvm { @@ -60,16 +60,16 @@ using runtime::ObjectRef; * \tparam FType function signiture * This type if only defined for FType with function signature */ -template +template class NodeFunctor; -template +template class NodeFunctor { private: /*! \brief internal function pointer type */ - typedef R (*FPointer)(const ObjectRef&n, Args...); + typedef R (*FPointer)(const ObjectRef& n, Args...); /*! \brief refer to itself. */ - using TSelf = NodeFunctor; + using TSelf = NodeFunctor; /*! \brief internal function table */ std::vector func_; @@ -92,9 +92,8 @@ class NodeFunctor { * \return The result. */ R operator()(const ObjectRef& n, Args... args) const { - CHECK(can_dispatch(n)) - << "NodeFunctor calls un-registered function on type " - << n->GetTypeKey(); + CHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " + << n->GetTypeKey(); return (*func_[n->type_index()])(n, std::forward(args)...); } /*! @@ -103,37 +102,32 @@ class NodeFunctor { * \tparam TNode the type of Node to be dispatched. * \return reference to self. */ - template + template TSelf& set_dispatch(FPointer f) { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } - CHECK(func_[tindex] == nullptr) - << "Dispatch for " << TNode::_type_key - << " is already set"; + CHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! - * \brief unset the dispacher for type TNode - * - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template + * \brief unset the dispacher for type TNode + * + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template TSelf& clear_dispatch() { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); - CHECK_LT(tindex, func_.size()) - << "clear_dispatch: index out of range"; + CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; } }; - -#define TVM_REG_FUNC_VAR_DEF(ClsName) \ - static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName +#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName /*! * \brief Useful macro to set NodeFunctor dispatch in a global static field. @@ -176,8 +170,7 @@ class NodeFunctor { * \param ClsName The name of the class * \param FField The static function that returns a singleton of NodeFunctor. */ -#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ - TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = \ - ClsName::FField() +#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ + TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField() } // namespace tvm #endif // TVM_NODE_FUNCTOR_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 471a0de361b7c..b622fc7ee47f1 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -34,35 +34,35 @@ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include -#include -#include -#include +#include #include #include -#include #include #include +#include +#include +#include +#include #include -#include -#include #include +#include +#include namespace tvm { -using runtime::TypeIndex; +using runtime::Downcast; +using runtime::GetRef; +using runtime::make_object; using runtime::Object; +using runtime::ObjectEqual; +using runtime::ObjectHash; using runtime::ObjectPtr; using runtime::ObjectRef; -using runtime::GetRef; -using runtime::Downcast; -using runtime::ObjectHash; -using runtime::ObjectEqual; -using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using runtime::TypeIndex; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 9ed87df46618e..643b63895d5c5 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -23,18 +23,18 @@ #ifndef TVM_NODE_REFLECTION_H_ #define TVM_NODE_REFLECTION_H_ +#include +#include #include -#include +#include #include -#include #include -#include -#include -#include +#include +#include -#include #include #include +#include namespace tvm { @@ -51,7 +51,7 @@ using runtime::ObjectRef; */ class AttrVisitor { public: -//! \cond Doxygen_Suppress + //! \cond Doxygen_Suppress TVM_DLL virtual ~AttrVisitor() = default; TVM_DLL virtual void Visit(const char* key, double* value) = 0; TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0; @@ -63,14 +63,13 @@ class AttrVisitor { TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; - template::value>::type> + template ::value>::type> void Visit(const char* key, ENum* ptr) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); this->Visit(key, reinterpret_cast(ptr)); } -//! \endcond + //! \endcond }; /*! @@ -166,7 +165,7 @@ class ReflectionVTable { TVM_DLL static ReflectionVTable* Global(); class Registry; - template + template inline Registry Register(); private: @@ -174,7 +173,7 @@ class ReflectionVTable { std::vector fvisit_attrs_; /*! \brief Structural equal function. */ std::vector fsequal_reduce_; - /*! \brief Structural hash function. */ + /*! \brief Structural hash function. */ std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; @@ -186,7 +185,7 @@ class ReflectionVTable { class ReflectionVTable::Registry { public: Registry(ReflectionVTable* parent, uint32_t type_index) - : parent_(parent), type_index_(type_index) { } + : parent_(parent), type_index_(type_index) {} /*! * \brief Set fcreate function. * \param f The creator function. @@ -213,10 +212,8 @@ class ReflectionVTable::Registry { uint32_t type_index_; }; - -#define TVM_REFLECTION_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \ - __make_reflectiion +#define TVM_REFLECTION_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflectiion /*! * \brief Directly register reflection VTable. @@ -247,122 +244,108 @@ class ReflectionVTable::Registry { * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. * And can be used to register the related reflection functions for runtime objects. */ -#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ - TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::ReflectionVTable::Global()->Register() \ +#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ + TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ReflectionVTable::Global()->Register() /*! * \brief Register a node type to object registry and reflection registry. * \param TypeName The name of the type. * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - TVM_REGISTER_OBJECT_TYPE(TypeName); \ +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ - .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::runtime::make_object(); \ - }) - + .set_creator([](const std::string&) -> ObjectPtr { \ + return ::tvm::runtime::make_object(); \ + }) // Implementation details namespace detail { -template +template struct ImplVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct ImplVisitAttrs { - static void VisitAttrs(T* self, AttrVisitor* v) { - self->VisitAttrs(v); - } + static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); } }; -template +template struct ImplSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct ImplSEqualReduce { static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { return self->SEqualReduce(other, equal); } }; -template +template struct ImplSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct ImplSHashReduce { static void SHashReduce(const T* self, SHashReducer hash_reduce) { self->SHashReduce(hash_reduce); } }; -template -struct ReflectionTrait : - public ImplVisitAttrs, - public ImplSEqualReduce, - public ImplSHashReduce { -}; +template +struct ReflectionTrait : public ImplVisitAttrs, + public ImplSEqualReduce, + public ImplSHashReduce {}; -template::value> +template ::value> struct SelectVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct SelectVisitAttrs { static void VisitAttrs(Object* self, AttrVisitor* v) { TraitName::VisitAttrs(static_cast(self), v); } }; -template::value> +template ::value> struct SelectSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct SelectSEqualReduce { - static bool SEqualReduce(const Object* self, - const Object* other, - SEqualReducer equal) { - return TraitName::SEqualReduce(static_cast(self), - static_cast(other), + static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) { + return TraitName::SEqualReduce(static_cast(self), static_cast(other), equal); } }; -template::value> +template ::value> struct SelectSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct SelectSHashReduce { - static void SHashReduce(const Object* self, - SHashReducer hash_reduce) { - return TraitName::SHashReduce(static_cast(self), - hash_reduce); + static void SHashReduce(const Object* self, SHashReducer hash_reduce) { + return TraitName::SHashReduce(static_cast(self), hash_reduce); } }; } // namespace detail -template -inline ReflectionVTable::Registry -ReflectionVTable::Register() { +template +inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); if (tindex >= fvisit_attrs_.size()) { fvisit_attrs_.resize(tindex + 1, nullptr); @@ -372,20 +355,16 @@ ReflectionVTable::Register() { fshash_reduce_.resize(tindex + 1, nullptr); } // functor that implemnts the redirection. - fvisit_attrs_[tindex] = - ::tvm::detail::SelectVisitAttrs::VisitAttrs; + fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs::VisitAttrs; - fsequal_reduce_[tindex] = - ::tvm::detail::SelectSEqualReduce::SEqualReduce; + fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; - fshash_reduce_[tindex] = - ::tvm::detail::SelectSHashReduce::SHashReduce; + fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce::SHashReduce; return Registry(this, tindex); } -inline void ReflectionVTable:: -VisitAttrs(Object* self, AttrVisitor* visitor) const { +inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const { uint32_t tindex = self->type_index(); if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << self->GetTypeKey() @@ -394,8 +373,7 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { fvisit_attrs_[tindex](self, visitor); } -inline bool ReflectionVTable::GetReprBytes(const Object* self, - std::string* repr_bytes) const { +inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const { uint32_t tindex = self->type_index(); if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { if (repr_bytes != nullptr) { diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 57824306620c8..532425a51b3ec 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -24,6 +24,7 @@ #define TVM_NODE_REPR_PRINTER_H_ #include + #include namespace tvm { diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index f719e24f619cb..9424f6dc30f29 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -23,9 +23,10 @@ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -43,26 +44,13 @@ class BaseValueEqual { return diff > -atol && diff < atol; } - bool operator()(const int64_t& lhs, const int64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { - return lhs == rhs; - } - bool operator()(const bool& lhs, const bool& rhs) const { - return lhs == rhs; - } - bool operator()(const std::string& lhs, const std::string& rhs) const { - return lhs == rhs; - } - bool operator()(const DataType& lhs, const DataType& rhs) const { - return lhs == rhs; - } - template::value>::type> + bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } + bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } + bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } + bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } + bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } + template ::value>::type> bool operator()(const ENum& lhs, const ENum& rhs) const { return lhs == rhs; } @@ -127,9 +115,7 @@ class SEqualReducer : public BaseValueEqual { * \note This function may save the equality condition of (lhs == rhs) in an internal * stack and try to resolve later. */ - virtual bool SEqualReduce(const ObjectRef& lhs, - const ObjectRef& rhs, - bool map_free_vars) = 0; + virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0; /*! * \brief Lookup the graph node equal map for vars that are already mapped. * @@ -185,7 +171,7 @@ class SEqualReducer : public BaseValueEqual { * \param rhs The right operand. * \return the immediate check result. */ - template + template bool operator()(const Array& lhs, const Array& rhs) const { // quick specialization for Array to reduce amount of recursion // depth as array comparison is pretty common. @@ -210,9 +196,7 @@ class SEqualReducer : public BaseValueEqual { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index affc5f4dc3774..ed89d841cd655 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -23,11 +23,12 @@ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ #define TVM_NODE_STRUCTURAL_HASH_H_ -#include -#include #include -#include +#include +#include + #include +#include namespace tvm { @@ -36,39 +37,25 @@ namespace tvm { */ class BaseValueHash { public: - size_t operator()(const double& key) const { - return std::hash()(key); - } + size_t operator()(const double& key) const { return std::hash()(key); } - size_t operator()(const int64_t& key) const { - return std::hash()(key); - } + size_t operator()(const int64_t& key) const { return std::hash()(key); } - size_t operator()(const uint64_t& key) const { - return std::hash()(key); - } + size_t operator()(const uint64_t& key) const { return std::hash()(key); } - size_t operator()(const int& key) const { - return std::hash()(key); - } + size_t operator()(const int& key) const { return std::hash()(key); } - size_t operator()(const bool& key) const { - return std::hash()(key); - } + size_t operator()(const bool& key) const { return std::hash()(key); } - size_t operator()(const std::string& key) const { - return std::hash()(key); - } + size_t operator()(const std::string& key) const { return std::hash()(key); } size_t operator()(const runtime::DataType& key) const { - return std::hash()( - static_cast(key.code()) | - (static_cast(key.bits()) << 8) | - (static_cast(key.lanes()) << 16)); + return std::hash()(static_cast(key.code()) | + (static_cast(key.bits()) << 8) | + (static_cast(key.lanes()) << 16)); } - template::value>::type> + template ::value>::type> bool operator()(const ENum& key) const { return std::hash()(static_cast(key)); } @@ -173,9 +160,8 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - template::value>::type> + template ::value>::type> void operator()(const T& key) const { // handle normal values. handler_->SHashReduceHashedValue(BaseValueHash()(key)); @@ -184,17 +170,13 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - void operator()(const ObjectRef& key) const { - return handler_->SHashReduce(key, map_free_vars_); - } + void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); } /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. * \note This function indicate key could contain var defintions. */ - void DefHash(const ObjectRef& key) const { - return handler_->SHashReduce(key, true); - } + void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } /*! * \brief Implementation for hash for a free var. * \param var The variable. @@ -205,9 +187,7 @@ class SHashReducer { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 1ee7c9c09728e..b2164ba8c1f79 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -24,13 +24,14 @@ #ifndef TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_ -#include #include +#include #include #include #include -#include + #include +#include #include namespace tvm { @@ -72,16 +73,11 @@ class PatternWildcard; /*! \brief PatternWildcard container node */ class PatternWildcardNode : public PatternNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("span", &span); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } - bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} static constexpr const char* _type_key = "relay.PatternWildcard"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); @@ -131,9 +127,7 @@ class PatternVarNode : public PatternNode { return equal.DefEqual(var, other->var); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); } static constexpr const char* _type_key = "relay.PatternVar"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); @@ -167,9 +161,7 @@ class PatternConstructorNode : public PatternNode { } bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { - return - equal(constructor, other->constructor) && - equal(patterns, other->patterns); + return equal(constructor, other->constructor) && equal(patterns, other->patterns); } void SHashReduce(SHashReducer hash_reduce) const { @@ -210,9 +202,7 @@ class PatternTupleNode : public PatternNode { return equal(patterns, other->patterns); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(patterns); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); } static constexpr const char* _type_key = "relay.PatternTuple"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); @@ -297,10 +287,8 @@ class MatchNode : public ExprNode { bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(data, other->data) && - equal(clauses, other->clauses) && - equal(complete, other->complete); + return equal(data, other->data) && equal(clauses, other->clauses) && + equal(complete, other->complete); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index a2c0c75b66eff..b4b1b9dcc4e85 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -24,11 +24,12 @@ #ifndef TVM_RELAY_ANALYSIS_H_ #define TVM_RELAY_ANALYSIS_H_ +#include #include #include #include -#include #include + #include #include @@ -73,9 +74,9 @@ TVM_DLL bool ConstantCheck(const Expr& e); * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, * although x is not shadowed. * - * \param expr the expression to check. + * \param expr the expression to check. * - * \return true iff all Var in expr is bound at most once. + * \return true iff all Var in expr is bound at most once. */ TVM_DLL bool WellFormed(const Expr& expr); @@ -233,8 +234,7 @@ TVM_DLL Array UnmatchedCases(const Match& match, const IRModule& mod); * * \return The reference count mapping. */ -TVM_DLL std::unordered_map -GetExprRefCount(const Expr& body); +TVM_DLL std::unordered_map GetExprRefCount(const Expr& body); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 2d1b9028732df..a7d4708e2c926 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -38,14 +39,15 @@ struct ArgsortAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor." - "If not given, the flattened array is used."); - TVM_ATTR_FIELD(is_ascend).set_default(true) - .describe("Whether to sort in ascending or descending order." - "By default, sort in ascending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("DType of the output indices."); + TVM_ATTR_FIELD(axis).set_default(-1).describe( + "Axis along which to sort the input tensor." + "If not given, the flattened array is used."); + TVM_ATTR_FIELD(is_ascend).set_default(true).describe( + "Whether to sort in ascending or descending order." + "By default, sort in ascending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("DType of the output indices."); } }; @@ -57,20 +59,19 @@ struct TopKAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { - TVM_ATTR_FIELD(k).set_default(1) - .describe("Number of top elements to select"); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor."); - TVM_ATTR_FIELD(ret_type).set_default("both") - .describe("The return type [both, values, indices]." - "both - return both top k data and indices." - "values - return top k data only." - "indices - return top k indices only."); - TVM_ATTR_FIELD(is_ascend).set_default(false) - .describe("Whether to sort in ascending or descending order." - "By default, sort in descending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Data type of the output indices."); + TVM_ATTR_FIELD(k).set_default(1).describe("Number of top elements to select"); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor."); + TVM_ATTR_FIELD(ret_type).set_default("both").describe( + "The return type [both, values, indices]." + "both - return both top k data and indices." + "values - return top k data only." + "indices - return top k indices only."); + TVM_ATTR_FIELD(is_ascend).set_default(false).describe( + "Whether to sort in ascending or descending order." + "By default, sort in descending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("Data type of the output indices."); } }; diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index cc21e34b41255..4a2eb63c7e6af 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_ANNOTATION_H_ #include + #include namespace tvm { @@ -38,9 +39,8 @@ struct OnDeviceAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe( - "The virutal device/context type that an expression is annotated with.") - .set_default(0); + .describe("The virutal device/context type that an expression is annotated with.") + .set_default(0); } }; @@ -51,9 +51,7 @@ struct CastHintAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") { - TVM_ATTR_FIELD(dtype) - .describe( - "The data type denoted to be cast."); + TVM_ATTR_FIELD(dtype).describe("The data type denoted to be cast."); } }; @@ -65,8 +63,7 @@ struct CompilerAttrs : public tvm::AttrsNode { std::string compiler; TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") { - TVM_ATTR_FIELD(compiler) - .describe("A 3rd party compiler used for code generation."); + TVM_ATTR_FIELD(compiler).describe("A 3rd party compiler used for code generation."); } }; diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h index 962afc29fdbc7..ed04c59ec8659 100644 --- a/include/tvm/relay/attrs/bitserial.h +++ b/include/tvm/relay/attrs/bitserial.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -112,23 +113,18 @@ struct BinaryDenseAttrs : public tvm::AttrsNode { bool unipolar; TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); - TVM_ATTR_FIELD(data_bits) - .set_default(1) - .describe("Number of bits to pack for incoming tensor."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(data_bits).set_default(1).describe( + "Number of bits to pack for incoming tensor."); TVM_ATTR_FIELD(weight_bits) - .set_default(1) - .describe("Number of bits to pack for weight tensor."); + .set_default(1) + .describe("Number of bits to pack for weight tensor."); TVM_ATTR_FIELD(pack_dtype) - .set_default(NullValue()) - .describe("Datatype to pack bits into before computation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); - TVM_ATTR_FIELD(unipolar) - .set_default(true) - .describe("Whether to use unipolar or bipolar quantization for inputs."); + .set_default(NullValue()) + .describe("Datatype to pack bits into before computation."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + TVM_ATTR_FIELD(unipolar).set_default(true).describe( + "Whether to use unipolar or bipolar quantization for inputs."); } }; diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h index ed9ed4ee06268..112228bb41ee5 100644 --- a/include/tvm/relay/attrs/debug.h +++ b/include/tvm/relay/attrs/debug.h @@ -25,6 +25,8 @@ #define TVM_RELAY_ATTRS_DEBUG_H_ #include +#include + #include namespace tvm { @@ -37,8 +39,7 @@ struct DebugAttrs : public tvm::AttrsNode { EnvFunc debug_func; TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") { - TVM_ATTR_FIELD(debug_func) - .describe("The function to use when debugging."); + TVM_ATTR_FIELD(debug_func).describe("The function to use when debugging."); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 2486fcdf473db..7da92b3ff7639 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include + #include namespace tvm { @@ -39,13 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { TVM_ATTR_FIELD(src_dev_type) - .describe( - "The virtual device/context type where the op copies data from.") - .set_default(0); + .describe("The virtual device/context type where the op copies data from.") + .set_default(0); TVM_ATTR_FIELD(dst_dev_type) - .describe( - "The virtual device/context type where the op copies data to.") - .set_default(0); + .describe("The virtual device/context type where the op copies data to.") + .set_default(0); } }; diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 52bb2efc63a80..b927c98615374 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -40,26 +41,27 @@ struct ResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { - TVM_ATTR_FIELD(size).set_default(NullValue >()) - .describe("Output Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -72,22 +74,22 @@ struct CropAndResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") { - TVM_ATTR_FIELD(crop_size).set_default(NullValue >()) - .describe("Target Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation"); - TVM_ATTR_FIELD(extrapolation_value).set_default(0.0) + TVM_ATTR_FIELD(crop_size).set_default(NullValue >()).describe("Target Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .set_default(0.0) .describe("Specify value for extrapolation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -101,25 +103,33 @@ struct Dilation2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the sliding window. [stride_height, stride_width]."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilations).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilations) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use. [dilation_height, dilation_width]"); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("IHW") - .describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc." - "'I', 'H', 'W' stands for input_channel, height, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("IHW") + .describe( + "Dimension ordering of weight. Can be 'IHW', 'HWI', etc." + "'I', 'H', 'W' stands for input_channel, height, and width" + "dimensions respectively."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index d232f867a7775..7429c396ea006 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,6 +26,7 @@ #include #include + #include #include @@ -46,15 +47,10 @@ struct AllocStorageAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id) - .describe( - "The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type) - .describe( - "The device type on which to allocate memory."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); + TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); } }; @@ -68,16 +64,13 @@ struct AllocTensorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(const_shape) - .describe( - "The shape of constant used to aid in type inference."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(const_shape).describe("The shape of constant used to aid in type inference."); TVM_ATTR_FIELD(assert_shape) - .describe( - "The shape to cast the return type of the allocation to, "\ - "used to specify the shape obtained via further analysis."); + .describe( + "The shape to cast the return type of the allocation to, " + "used to specify the shape obtained via further analysis."); } }; @@ -88,10 +81,9 @@ struct ShapeFuncAttrs : public tvm::AttrsNode { Array is_input; TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") { - TVM_ATTR_FIELD(is_input) - .describe( - "A bool indicating whether the shape function should"\ - "expect shape or input in each position."); + TVM_ATTR_FIELD(is_input).describe( + "A bool indicating whether the shape function should" + "expect shape or input in each position."); } }; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index f985a9010961e..a9c305935ad3d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -42,13 +43,10 @@ struct BiasAddAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis to add the bias") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1); } }; - /*! \brief Attributes used in 1D convolution operators */ struct Conv1DAttrs : public tvm::AttrsNode { Array strides; @@ -63,31 +61,44 @@ struct Conv1DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, })) + TVM_ATTR_FIELD(strides) + .set_default(Array({ + 1, + })) .describe("Specifies the stride of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, })) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({ + 1, + })) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Currently unused but may be added in the future."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Currently unused but may be added in the future."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the 'W'" - "dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the 'W'" + "dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -96,7 +107,6 @@ struct Conv1DAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; @@ -111,42 +121,53 @@ struct Conv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -156,14 +177,13 @@ struct Conv2DAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in winograd weight transformation operators */ -struct ConvWinogradWeightTransformAttrs : - public tvm::AttrsNode { +struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { int tile_size; TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs, - "relay.attrs.ConvWinogradWeightTransformAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + "relay.attrs.ConvWinogradWeightTransformAttrs") { + TVM_ATTR_FIELD(tile_size).describe( + "Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); } }; @@ -182,44 +202,55 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -261,43 +292,54 @@ struct Conv3DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -321,45 +363,56 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -368,14 +421,12 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("The axis to sum over when computing softmax."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("The axis to sum over when computing softmax."); } }; @@ -395,53 +446,77 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0, 0})) - .describe("Zero-padding added to one side of the output." - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0, 0})) + .describe( + "Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); } }; +/*! \brief Attributes used in dilate operator */ +struct DilateAttrs : public tvm::AttrsNode { + Array strides; + + TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") { + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Dilation stride on each dimension, 1 means no dilation."); + } +}; + /*! \brief Attributes used in 1D transposed convolution operator */ struct Conv1DTransposeAttrs : public tvm::AttrsNode { IndexExpr channels; @@ -458,42 +533,54 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0})) - .describe("Zero-padding added to one side of the output."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("Symmetric or asymmetric padding." - "Single value: the input is implicitly zero-padded on both sides." - "Two values: padding[0] is used for left input padding, " - "padding[1] is used for right input padding,"); - TVM_ATTR_FIELD(dilation).set_default(Array({1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the" - "'W' dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0})) + .describe("Zero-padding added to one side of the output."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "Symmetric or asymmetric padding." + "Single value: the input is implicitly zero-padded on both sides." + "Two values: padding[0] is used for left input padding, " + "padding[1] is used for right input padding,"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -509,23 +596,25 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -539,25 +628,28 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; @@ -566,11 +658,11 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -580,13 +672,14 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output height and width."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output height and width."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -595,17 +688,17 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output depth, height and width."); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on 'D', 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output depth, height and width."); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on 'D', 'H' and" + "'W' dimensions."); } }; - /*! \brief Attributes for 1D max pool operator */ struct MaxPool1DAttrs : public tvm::AttrsNode { Array pool_size; @@ -615,22 +708,24 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -644,28 +739,30 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool1DAttrs, "relay.attrs.AvgPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NHC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimension."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimension."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for 3D max pool operator */ struct MaxPool3DAttrs : public tvm::AttrsNode { Array pool_size; @@ -675,23 +772,25 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -705,37 +804,38 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -772,21 +872,22 @@ struct UpSamplingAttrs : public tvm::AttrsNode { bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Upsampling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(false) + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Upsampling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(align_corners) + .set_default(false) .describe("Should be true to preserve the values at the corner pixels"); } }; @@ -801,26 +902,27 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { std::string coordinate_transformation_mode; TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") { - TVM_ATTR_FIELD(scale_d) - .describe("The upsampling factor for depth"); - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Upsampling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(scale_d).describe("The upsampling factor for depth"); + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Upsampling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); } }; @@ -831,15 +933,17 @@ struct PadAttrs : public tvm::AttrsNode { std::string pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { - TVM_ATTR_FIELD(pad_value).set_default(0.0) - .describe("The value used for padding when mode is 'constant'."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); - TVM_ATTR_FIELD(pad_mode).set_default("constant") - .describe("Padding type to use. \"constant\" pads with constant_value, " - "\"edge\" pads using the edge values of the input array, " - "\"reflect\" pads by reflecting values with respect to the edges."); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe( + "The value used for padding when mode is 'constant'."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(pad_mode) + .set_default("constant") + .describe( + "Padding type to use. \"constant\" pads with constant_value, " + "\"edge\" pads using the edge values of the input array, " + "\"reflect\" pads by reflecting values with respect to the edges."); } }; @@ -849,11 +953,12 @@ struct MirrorPadAttrs : public tvm::AttrsNode { Array > pad_width; TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") { - TVM_ATTR_FIELD(mode).set_default("SYMMETRIC") - .describe("Specifies how mirroring should be performed."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(mode) + .set_default("SYMMETRIC") + .describe("Specifies how mirroring should be performed."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); } }; @@ -862,30 +967,28 @@ struct LeakyReluAttrs : public tvm::AttrsNode { double alpha; TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") { - TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) - .describe("Slope coefficient for the negative half axis."); + TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25).describe( + "Slope coefficient for the negative half axis."); } }; - /*! \brief Attributes for prelu operator */ struct PReluAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") { - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Specify which shape axis the channel is specified."); + TVM_ATTR_FIELD(axis).set_default(1).describe( + "Specify which shape axis the channel is specified."); } }; - /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") { TVM_ATTR_FIELD(rate) - .describe("Fraction of the input that gets dropped out during training time") - .set_default(0.5); + .describe("Fraction of the input that gets dropped out during training time") + .set_default(0.5); } }; // struct DropoutAttrs @@ -897,24 +1000,22 @@ struct BatchNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); TVM_ATTR_FIELD(center) - .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") - .set_default(true); + .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") + .set_default(true); TVM_ATTR_FIELD(scale) - .describe("If True, multiply by gamma. If False, gamma is not used. " - "When the next layer is piecewise linear (also, e.g., nn.relu), " - "this can be disabled since the scaling will be done by the next layer.") - .set_default(true); + .describe( + "If True, multiply by gamma. If False, gamma is not used. " + "When the next layer is piecewise linear (also, e.g., nn.relu), " + "this can be disabled since the scaling will be done by the next layer.") + .set_default(true); } }; // struct BatchNormAttrs - /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public tvm::AttrsNode { int axis; @@ -923,21 +1024,18 @@ struct InstanceNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct InstanceNormAttrs - /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public tvm::AttrsNode { int axis; @@ -946,19 +1044,17 @@ struct LayerNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Specify which shape axis denotes the channel."); - TVM_ATTR_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero"); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct LayerNormAttrs - /*! \brief Attributes used in group_norm operator */ struct GroupNormAttrs : public tvm::AttrsNode { int num_groups; @@ -968,21 +1064,20 @@ struct GroupNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") { - TVM_ATTR_FIELD(num_groups).set_default(0) - .describe("Specify number of groups to separate the channels into."); - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Specify which shape axis denotes the channel."); - TVM_ATTR_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero"); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + TVM_ATTR_FIELD(num_groups) + .set_default(0) + .describe("Specify number of groups to separate the channels into."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct GroupNormAttrs - /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { int size; @@ -992,34 +1087,26 @@ struct LRNAttrs : public tvm::AttrsNode { double beta; TVM_DECLARE_ATTRS(LRNAttrs, "relay.attrs.LRNAttrs") { - TVM_ATTR_FIELD(size).set_default(5) - .describe("The size of the local region to be considered for normalization."); - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Axis of input data layout channel."); - TVM_ATTR_FIELD(bias).set_default(2) - .describe("The offset parameter to avoid division by 0."); - TVM_ATTR_FIELD(alpha).set_default(0.0001) - .describe("The scaling parameter."); - TVM_ATTR_FIELD(beta).set_default(0.75) - .describe("The exponent parameter."); + TVM_ATTR_FIELD(size).set_default(5).describe( + "The size of the local region to be considered for normalization."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Axis of input data layout channel."); + TVM_ATTR_FIELD(bias).set_default(2).describe("The offset parameter to avoid division by 0."); + TVM_ATTR_FIELD(alpha).set_default(0.0001).describe("The scaling parameter."); + TVM_ATTR_FIELD(beta).set_default(0.75).describe("The exponent parameter."); } }; - /*! \brief Attributes for L2Normalize operator */ struct L2NormalizeAttrs : public tvm::AttrsNode { double eps; Array axis; TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { - TVM_ATTR_FIELD(eps) - .describe("A lower bound value for the norm, to avoid division by 0."); - TVM_ATTR_FIELD(axis) - .describe("Axis over the normalization applied."); + TVM_ATTR_FIELD(eps).describe("A lower bound value for the norm, to avoid division by 0."); + TVM_ATTR_FIELD(axis).describe("Axis over the normalization applied."); } }; - /*! \brief Attributes for DeformableConv2D operator */ struct DeformableConv2DAttrs : public tvm::AttrsNode { Array strides; @@ -1035,46 +1122,59 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(deformable_groups).set_default(1) - .describe("Controls the connections between inputs and offsets." - "Input channels are partitioned into multiple deformable groups. Offsets" - "are shared across input channels in the same deformable group."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(deformable_groups) + .set_default(1) + .describe( + "Controls the connections between inputs and offsets." + "Input channels are partitioned into multiple deformable groups. Offsets" + "are shared across input channels in the same deformable group."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 443efb5b0c32a..f57c1f4ddc589 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_REDUCE_H_ #include + #include namespace tvm { @@ -37,7 +38,8 @@ struct ReduceAttrs : public tvm::AttrsNode { bool exclude; TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue>()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -51,11 +53,11 @@ struct ReduceAttrs : public tvm::AttrsNode { If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead.)code"); - TVM_ATTR_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - TVM_ATTR_FIELD(exclude).set_default(false) - .describe("Whether to perform reduction on axis that are NOT in axis instead."); + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); } }; } // namespace relay diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 9e7ff714defa1..824ea93651263 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -27,6 +27,7 @@ #include #include #include + #include namespace tvm { @@ -37,8 +38,7 @@ struct CastAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type"); + TVM_ATTR_FIELD(dtype).describe("Target data type"); } }; // struct CastAttrs. @@ -48,11 +48,11 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { int num_newaxis; TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis at which the input array is expanded." - "Should lie in range `[-data.ndim - 1, data.ndim]`." - "If `axis < 0`, it is the first axis inserted;" - "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input array is expanded." + "Should lie in range `[-data.ndim - 1, data.ndim]`." + "If `axis < 0`, it is the first axis inserted;" + "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); TVM_ATTR_FIELD(num_newaxis) .describe("Number of axises to be inserted. Should be >= 0.") .set_lower_bound(0) @@ -65,8 +65,9 @@ struct ConcatenateAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis at which the input arrays are concatenated." - "Should lie in range `[-ndim, ndim)`.") + .describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") .set_default(0); } }; // struct ConcatenateAttrs @@ -75,8 +76,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode { struct TransposeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("The target axes order, reverse order if not specified."); + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); } }; // struct TransposeAttrs @@ -85,8 +85,8 @@ struct ReshapeAttrs : public tvm::AttrsNode { Expr newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { - TVM_ATTR_FIELD(newshape) - .describe("The new shape. Should be compatible with the original shape."); + TVM_ATTR_FIELD(newshape).describe( + "The new shape. Should be compatible with the original shape."); TVM_ATTR_FIELD(reverse) .describe("Infer the special values from right to left if true") .set_default(false); @@ -98,13 +98,14 @@ struct TakeAttrs : public tvm::AttrsNode { std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis over which to select values."); - TVM_ATTR_FIELD(mode).set_default("clip") - .describe("Specify how out-of-bound indices will behave." - "clip - clip to the range (default)" - "wrap - wrap around the indices" - "fast - no clip or wrap around (user must make sure indices are in-bound)"); + TVM_ATTR_FIELD(mode).set_default("clip").describe( + "Specify how out-of-bound indices will behave." + "clip - clip to the range (default)" + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; @@ -114,11 +115,8 @@ struct InitOpAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") { - TVM_ATTR_FIELD(shape) - .describe("Target shape."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type.") - .set_default(NullValue()); + TVM_ATTR_FIELD(shape).describe("Target shape."); + TVM_ATTR_FIELD(dtype).describe("Target data type.").set_default(NullValue()); } }; // struct InitOpAttrs @@ -130,14 +128,10 @@ struct ArangeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { - TVM_ATTR_FIELD(start) - .describe("Start of interval. The interval includes this value."); - TVM_ATTR_FIELD(stop) - .describe("Stop of interval. The interval does not include this value."); - TVM_ATTR_FIELD(step) - .describe("Spacing between values."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type."); + TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value."); + TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value."); + TVM_ATTR_FIELD(step).describe("Spacing between values."); + TVM_ATTR_FIELD(dtype).describe("Target data type."); } }; // struct ArangeAttrs @@ -145,8 +139,8 @@ struct ArangeAttrs : public tvm::AttrsNode { struct StackAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") { - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis in the result array along which the input arrays are stacked."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis in the result array along which the input arrays are stacked."); } }; // struct StackAttrs @@ -155,9 +149,9 @@ struct RepeatAttrs : public tvm::AttrsNode { Integer repeats; Integer axis; TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") { - TVM_ATTR_FIELD(repeats) - .describe("The number of repetitions for each element."); - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element."); + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe(" The axis along which to repeat values."); } }; // struct RepeatAttrs @@ -166,9 +160,9 @@ struct RepeatAttrs : public tvm::AttrsNode { struct TileAttrs : public tvm::AttrsNode { Array reps; TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") { - TVM_ATTR_FIELD(reps) - .describe("The number of times for repeating the tensor a." - "Each dim sizeof reps must be a positive integer."); + TVM_ATTR_FIELD(reps).describe( + "The number of times for repeating the tensor a." + "Each dim sizeof reps must be a positive integer."); } }; // struct TileAttrs @@ -176,7 +170,8 @@ struct TileAttrs : public tvm::AttrsNode { struct ReverseAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis along which to reverse elements."); } }; // struct ReverseAttrs @@ -188,10 +183,11 @@ struct SqueezeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis to squeeze in the input tensor." - "If `axis = None`, all axis of dimension 1 get squeezed;" - "Else, the dimension in axes get squeezed." - "It is an error if an axis does not has dimension 1.") + .describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1.") .set_default(NullValue >()); } }; // struct SqueezeAttrs @@ -202,13 +198,13 @@ struct SplitAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_ATTR_FIELD(indices_or_sections) - .describe("Indices or sections to split into. Accepts an int or a tuple" - "If indices_or_sections is an integer, the input will be divided equally" - "along given axis. If such a split is not possible, an error is raised." - "If indices_or_sections is a tuple of sorted integers," - "the entries indicate where along axis the array is split."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("the axis to be splitted."); + .describe( + "Indices or sections to split into. Accepts an int or a tuple" + "If indices_or_sections is an integer, the input will be divided equally" + "along given axis. If such a split is not possible, an error is raised." + "If indices_or_sections is a tuple of sorted integers," + "the entries indicate where along axis the array is split."); + TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted."); } }; @@ -219,12 +215,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Array strides; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { - TVM_ATTR_FIELD(begin) - .describe("Indices for begin of slice, begin index is also inclusive"); - TVM_ATTR_FIELD(end) - .describe("Indices for end of slice, end index is exclusive"); - TVM_ATTR_FIELD(strides).set_default(Array({})) - .describe("Stride values of the slice"); + TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive"); + TVM_ATTR_FIELD(strides).set_default(Array({})).describe("Stride values of the slice"); } }; @@ -232,10 +225,10 @@ struct SliceLikeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("List of axes on which input data will be sliced according to the " - "corresponding size of the second input. By default will slice " - "on all axes. Negative axes mean counting in reverse."); + TVM_ATTR_FIELD(axes).describe( + "List of axes on which input data will be sliced according to the " + "corresponding size of the second input. By default will slice " + "on all axes. Negative axes mean counting in reverse."); } }; @@ -245,10 +238,8 @@ struct ClipAttrs : public tvm::AttrsNode { double a_max; TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); + TVM_ATTR_FIELD(a_min).describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max).describe("The maximum clip value."); } }; @@ -258,10 +249,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string dst_layout; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") { - TVM_ATTR_FIELD(src_layout) - .describe("The source layout of the tensor. (e.g. NCHW)"); - TVM_ATTR_FIELD(dst_layout) - .describe("The destination layout of the tensor. (e.g. NCHW16c)"); + TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)"); + TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)"); } }; @@ -270,9 +259,7 @@ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -281,10 +268,9 @@ struct SequenceMaskAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") { - TVM_ATTR_FIELD(mask_value).set_default(0) - .describe("The masking value."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis of the length dimension. Can only be 0 or 1."); + TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis of the length dimension. Can only be 0 or 1."); } }; // struct SequenceMaskAttrs. @@ -293,9 +279,7 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -306,12 +290,9 @@ struct OneHotAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { - TVM_ATTR_FIELD(depth).set_default(1) - .describe("Depth of the one hot dimension."); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis to fill."); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + TVM_ATTR_FIELD(dtype).set_default(NullValue()).describe("Output data type."); } }; // struct OneHotAttrs diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index c4a30ce8b1594..e7e24b19228b8 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -41,39 +42,32 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(MultiBoxPriorAttrs, "relay.attrs.MultiBoxPriorAttrs") { TVM_ATTR_FIELD(sizes) - .set_default(Array({static_cast(1.0)})) - .describe("List of sizes of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of sizes of generated MultiBoxPriores."); TVM_ATTR_FIELD(ratios) - .set_default(Array({static_cast(1.0)})) - .describe("List of aspect ratios of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of aspect ratios of generated MultiBoxPriores."); TVM_ATTR_FIELD(steps) - .set_default(Array({static_cast(-1.0), - static_cast(-1.0)})) - .describe("Priorbox step across y and x, -1 for auto calculation."); + .set_default(Array({static_cast(-1.0), static_cast(-1.0)})) + .describe("Priorbox step across y and x, -1 for auto calculation."); TVM_ATTR_FIELD(offsets) - .set_default(Array({static_cast(0.5), - static_cast(0.5)})) - .describe("Priorbox center offsets, y and x respectively."); - TVM_ATTR_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); + .set_default(Array({static_cast(0.5), static_cast(0.5)})) + .describe("Priorbox center offsets, y and x respectively."); + TVM_ATTR_FIELD(clip).set_default(false).describe("Whether to clip out-of-boundary boxes."); } }; -struct MultiBoxTransformLocAttrs - : public tvm::AttrsNode { +struct MultiBoxTransformLocAttrs : public tvm::AttrsNode { bool clip; double threshold; Array variances; - TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, - "relay.attrs.MultiBoxTransformLocAttrs") { - TVM_ATTR_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - TVM_ATTR_FIELD(threshold).set_default(0.01) - .describe("Threshold to be a positive prediction."); + TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") { + TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes."); + TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction."); TVM_ATTR_FIELD(variances) - .set_default(Array({0.1f, 0.1f , 0.2f, 0.2f})) - .describe("Variances to be decoded from box regression output."); + .set_default(Array({0.1f, 0.1f, 0.2f, 0.2f})) + .describe("Variances to be decoded from box regression output."); } }; @@ -84,12 +78,11 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { - TVM_ATTR_FIELD(score_threshold).set_default(0.0) - .describe("Lower limit of score for valid bounding boxes."); - TVM_ATTR_FIELD(id_index).set_default(0) - .describe("Axis index of id."); - TVM_ATTR_FIELD(score_index).set_default(1) - .describe("Index of the scores/confidence of boxes."); + TVM_ATTR_FIELD(score_threshold) + .set_default(0.0) + .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes."); } }; @@ -106,25 +99,28 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { Integer stride; TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") { - TVM_ATTR_FIELD(stride) - .set_default(1) - .describe("Stride value for yolo reorg"); + TVM_ATTR_FIELD(stride).set_default(1).describe("Stride value for yolo reorg"); } }; @@ -206,10 +200,8 @@ struct ProposalAttrs : public tvm::AttrsNode { .describe( "The size of the receptive field each unit in the convolution layer of the rpn," "for example the product of all stride's prior to this layer."); - TVM_ATTR_FIELD(threshold) - .set_default(0.7) - .describe( - "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); + TVM_ATTR_FIELD(threshold).set_default(0.7).describe( + "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); TVM_ATTR_FIELD(rpn_pre_nms_top_n) .set_default(6000) .describe("Number of top scoring boxes to apply NMS. -1 to use all boxes"); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 1d0120675e99e..c78ab75bcbd9e 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,10 +24,10 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ - #include -#include #include +#include + #include #include @@ -42,17 +42,19 @@ namespace tvm { */ namespace relay { -#define RELAY_DEBUG(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } -#define RELAY_DEBUG_INTERP(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG_INTERP(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } /*! * \brief Symbolic expression for tensor shape. @@ -93,9 +95,7 @@ class IdNode : public Object { */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } static constexpr const char* _type_key = "relay.Id"; TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8c5026050d91f..69a60a76b3c0b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -26,10 +26,12 @@ #include #include -#include #include -#include +#include + #include +#include + #include "./base.h" #include "./type.h" @@ -63,9 +65,7 @@ class ConstantNode : public ExprNode { TensorType tensor_type() const; /*! \return Whether it is scalar(rank-0 tensor) */ - bool is_scalar() const { - return data->ndim == 0; - } + bool is_scalar() const { return data->ndim == 0; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); @@ -77,9 +77,7 @@ class ConstantNode : public ExprNode { return equal(data, other->data); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(data); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } static constexpr const char* _type_key = "relay.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); @@ -172,9 +170,7 @@ class VarNode : public ExprNode { Type type_annotation; /*! \return The name hint of the variable */ - const std::string& name_hint() const { - return vid->name_hint; - } + const std::string& name_hint() const { return vid->name_hint; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); @@ -184,9 +180,7 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - return - equal(type_annotation, other->type_annotation) && - equal.FreeVarEqualImpl(this, other); + return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -194,11 +188,9 @@ class VarNode : public ExprNode { hash_reduce.FreeVarHashImpl(this); } - TVM_DLL static Var make(std::string name_hint, - Type type_annotation); + TVM_DLL static Var make(std::string name_hint, Type type_annotation); - TVM_DLL static Var make(Id vid, - Type type_annotation); + TVM_DLL static Var make(Id vid, Type type_annotation); static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); @@ -211,8 +203,7 @@ class Var : public Expr { * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. */ - TVM_DLL Var(std::string name_hint, Type type_annotation) : - Var(Id(name_hint), type_annotation) {} + TVM_DLL Var(std::string name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {} /*! * \brief The constructor @@ -278,11 +269,8 @@ class CallNode : public ExprNode { bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip type_args check for primitive ops. equal->MarkGraphNode(); - return - equal(op, other->op) && - equal(args, other->args) && - equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(type_args, other->type_args)); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(type_args, other->type_args)); } void SHashReduce(SHashReducer hash_reduce) const { @@ -308,9 +296,7 @@ class Call : public Expr { * \param attrs The attributes of the call node. * \param type_args The type arguments passed to a polymorphic function. */ - TVM_DLL Call(Expr op, - Array args, - Attrs attrs = Attrs(), + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), Array type_args = Array()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); @@ -348,10 +334,8 @@ class LetNode : public ExprNode { bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -410,10 +394,8 @@ class IfNode : public ExprNode { bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(cond, other->cond) && - equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch); } void SHashReduce(SHashReducer hash_reduce) const { @@ -457,9 +439,7 @@ class TupleGetItemNode : public ExprNode { } bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { - return - equal(tuple, other->tuple) && - equal(index, other->index); + return equal(tuple, other->tuple) && equal(index, other->index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -576,9 +556,7 @@ class RefWriteNode : public ExprNode { bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(ref, other->ref) && - equal(value, other->value); + return equal(ref, other->ref) && equal(value, other->value); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 04b275431f2b3..559f9b854ee32 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ -#include #include +#include +#include #include #include -#include #include #include -#include #include +#include namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT \ +#define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward(args)...); } -#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); template class ExprFunctor { @@ -81,9 +79,7 @@ class ExprFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Expr& n, Args... args) { - return VisitExpr(n, std::forward(args)...); - } + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -96,22 +92,15 @@ class ExprFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitExpr_(const ConstantNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const TupleNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const VarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GlobalVarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FunctionNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const IfNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OpNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -154,8 +143,7 @@ class ExprFunctor { * ExprVisitor treats Expr as dataflow graph, * and only visit each Expr node once. */ -class ExprVisitor - : public ::tvm::relay::ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: void VisitExpr(const Expr& expr) override; void VisitExpr_(const VarNode* op) override; @@ -189,16 +177,13 @@ class ExprVisitor * The mutated results are memoized in a map and reused so that * local transformation on the dataflow preserves the graph structure. */ -class ExprMutator - : public ::tvm::relay::ExprFunctor { +class ExprMutator : public ::tvm::relay::ExprFunctor { public: /*! * \brief Mutate is alias for VisitExpr * \return expr. */ - Expr Mutate(const Expr& expr) { - return this->VisitExpr(expr); - } + Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; @@ -283,7 +268,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions * of the graph and processes them iteratatively to prevent stack overflows * - * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior. + * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive + * behavior. */ class MixedModeMutator : public ::tvm::relay::ExprMutator { public: @@ -293,14 +279,14 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); }; /*! - * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be - * able to rewrite the op only with data about the original node `pre` and the same node with + * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will + * be able to rewrite the op only with data about the original node `pre` and the same node with * modified inputs `post` and should not recurse. * * \param pre The expression node before rewriting. * \param post The expression with rewritten inputs. */ - virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;} + virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } @@ -350,9 +336,7 @@ class ExprRewriter { * \param post The expression node with rewritten inputs. * \return The result of the call */ - Expr operator()(const Expr& pre, const Expr& post) { - return Rewrite(pre, post); - } + Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); } /*! * \brief The functor call. * \param pre The expression node before rewriting. diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 744d7c4e111cf..3783e320f57c4 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -24,9 +24,9 @@ #ifndef TVM_RELAY_FEATURE_H_ #define TVM_RELAY_FEATURE_H_ +#include #include #include -#include #include @@ -65,9 +65,7 @@ class FeatureSet { public: FeatureSet(const FeatureSet&) = default; /*! \brief A singleton set containing a single Feature. */ - explicit FeatureSet(Feature ft) { - bs_.set(static_cast(ft)); - } + explicit FeatureSet(Feature ft) { bs_.set(static_cast(ft)); } explicit FeatureSet(const tvm::Array& ft) { for (Integer i : ft) { (*this) += Feature(static_cast(i)); @@ -93,25 +91,25 @@ class FeatureSet { FeatureSet fs; return fs; } - template + template FeatureSet& operator+=(const T& rhs) { bs_ |= FeatureSet(rhs).bs_; return *this; } /*! \brief Set union. */ - template + template FeatureSet operator+(const T& rhs) const { FeatureSet fs(*this); fs += rhs; return fs; } - template + template FeatureSet& operator-=(const T& rhs) { bs_ &= ~(FeatureSet(rhs)).bs_; return *this; } /*! \brief Set difference. */ - template + template FeatureSet operator-(const T& rhs) const { FeatureSet fs(*this); fs -= rhs; @@ -124,14 +122,12 @@ class FeatureSet { * * \return true only if this is a subset of rhs. */ - bool is_subset_of(const FeatureSet& rhs) const { - return ((*this) - rhs).bs_.none(); - } + bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } private: std::bitset bs_; FeatureSet() = default; - explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } + explicit FeatureSet(const std::bitset& bs) : bs_(bs) {} }; /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 33b813b76f18f..ab9111bfe0846 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -26,8 +26,8 @@ #include #include -#include +#include namespace tvm { namespace relay { @@ -71,12 +71,9 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { // Important to make def equal first. equal->MarkGraphNode(); - return - equal.DefEqual(params, other->params) && - equal.DefEqual(type_params, other->type_params) && - equal(ret_type, other->ret_type) && - equal(attrs, other->attrs) && - equal(body, other->body); + return equal.DefEqual(params, other->params) && + equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) && + equal(attrs, other->attrs) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -100,7 +97,6 @@ class FunctionNode : public BaseFuncNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; - /*! * \brief Managed reference to FunctionNode. * \sa FunctionNode @@ -115,10 +111,7 @@ class Function : public BaseFunc { * \param ty_params The type parameters. * \param attrs Additional function attributes. */ - TVM_DLL Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, + TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, tvm::DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index ae1f84a616a4a..bda73ed3a51bb 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,12 +36,11 @@ #include #include -#include #include +#include #include #include - namespace tvm { namespace relay { @@ -64,8 +63,8 @@ namespace relay { * \param target Compiler target flag to compile the functions on the context. * \return A function that takes in an expression and returns a value. */ -runtime::TypedPackedFunc -CreateInterpreter(IRModule mod, DLContext context, Target target); +runtime::TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, + Target target); /*! \brief The container type of Closures used by the interpreter. */ class InterpreterClosureObj : public runtime::vm::ClosureObj { @@ -96,8 +95,7 @@ class InterpreterClosureObj : public runtime::vm::ClosureObj { class InterpreterClosure : public runtime::vm::Closure { public: TVM_DLL InterpreterClosure(tvm::Map env, Function func); - TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, - InterpreterClosureObj); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, InterpreterClosureObj); }; /*! \brief The container type of RecClosure. */ @@ -130,9 +128,7 @@ struct RefValueObj : Object { RefValueObj() {} - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } static constexpr const char* _type_key = "relay.RefValue"; TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); @@ -164,9 +160,7 @@ struct ConstructorValueObj : Object { class ConstructorValue : public ObjectRef { public: - TVM_DLL ConstructorValue(int32_t tag, - tvm::Array fields, - Constructor construtor = {}); + TVM_DLL ConstructorValue(int32_t tag, tvm::Array fields, Constructor construtor = {}); TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fa47da226dffc..12845158a22ff 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -25,8 +25,8 @@ #define TVM_RELAY_OP_H_ #include -#include #include +#include namespace tvm { namespace relay { @@ -34,8 +34,7 @@ namespace relay { using Op = tvm::Op; using OpNode = tvm::OpNode; -#define RELAY_REGISTER_OP(OpName) \ - TVM_REGISTER_OP(OpName) +#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName) } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 5b2fdd3ab4e19..b3e70f50cc03d 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -24,21 +24,22 @@ #ifndef TVM_RELAY_OP_ATTR_TYPES_H_ #define TVM_RELAY_OP_ATTR_TYPES_H_ -#include -#include -#include #include -#include +#include #include +#include +#include +#include #include + #include namespace tvm { namespace relay { +using tir::BijectiveLayoutNode; using tir::Layout; using tir::LayoutAxis; -using tir::BijectiveLayoutNode; /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { @@ -104,10 +105,8 @@ using TShapeDataDependant = bool; & these are always placeholders. * \return The output compute description of the operator. */ -using FTVMCompute = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Type& out_type)>; +using FTVMCompute = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Type& out_type)>; /*! * \brief Build the computation schedule for @@ -118,10 +117,8 @@ using FTVMCompute = runtime::TypedPackedFunc< * \param target The build target. * \return schedule The computation schedule. */ -using FTVMSchedule = runtime::TypedPackedFunc< - te::Schedule(const Attrs& attrs, - const Array& outs, - const Target& target)>; +using FTVMSchedule = runtime::TypedPackedFunc& outs, const Target& target)>; /*! * \brief Generate the strategy of operators. This function is a generic @@ -143,11 +140,9 @@ using FTVMStrategy = GenericFunc; * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMAlterOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const Type& out_type)>; +using FTVMAlterOpLayout = + runtime::TypedPackedFunc& args, + const Array& tinfos, const Type& out_type)>; /*! * \brief Convert the layout of operators or replace the @@ -160,11 +155,9 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< * \param desired_layout The desired layout. * \return new_expr The modified expression. */ -using FTVMConvertOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const std::string& desired_layout)>; +using FTVMConvertOpLayout = runtime::TypedPackedFunc& args, const Array& tinfos, + const std::string& desired_layout)>; /*! * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. @@ -174,10 +167,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMLegalize = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& arg_types)>; +using FTVMLegalize = runtime::TypedPackedFunc& args, + const Array& arg_types)>; /*! * \brief Annotates an expression to indicate if an op should be compiled using @@ -189,9 +180,8 @@ using FTVMLegalize = runtime::TypedPackedFunc< * \return true if this op should be registered to invoke a specific compiler * for codegen, otherwise, false. */ -using FTVMAnnotateTarget = runtime::TypedPackedFunc< - bool(const Attrs& attrs, // NOLINT(*) - const Array& args)>; +using FTVMAnnotateTarget = runtime::TypedPackedFunc& args)>; /*! * \brief Forward rewriting rule for a specific op. @@ -207,10 +197,8 @@ using FTVMAnnotateTarget = runtime::TypedPackedFunc< * \note When we register the function, we can register * a different signature with ctx to be a specific node type. */ -using FForwardRewrite = runtime::TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx)>; +using FForwardRewrite = runtime::TypedPackedFunc& new_args, const ObjectRef& ctx)>; /*! * \brief Gradient for a specific op. @@ -219,8 +207,8 @@ using FForwardRewrite = runtime::TypedPackedFunc< * \param output_grad the gradient of the Expr. * \return the gradient for each parameters. */ -using FPrimalGradient = runtime::TypedPackedFunc(const Expr& orig_call, - const Expr& output_grad)>; +using FPrimalGradient = + runtime::TypedPackedFunc(const Expr& orig_call, const Expr& output_grad)>; /*! * \brief The codegeneration strategy for dynamic dimensions. @@ -233,10 +221,8 @@ enum AnyCodegenStrategy { /*! \brief A runtime representation of shape. */ using Shape = Array; -using FShapeFunc = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Array& out_ndims)>; +using FShapeFunc = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Array& out_ndims)>; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_strategy.h b/include/tvm/relay/op_strategy.h index a4da95a36b074..3f5876d9fcbaf 100644 --- a/include/tvm/relay/op_strategy.h +++ b/include/tvm/relay/op_strategy.h @@ -25,11 +25,12 @@ #ifndef TVM_RELAY_OP_STRATEGY_H_ #define TVM_RELAY_OP_STRATEGY_H_ -#include -#include #include #include #include +#include +#include + #include namespace tvm { @@ -70,8 +71,7 @@ class OpImplementation : public ObjectRef { * \param out_type The output type information. * \return The output compute description of the operator. */ - TVM_DLL Array Compute(const Attrs& attrs, - const Array& inputs, + TVM_DLL Array Compute(const Attrs& attrs, const Array& inputs, const Type& out_type); /*! * \brief Build the computation schedule. @@ -80,8 +80,7 @@ class OpImplementation : public ObjectRef { * \param target The build target. * \return The computation schedule. */ - TVM_DLL te::Schedule Schedule(const Attrs& attrs, - const Array& outs, + TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array& outs, const Target& target); TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); @@ -119,8 +118,8 @@ class OpSpecialization : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); }; @@ -133,9 +132,7 @@ class OpStrategyNode : public Object { /*! \brief List of operator specializations. */ Array specializations; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("specializations", &specializations); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); } static constexpr const char* _type_key = "relay.OpStrategy"; TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); @@ -153,8 +150,8 @@ class OpStrategy : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); }; diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 6e0fb17ed233e..ada69c6772f36 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_ -#include #include +#include #include -#include #include +#include +#include "./adt.h" #include "./expr.h" #include "./op.h" -#include "./adt.h" namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class PatternFunctor; // functions to be overriden. -#define PATTERN_FUNCTOR_DEFAULT \ +#define PATTERN_FUNCTOR_DEFAULT \ { return VisitPatternDefault_(op, std::forward(args)...); } -#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), std::forward(args)...); \ + }); template class PatternFunctor { @@ -96,14 +94,10 @@ class PatternFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitPattern_(const PatternWildcardNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternVarNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternConstructorNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternTupleNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; @@ -144,8 +138,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor { +class PatternMutator : public ::tvm::relay::PatternFunctor { public: Pattern Mutate(const Pattern& pat); Pattern VisitPattern_(const PatternWildcardNode* op) override; @@ -163,6 +156,7 @@ class PatternMutator virtual Var VisitVar(const Var& v); /*! \brief Used to visit the vars inside of patterns. */ virtual Constructor VisitConstructor(const Constructor& c); + private: std::unordered_map var_map_; }; diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 3c1c4a33c3d11..4b5cd89f0b0cd 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -25,6 +25,7 @@ #define TVM_RELAY_QNN_ATTRS_H_ #include + #include namespace tvm { @@ -39,19 +40,20 @@ struct RequantizeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); - TVM_ATTR_FIELD(rounding).set_default("UPWARD") - .describe("Defines the rounding direction when the value is midway between" - "two representable values. There are two supported modes - UPWARD" - "or TONEAREST. Both modes behave exactly same except at the" - "midpoints between the two representable values. At the midpoint," - "UPWARD rounds towards positive infinity (for example -1.5 will be" - "rounded to -1). TONEAREST is the standard rounding where the" - "value is rounded away from zero at midpoints (for example, -1.5" - "rounds to -2). More context can be found at following gblic manual" - "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + "Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - UPWARD" + "or TONEAREST. Both modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). TONEAREST is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following gblic manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -64,12 +66,12 @@ struct QuantizeAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { - TVM_ATTR_FIELD(out_dtype) - .describe("Output data type, can be one of [int8 or uint8]."); + TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [int8 or uint8]."); TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); } }; diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h index 10cd19afe6f3a..d1f07c924d6b1 100644 --- a/include/tvm/relay/qnn/transform.h +++ b/include/tvm/relay/qnn/transform.h @@ -25,8 +25,8 @@ #ifndef TVM_RELAY_QNN_TRANSFORM_H_ #define TVM_RELAY_QNN_TRANSFORM_H_ -#include #include +#include namespace tvm { namespace relay { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index dc4097ab184b5..461276b795410 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,13 +24,13 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ -#include -#include #include +#include #include #include -#include #include +#include +#include #include @@ -56,11 +56,9 @@ using Sequential = tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const tvm::Array& required); /*! \brief Remove expressions which does not effect the program result. * @@ -79,17 +77,17 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< TVM_DLL Pass DeadCodeElimination(bool inline_once = false); /*! -* \brief Convert all expressions of TensorType into GradCell, -* an algebraic data type defined in gradient.rly. -* -* This will delay or decrease memory usage. All calls to -* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, -* rather only instantiate if needed. It also defines + and * operation -* between GradCell types which can increase performance when using -* zero-filled or one-filled tensors, which is the case in reverse mode ad. -* -* \return the pass -*/ + * \brief Convert all expressions of TensorType into GradCell, + * an algebraic data type defined in gradient.rly. + * + * This will delay or decrease memory usage. All calls to + * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, + * rather only instantiate if needed. It also defines + and * operation + * between GradCell types which can increase performance when using + * zero-filled or one-filled tensors, which is the case in reverse mode ad. + * + * \return the pass + */ TVM_DLL Pass LazyGradientInit(); /*! @@ -373,9 +371,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); * \return A type checked Function with its checked_type field populated. * \note this function mutates mod and is not thread-safe. */ -TVM_DLL Function InferType(const Function& f, - const IRModule& mod, - const GlobalVar& var); +TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This @@ -389,8 +385,7 @@ TVM_DLL Function InferType(const Function& f, * an Expr consumed by multiple callers. * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); @@ -406,8 +401,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, * * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e8f402ac961d6..105f74edcfa7b 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,18 +24,18 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ -#include +#include +#include #include +#include #include -#include #include -#include #include + #include #include "base.h" - namespace tvm { namespace relay { diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index abfc792d574f5..741b2807a2f1d 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -45,11 +45,8 @@ extern "C" { * * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -typedef int (*TVMBackendPackedCFunc)(TVMValue* args, - int* type_codes, - int num_args, - TVMValue* out_ret_value, - int* out_ret_tcode); +typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, + TVMValue* out_ret_value, int* out_ret_tcode); /*! * \brief Backend function for modules to get function @@ -61,9 +58,7 @@ typedef int (*TVMBackendPackedCFunc)(TVMValue* args, * \param out The result function. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *out); +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); /*! * \brief Backend function to register system-wide library symbol. * @@ -87,11 +82,8 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); * certain backends such as OpenGL. * \return nullptr when error is thrown, a valid ptr if success */ -TVM_DLL void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t nbytes, - int dtype_code_hint, - int dtype_bits_hint); +TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, + int dtype_code_hint, int dtype_bits_hint); /*! * \brief Backend function to free temporal workspace. @@ -103,9 +95,7 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, * * \sa TVMBackendAllocWorkspace */ -TVM_DLL int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr); +TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); /*! * \brief Environment for TVM parallel task. @@ -125,8 +115,7 @@ typedef struct { * \param penv The parallel environment backs the execution. * \param cdata The supporting closure data. */ -typedef int (*FTVMParallelLambda)( - int task_id, TVMParallelGroupEnv* penv, void* cdata); +typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata); /*! * \brief Backend function for running parallel jobs. @@ -138,9 +127,7 @@ typedef int (*FTVMParallelLambda)( * * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, - void* cdata, - int num_task); +TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task); /*! * \brief BSP barrrier between parallel threads @@ -150,7 +137,6 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, */ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); - /*! * \brief Simple static initialization function. * Run f once and set handle to be not null. @@ -162,10 +148,7 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); * \param nbytes Number of bytes in the closure data. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void *cdata, - int nbytes); +TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 920ecfbf9b130..bb38ad8a84df9 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -63,15 +63,14 @@ // TVM version #define TVM_VERSION "0.7.dev1" - // TVM Runtime is DLPack compatible. #include #ifdef __cplusplus extern "C" { #endif -#include #include +#include /*! \brief type of array index. */ typedef int64_t tvm_index_t; @@ -83,6 +82,7 @@ typedef enum { kOpenGL = 11, kDLMicroDev = 13, kDLHexagon = 14, + kDLWebGPU = 15 // AddExtraTVMType which is not in DLPack here } TVMDeviceExtType; @@ -179,7 +179,7 @@ TVM_DLL void TVMAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -TVM_DLL const char *TVMGetLastError(void); +TVM_DLL const char* TVMGetLastError(void); /*! * \brief Load module from file. * \param file_name The file name to load the module from. @@ -190,9 +190,7 @@ TVM_DLL const char *TVMGetLastError(void); * \note The resulting module do not contain import relation. * It can be reconstructed by TVMModImport. */ -TVM_DLL int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out); +TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out); /*! * \brief Add dep to mod's dependency. @@ -202,8 +200,7 @@ TVM_DLL int TVMModLoadFromFile(const char* file_name, * \param dep The dependent module to be imported. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep); +TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep); /*! * \brief Get function from the module. @@ -213,10 +210,8 @@ TVM_DLL int TVMModImport(TVMModuleHandle mod, * \param out The result function, can be NULL if it is not available. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *out); +TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* out); /*! * \brief Free the Module @@ -258,12 +253,8 @@ TVM_DLL int TVMFuncFree(TVMFunctionHandle func); * The front-end need to call free function (e.g. TVMFuncFree) * to free these handles. */ -TVM_DLL int TVMFuncCall(TVMFunctionHandle func, - TVMValue* arg_values, - int* type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code); +TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code); /*! * \brief Set the return value of TVMPackedCFunc. @@ -276,10 +267,7 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, * \param type_code The type of the value to be returned. * \param num_ret Number of return values, for now only 1 is supported. */ -TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret); +TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret); /*! * \brief Inplace translate callback argument value to return value. @@ -304,12 +292,8 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. * \sa TVMCFuncSetReturn */ -typedef int (*TVMPackedCFunc)( - TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, - void* resource_handle); +typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, + void* resource_handle); /*! * \brief C callback to free the resource handle in C packed function. @@ -339,10 +323,8 @@ typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); * \param out the result function handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out); +TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, + TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out); /*! * \brief Register the function to runtime's global table. @@ -353,8 +335,7 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, * \param f The function to be registered. * \param override Whether allow override already registered function. */ -TVM_DLL int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override); +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override); /*! * \brief Get a global function. @@ -373,8 +354,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); * \param out_array The array of function names. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncListGlobalNames(int* out_size, - const char*** out_array); +TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); // Array related apis for quick proptyping /*! @@ -391,14 +371,8 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, * \param out The output handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out); +TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); /*! * \brief Free the TVM Array. @@ -414,9 +388,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy array data to CPU byte array. @@ -425,9 +397,7 @@ TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy the array, both from and to must be valid during the copy. @@ -436,9 +406,7 @@ TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, * \param stream The stream where the copy happens, can be NULL. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream); +TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); /*! * \brief Produce an array from the DLManagedTensor that shares data memory @@ -447,8 +415,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, * \param out The output array handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out); +TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out); /*! * \brief Produce a DLMangedTensor from the array that shares data memory with @@ -457,8 +424,7 @@ TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, * \param out The DLManagedTensor handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out); +TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out); /*! * \brief Delete (free) a DLManagedTensor's data. @@ -518,9 +484,7 @@ TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle strea * \param dst The destination stream to synchronize. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst); /*! @@ -550,6 +514,46 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); */ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); +/*! + * \brief Allocate a data space on device. + * \param ctx The device context to perform operation. + * \param nbytes The number of bytes in memory. + * \param alignment The alignment of the memory. + * \param type_hint The type of elements. Only needed by certain backends such + * as nbytes & alignment are sufficient for most backends. + * \param out_data The allocated device pointer. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint, void** out_data); + +/*! + * \brief Free a data space on device. + * \param ctx The device context to perform operation. + * \param ptr The data space. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr); + +/*! + * \brief Copy data from one place to another. + * \param from The source array. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param num_bytes The size of the memory in bytes + * \param ctx_from The source context + * \param ctx_to The target context + * \param type_hint The type of elements, only neded by certain backends. + * can be useful for cross device endian converison. + * \param stream Optional stream object. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index cdb92ba6779a1..49c005e36d7d9 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -39,8 +39,7 @@ // string_view: // https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations // https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros -#if defined(__cpp_lib_experimental_string_view) && \ - __cpp_lib_experimental_string_view >= 201411 +#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 #define TVM_USE_CXX14_STRING_VIEW_HASH 1 #else #define TVM_USE_CXX14_STRING_VIEW_HASH 0 @@ -135,8 +134,7 @@ class InplaceArrayBase { * \brief Destroy the Inplace Array Base object */ ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && - std::is_trivial::value)) { + if (!(std::is_standard_layout::value && std::is_trivial::value)) { size_t size = Self()->GetSize(); for (size_t i = 0; i < size; ++i) { ElemType* fp = reinterpret_cast(AddressOf(i)); @@ -179,10 +177,10 @@ class InplaceArrayBase { * \return Raw pointer to the element. */ void* AddressOf(size_t idx) const { - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); size_t kDataStart = sizeof(ArrayType); ArrayType* self = Self(); @@ -242,8 +240,7 @@ class ADT : public ObjectRef { * \param fields The fields of the ADT object. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::vector fields) - : ADT(tag, fields.begin(), fields.end()){}; + ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; /*! * \brief construct an ADT object reference. @@ -267,8 +264,7 @@ class ADT : public ObjectRef { * \param init The initializer list of fields. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::initializer_list init) - : ADT(tag, init.begin(), init.end()){}; + ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; /*! * \brief Access element at index. @@ -276,9 +272,7 @@ class ADT : public ObjectRef { * \param idx The array index * \return const ObjectRef */ - const ObjectRef& operator[](size_t idx) const { - return operator->()->operator[](idx); - } + const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } /*! * \brief Return the ADT tag. @@ -390,9 +384,7 @@ class String : public ObjectRef { * * \return the comparison result */ - bool operator==(const std::string& other) const { - return this->compare(other) == 0; - } + bool operator==(const std::string& other) const { return this->compare(other) == 0; } /*! * \brief Compare is not equal to other std::string @@ -512,11 +504,9 @@ class String : public ObjectRef { // This function falls back to string copy with c++11 compiler and is // recommended to be compiled with c++14 #if TVM_USE_CXX17_STRING_VIEW_HASH - return std::hash()( - std::string_view(data, size)); + return std::hash()(std::string_view(data, size)); #elif TVM_USE_CXX14_STRING_VIEW_HASH - return std::hash()( - std::experimental::string_view(data, size)); + return std::hash()(std::experimental::string_view(data, size)); #else return std::hash()(std::string(data, size)); #endif @@ -538,8 +528,7 @@ class String : public ObjectRef { * \return int zero if both char sequences compare equal. negative if this * appear before other, positive otherwise. */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count); + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); }; /*! \brief An object representing string moved from std::string. */ @@ -575,8 +564,7 @@ inline String String::operator=(std::string other) { return Downcast(*this); } -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count) { +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { if (lhs == rhs && lhs_count == rhs_count) return 0; for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { @@ -592,7 +580,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, } } -template<> +template <> struct PackedFuncValueConverter<::tvm::runtime::String> { static String From(const TVMArgValue& val) { if (val.IsObjectRef()) { @@ -612,8 +600,7 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { }; /*! \brief Helper to represent nullptr for optional. */ -struct NullOptType { -}; +struct NullOptType {}; /*! * \brief Optional container that to represent to a Nullable variant of T. @@ -628,12 +615,11 @@ struct NullOptType { * * \endcode */ -template +template class Optional : public ObjectRef { public: using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, - "Optional is only defined for ObjectRef."); + static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); // default constructors. Optional() = default; Optional(const Optional&) = default; @@ -656,9 +642,8 @@ class Optional : public ObjectRef { return *this; } // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) { - } + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} Optional& operator=(T other) { ObjectRef::operator=(std::move(other)); return *this; @@ -680,20 +665,12 @@ class Optional : public ObjectRef { * \return The contained value if the Optional is not null * otherwise return the default_value. */ - T value_or(T default_value) const { - return data_ != nullptr ? T(data_) : default_value; - } + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { - return *this != nullptr; - } + explicit operator bool() const { return *this != nullptr; } // operator overloadings - bool operator==(std::nullptr_t) const { - return data_ == nullptr; - } - bool operator!=(std::nullptr_t) const { - return data_ != nullptr; - } + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } auto operator==(const Optional& other) const { // support case where sub-class returns a symbolic ref type. using RetType = decltype(value() == other.value()); @@ -722,16 +699,14 @@ class Optional : public ObjectRef { if (*this != nullptr) return value() == other; return RetType(false); } - auto operator!=(const T& other) const { - return !(*this == other); - } - template + auto operator!=(const T& other) const { return !(*this == other); } + template auto operator==(const U& other) const { using RetType = decltype(value() == other); if (*this == nullptr) return RetType(false); return value() == other; } - template + template auto operator!=(const U& other) const { using RetType = decltype(value() != other); if (*this == nullptr) return RetType(true); @@ -740,7 +715,7 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; -template +template struct PackedFuncValueConverter> { static Optional From(const TVMArgValue& val) { if (val.type_code() == kTVMNullptr) return Optional(nullptr); @@ -755,8 +730,8 @@ struct PackedFuncValueConverter> { } // namespace runtime // expose the functions to the root namespace. -using runtime::String; using runtime::Optional; +using runtime::String; constexpr runtime::NullOptType NullOpt{}; } // namespace tvm diff --git a/include/tvm/runtime/crt/memory.h b/include/tvm/runtime/crt/memory.h index 3e47060a86c4b..7b88b3123644f 100644 --- a/include/tvm/runtime/crt/memory.h +++ b/include/tvm/runtime/crt/memory.h @@ -32,7 +32,7 @@ static int vleak_size = 0; * \param size The size of memory * \return The virtual address */ -void * vmalloc(size_t size); +void* vmalloc(size_t size); /*! * \brief Reallocate memory from manager @@ -40,13 +40,13 @@ void * vmalloc(size_t size); * \param size The size of memory * \return The virtual address */ -void * vrealloc(void * ptr, size_t size); +void* vrealloc(void* ptr, size_t size); /*! * \brief Free the memory. * \param ptr The pointer to the memory to deallocate * \return The virtual address */ -void vfree(void * ptr); +void vfree(void* ptr); #endif // TVM_RUNTIME_CRT_MEMORY_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 44385d63263b6..a10b83fd321be 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace runtime { @@ -52,8 +53,7 @@ class DataType { * \brief Constructor * \param dtype The DLDataType */ - explicit DataType(DLDataType dtype) - : data_(dtype) {} + explicit DataType(DLDataType dtype) : data_(dtype) {} /*! * \brief Constructor * \param code The type code. @@ -66,106 +66,70 @@ class DataType { data_.lanes = static_cast(lanes); } /*! \return The type code. */ - int code() const { - return static_cast(data_.code); - } + int code() const { return static_cast(data_.code); } /*! \return number of bits in the data. */ - int bits() const { - return static_cast(data_.bits); - } + int bits() const { return static_cast(data_.bits); } /*! \return number of bytes to store each scalar. */ - int bytes() const { - return (bits() + 7) / 8; - } + int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ - int lanes() const { - return static_cast(data_.lanes); - } + int lanes() const { return static_cast(data_.lanes); } /*! \return whether type is a scalar type. */ - bool is_scalar() const { - return lanes() == 1; - } + bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ - bool is_bool() const { - return code() == DataType::kUInt && bits() == 1; - } + bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ - bool is_float() const { - return code() == DataType::kFloat; - } + bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ - bool is_float16() const { - return is_float() && bits() == 16; - } + bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is an int type. */ - bool is_int() const { - return code() == DataType::kInt; - } + bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ - bool is_uint() const { - return code() == DataType::kUInt; - } + bool is_uint() const { return code() == DataType::kUInt; } /*! \return whether type is a handle type. */ - bool is_handle() const { - return code() == DataType::kHandle; - } + bool is_handle() const { return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ - bool is_vector() const { - return lanes() > 1; - } + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { - return is_vector() && bits() == 1; - } + bool is_vector_bool() const { return is_vector() && bits() == 1; } + /*! \return whether type is a Void type. */ + bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ - DataType with_lanes(int lanes) const { - return DataType(data_.code, data_.bits, lanes); - } + DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. * \return the result type. */ - DataType with_bits(int bits) const { - return DataType(data_.code, bits, data_.lanes); - } + DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } /*! * \brief Get the scalar version of the type. * \return the result type. */ - DataType element_of() const { - return with_lanes(1); - } + DataType element_of() const { return with_lanes(1); } /*! * \brief Equal comparator. * \param other The data type to compre against. * \return The comparison resilt. */ bool operator==(const DataType& other) const { - return - data_.code == other.data_.code && - data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; + return data_.code == other.data_.code && data_.bits == other.data_.bits && + data_.lanes == other.data_.lanes; } /*! * \brief NotEqual comparator. * \param other The data type to compre against. * \return The comparison resilt. */ - bool operator!=(const DataType& other) const { - return !operator==(other); - } + bool operator!=(const DataType& other) const { return !operator==(other); } /*! * \brief Converter to DLDataType * \return the result. */ - operator DLDataType () const { - return data_; - } + operator DLDataType() const { return data_; } /*! * \brief Construct an int type. @@ -173,44 +137,39 @@ class DataType { * \param lanes The number of lanes. * \return The constructed data type. */ - static DataType Int(int bits, int lanes = 1) { - return DataType(kDLInt, bits, lanes); - } + static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { - return DataType(kDLUInt, bits, lanes); - } + static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Float(int bits, int lanes = 1) { - return DataType(kDLFloat, bits, lanes); - } + static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { - return DataType::UInt(1, lanes); - } + static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Handle(int bits = 64, int lanes = 1) { - return DataType(kHandle, bits, lanes); - } + static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } + /*! + * \brief Construct a Void type. + * \return The constructed data type. + */ + static DataType Void() { return DataType(kHandle, 0, 0); } /*! * \brief Get the corresponding type of TVMShapeIndex. * \return The type of TVM shape index. @@ -235,14 +194,11 @@ class DataType { inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist - if (dtype == DataType::Bool() || - dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || + if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { return 1; } - CHECK_EQ(data_bits % 8, 0U) - << "Need to load/store by multiple of bytes"; + CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } @@ -311,29 +267,49 @@ inline std::string DLDataType2String(DLDataType t); // implementation details inline const char* TypeCode2Str(int type_code) { switch (type_code) { - case kDLInt: return "int"; - case kDLUInt: return "uint"; - case kDLFloat: return "float"; - case kTVMStr: return "str"; - case kTVMBytes: return "bytes"; - case kTVMOpaqueHandle: return "handle"; - case kTVMNullptr: return "NULL"; - case kTVMDLTensorHandle: return "ArrayHandle"; - case kTVMDataType: return "DLDataType"; - case kTVMContext: return "TVMContext"; - case kTVMPackedFuncHandle: return "FunctionHandle"; - case kTVMModuleHandle: return "ModuleHandle"; - case kTVMNDArrayHandle: return "NDArrayContainer"; - case kTVMObjectHandle: return "Object"; - case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kTVMContext: + return "TVMContext"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; } } inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { - os << "bool"; return os; + os << "bool"; + return os; + } + if (DataType(t).is_void()) { + return os << "void"; } if (t.code < kTVMCustomBegin) { os << TypeCode2Str(t.code); @@ -348,7 +324,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) return os; } -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) return os << dtype.operator DLDataType(); } @@ -361,19 +337,23 @@ inline std::string DLDataType2String(DLDataType t) { inline DLDataType String2DLDataType(std::string s) { DLDataType t; - // handle None type + // handle void type if (s.length() == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t = DataType::Void(); return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (s.substr(0, 3) == "int") { - t.code = kDLInt; scan = s.c_str() + 3; + t.code = kDLInt; + scan = s.c_str() + 3; } else if (s.substr(0, 4) == "uint") { - t.code = kDLUInt; scan = s.c_str() + 4; + t.code = kDLUInt; + scan = s.c_str() + 4; } else if (s.substr(0, 5) == "float") { - t.code = kDLFloat; scan = s.c_str() + 5; + t.code = kDLFloat; + scan = s.c_str() + 5; } else if (s.substr(0, 6) == "handle") { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f2ddc84e9f98b..7fb2f9db96f9d 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -85,9 +86,7 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes & alignment are sufficient for most backends. * \return The allocated device pointer. */ - virtual void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + virtual void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! * \brief Free a data space on device. @@ -108,16 +107,10 @@ class TVM_DLL DeviceAPI { * can be useful for cross device endian converison. * \param stream Optional stream object. */ - virtual void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t num_bytes, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) = 0; - /*! + virtual void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) = 0; + /*! * \brief Create a new stream of execution. * * \param ctx The context of allocation. @@ -156,9 +149,8 @@ class TVM_DLL DeviceAPI { * \param event_src The source stream to synchronize. * \param event_dst The destination stream to synchronize. */ - virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, - TVMStreamHandle event_dst); + virtual void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, + TVMStreamHandle event_dst); /*! * \brief Allocate temporal workspace for backend execution. * @@ -175,9 +167,7 @@ class TVM_DLL DeviceAPI { * \param type_hint The type of elements. Only needed by certain backends such * as OpenGL, as nbytes is sufficient for most backends. */ - virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + virtual void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * @@ -214,21 +204,39 @@ constexpr int kRPCSessMask = 128; */ inline const char* DeviceName(int type) { switch (type) { - case kDLCPU: return "cpu"; - case kDLGPU: return "gpu"; - case kDLCPUPinned: return "cpu_pinned"; - case kDLOpenCL: return "opencl"; - case kDLSDAccel: return "sdaccel"; - case kDLAOCL: return "aocl"; - case kDLVulkan: return "vulkan"; - case kDLMetal: return "metal"; - case kDLVPI: return "vpi"; - case kDLROCM: return "rocm"; - case kOpenGL: return "opengl"; - case kDLExtDev: return "ext_dev"; - case kDLMicroDev: return "micro_dev"; - case kDLHexagon: return "hexagon"; - default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; + case kDLCPU: + return "cpu"; + case kDLGPU: + return "gpu"; + case kDLCPUPinned: + return "cpu_pinned"; + case kDLOpenCL: + return "opencl"; + case kDLSDAccel: + return "sdaccel"; + case kDLAOCL: + return "aocl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kOpenGL: + return "opengl"; + case kDLExtDev: + return "ext_dev"; + case kDLWebGPU: + return "webgpu"; + case kDLMicroDev: + return "micro_dev"; + case kDLHexagon: + return "hexagon"; + default: + LOG(FATAL) << "unknown type =" << type; + return "Unknown"; } } diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index b9b420ad02b6f..1199c420f212a 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -24,9 +24,10 @@ #define TVM_RUNTIME_MEMORY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -36,7 +37,7 @@ namespace runtime { * \tparam T the node type. * \return The ObjectPtr to the allocated object. */ -template +template inline ObjectPtr make_object(Args&&... args); // Detail implementations after this @@ -55,7 +56,7 @@ inline ObjectPtr make_object(Args&&... args); * * \tparam Derived The derived class. */ -template +template class ObjAllocatorBase { public: /*! @@ -64,13 +65,11 @@ class ObjAllocatorBase { * \tparam Args The constructor signature. * \param args The arguments. */ - template + template inline ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, - "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), - std::forward(args)...); + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this), std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -83,14 +82,13 @@ class ObjAllocatorBase { * \param num_elems The number of array elements. * \param args The arguments. */ - template + template inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { using Handler = typename Derived::template ArrayHandler; static_assert(std::is_base_of::value, "make_inplace_array can only be used to create Object"); - ArrayType* ptr = Handler::New(static_cast(this), - num_elems, - std::forward(args)...); + ArrayType* ptr = + Handler::New(static_cast(this), num_elems, std::forward(args)...); ptr->type_index_ = ArrayType::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -98,15 +96,14 @@ class ObjAllocatorBase { }; // Simple allocator that uses new/delete. -class SimpleObjAllocator : - public ObjAllocatorBase { +class SimpleObjAllocator : public ObjAllocatorBase { public: - template + template class Handler { public: using StorageType = typename std::aligned_storage::type; - template + template static T* New(SimpleObjAllocator*, Args&&... args) { // NOTE: the first argument is not needed for SimpleObjAllocator // It is reserved for special allocators that needs to recycle @@ -126,9 +123,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -146,16 +141,16 @@ class SimpleObjAllocator : }; // Array handler that uses new/delete. - template + template class ArrayHandler { public: using StorageType = typename std::aligned_storage::type; // for now only support elements that aligns with array header. static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, + sizeof(ArrayType) % alignof(ElemType) == 0, "element alignment constraint"); - template + template static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { // NOTE: the first argument is not needed for ArrayObjAllocator // It is reserved for special allocators that needs to recycle @@ -177,9 +172,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -193,20 +186,20 @@ class SimpleObjAllocator : // call a virtual destructor(which may not be available and is not required). tptr->ArrayType::~ArrayType(); StorageType* p = reinterpret_cast(tptr); - delete []p; + delete[] p; } }; }; -template +template inline ObjectPtr make_object(Args&&... args) { return SimpleObjAllocator().make_object(std::forward(args)...); } -template +template inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return SimpleObjAllocator().make_inplace_array( - num_elems, std::forward(args)...); + return SimpleObjAllocator().make_inplace_array(num_elems, + std::forward(args)...); } } // namespace runtime diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 3c43ae090efce..0e7cd2b087843 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,15 +27,14 @@ #define TVM_RUNTIME_MODULE_H_ #include - #include -#include #include +#include #include -#include #include #include +#include namespace tvm { namespace runtime { @@ -50,8 +49,7 @@ class Module : public ObjectRef { public: Module() {} // constructor from container. - explicit Module(ObjectPtr n) - : ObjectRef(n) {} + explicit Module(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Get packed function from current module by name. * @@ -82,8 +80,7 @@ class Module : public ObjectRef { * \note This function won't load the import relationship. * Re-create import relationship by calling Import. */ - TVM_DLL static Module LoadFromFile(const std::string& file_name, - const std::string& format = ""); + TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = ""); // refer to the corresponding container. using ContainerType = ModuleNode; friend class ModuleNode; @@ -137,16 +134,14 @@ class TVM_DLL ModuleNode : public Object { * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) = 0; + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) = 0; /*! * \brief Save the module to file. * \param file_name The file to be saved to. * \param format The format of the file. */ - virtual void SaveToFile(const std::string& file_name, - const std::string& format); + virtual void SaveToFile(const std::string& file_name, const std::string& format); /*! * \brief Save the module to binary stream. * \param stream The binary stream to save to. @@ -188,9 +183,7 @@ class TVM_DLL ModuleNode : public Object { */ const PackedFunc* GetFuncFromEnv(const std::string& name); /*! \return The module it imports from */ - const std::vector& imports() const { - return imports_; - } + const std::vector& imports() const { return imports_; } // integration with the existing components. static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; @@ -207,8 +200,7 @@ class TVM_DLL ModuleNode : public Object { private: /*! \brief Cache used by GetImport */ - std::unordered_map > import_cache_; + std::unordered_map > import_cache_; }; /*! @@ -238,13 +230,9 @@ constexpr const char* tvm_module_main = "__tvm_main__"; // implementations of inline functions. -inline void Module::Import(Module other) { - return (*this)->Import(other); -} +inline void Module::Import(Module other) { return (*this)->Import(other); } -inline ModuleNode* Module::operator->() { - return static_cast(get_mutable()); -} +inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { return static_cast(get()); @@ -254,4 +242,4 @@ inline const ModuleNode* Module::operator->() const { } // namespace tvm #include // NOLINT(*) -#endif // TVM_RUNTIME_MODULE_H_ +#endif // TVM_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 33f27f49fbbe0..8db93b46e934e 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -29,8 +29,8 @@ #include #include -#include #include +#include namespace tvm { namespace runtime { @@ -53,8 +53,7 @@ class NDArray : public ObjectRef { * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit NDArray(ObjectPtr data) - : ObjectRef(data) {} + explicit NDArray(ObjectPtr data) : ObjectRef(data) {} /*! \brief reset the content of NDArray to be nullptr */ inline void reset(); @@ -76,13 +75,13 @@ class NDArray : public ObjectRef { inline void CopyFrom(const DLTensor* other); inline void CopyFrom(const NDArray& other); /*! - * \brief Copy data content from a byte buffer. - * \param data The source bytes to be copied from. - * \param nbytes The size of the buffer in bytes - * Must be equal to the size of the NDArray. - * \note The copy may happen asynchronously if it involves a GPU context. - * TVMSynchronize is necessary. - */ + * \brief Copy data content from a byte buffer. + * \param data The source bytes to be copied from. + * \param nbytes The size of the buffer in bytes + * Must be equal to the size of the NDArray. + * \note The copy may happen asynchronously if it involves a GPU context. + * TVMSynchronize is necessary. + */ TVM_DLL void CopyFromBytes(const void* data, size_t nbytes); /*! * \brief Copy data content into another array. @@ -124,8 +123,7 @@ class NDArray : public ObjectRef { * \param dtype The data type of the new array. * \note The memory size of new array must be smaller than the current one. */ - TVM_DLL NDArray CreateView( - std::vector shape, DLDataType dtype); + TVM_DLL NDArray CreateView(std::vector shape, DLDataType dtype); /*! * \brief Create a reference view of NDArray that * represents as DLManagedTensor. @@ -139,9 +137,7 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * @@ -160,8 +156,8 @@ class NDArray : public ObjectRef { * \param to The target array. * \param stream The stream used in copy. */ - TVM_DLL static void CopyFromTo( - const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); + TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, + TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; // internal namespace @@ -244,9 +240,7 @@ class NDArray::ContainerBase { * \brief Object container class that backs NDArray. * \note do not use this function directly, use NDArray. */ -class NDArray::Container : - public Object, - public NDArray::ContainerBase { +class NDArray::Container : public Object, public NDArray::ContainerBase { public: /*! \brief default constructor */ Container() { @@ -259,10 +253,7 @@ class NDArray::Container : dl_tensor.byte_offset = 0; } - Container(void* data, - std::vector shape, - DLDataType dtype, - DLContext ctx) { + Container(void* data, std::vector shape, DLDataType dtype, DLContext ctx) { // Initialize the type index. type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; @@ -278,9 +269,7 @@ class NDArray::Container : * \brief Set the deleter field. * \param deleter The deleter. */ - void SetDeleter(FDeleter deleter) { - deleter_ = deleter; - } + void SetDeleter(FDeleter deleter) { deleter_ = deleter; } // Expose DecRef and IncRef as public function // NOTE: they are only for developer purposes only. @@ -360,53 +349,44 @@ inline void NDArray::CopyTo(const NDArray& other) const { inline NDArray NDArray::CopyTo(const DLContext& ctx) const { CHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), - dptr->dtype, ctx); + NDArray ret = + Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, ctx); this->CopyTo(ret); return ret; } -inline int NDArray::use_count() const { - return data_.use_count(); -} +inline int NDArray::use_count() const { return data_.use_count(); } -inline const DLTensor* NDArray::operator->() const { - return &(get_mutable()->dl_tensor); -} +inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); } inline NDArray::Container* NDArray::get_mutable() const { return static_cast(data_.get()); } inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { - return GetObjectPtr(static_cast( - reinterpret_cast(handle))); + return GetObjectPtr( + static_cast(reinterpret_cast(handle))); } inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - return reinterpret_cast( - static_cast( - static_cast( - const_cast(nd.get())))); + return reinterpret_cast(static_cast( + static_cast(const_cast(nd.get())))); } inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - static_cast( - reinterpret_cast(handle))->DecRef(); + static_cast(reinterpret_cast(handle))->DecRef(); } inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { - return static_cast( - reinterpret_cast(handle)); + return static_cast(reinterpret_cast(handle)); } /*! \brief Magic number for NDArray file */ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; -inline bool SaveDLTensor(dmlc::Stream* strm, - const DLTensor* tensor) { +inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { uint64_t header = kTVMNDArrayMagic, reserved = 0; strm->Write(header); strm->Write(reserved); @@ -435,16 +415,15 @@ inline bool SaveDLTensor(dmlc::Stream* strm, int64_t data_byte_size = type_bytes * num_elems; strm->Write(data_byte_size); - if (DMLC_IO_NO_ENDIAN_SWAP && - tensor->ctx.device_type == kDLCPU && - tensor->strides == nullptr && + if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDLCPU && tensor->strides == nullptr && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - CHECK_EQ(TVMArrayCopyToBytes( - const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), 0) + CHECK_EQ( + TVMArrayCopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), + 0) << TVMGetLastError(); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); @@ -454,33 +433,23 @@ inline bool SaveDLTensor(dmlc::Stream* strm, return true; } -inline void NDArray::Save(dmlc::Stream* strm) const { - SaveDLTensor(strm, operator->()); -} +inline void NDArray::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); } inline bool NDArray::Load(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved)) - << "Invalid DLTensor file format"; - CHECK(header == kTVMNDArrayMagic) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&header)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; DLContext ctx; int ndim; DLDataType dtype; - CHECK(strm->Read(&ctx)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&ndim)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&dtype)) - << "Invalid DLTensor file format"; - CHECK_EQ(ctx.device_type, kDLCPU) - << "Invalid DLTensor context: can only save as CPU tensor"; + CHECK(strm->Read(&ctx)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format"; + CHECK_EQ(ctx.device_type, kDLCPU) << "Invalid DLTensor context: can only save as CPU tensor"; std::vector shape(ndim); if (ndim != 0) { - CHECK(strm->ReadArray(&shape[0], ndim)) - << "Invalid DLTensor file format"; + CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } NDArray ret = NDArray::Empty(shape, dtype, ctx); int64_t num_elems = 1; @@ -489,12 +458,9 @@ inline bool NDArray::Load(dmlc::Stream* strm) { num_elems *= ret->shape[i]; } int64_t data_byte_size; - CHECK(strm->Read(&data_byte_size)) - << "Invalid DLTensor file format"; - CHECK(data_byte_size == num_elems * elem_bytes) - << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, data_byte_size)) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; + CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; + CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format"; if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(ret->data, elem_bytes, num_elems); } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 764dcdf3640c1..7d912c5db8210 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -25,8 +25,9 @@ #include #include -#include + #include +#include #include /*! @@ -100,8 +101,8 @@ struct TypeIndex { * Recommendation: set to estimate number of children needed. * - _type_child_slots_can_overflow: * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be + * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. * * Two macros are used to declare helper functions in the object: * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. @@ -163,28 +164,22 @@ class Object { */ typedef void (*FDeleter)(Object* self); /*! \return The internal runtime type index of the object. */ - uint32_t type_index() const { - return type_index_; - } + uint32_t type_index() const { return type_index_; } /*! * \return the type key of the object. * \note this operation is expensive, can be used for error reporting. */ - std::string GetTypeKey() const { - return TypeIndex2Key(type_index_); - } + std::string GetTypeKey() const { return TypeIndex2Key(type_index_); } /*! * \return A hash value of the return of GetTypeKey. */ - size_t GetTypeKeyHash() const { - return TypeIndex2KeyHash(type_index_); - } + size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. * \return Whether the target type is true. */ - template + template inline bool IsInstance() const; /*! @@ -214,12 +209,8 @@ class Object { static constexpr const char* _type_key = "runtime.Object"; - static uint32_t _GetOrAllocRuntimeTypeIndex() { - return TypeIndex::kRoot; - } - static uint32_t RuntimeTypeIndex() { - return TypeIndex::kRoot; - } + static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; } + static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; } // Default object type properties for sub-classes static constexpr bool _type_final = false; @@ -234,7 +225,6 @@ class Object { // The type index of Object is TypeIndex::kRoot static constexpr uint32_t _type_index = TypeIndex::kDynamic; - // Default constructor and copy constructor Object() {} // Override the copy and assign constructors to do nothing. @@ -246,10 +236,10 @@ class Object { } Object(Object&& other) { // NOLINT(*) } - Object& operator=(const Object& other) { //NOLINT(*) + Object& operator=(const Object& other) { // NOLINT(*) return *this; } - Object& operator=(Object&& other) { //NOLINT(*) + Object& operator=(Object&& other) { // NOLINT(*) return *this; } @@ -267,7 +257,7 @@ class Object { FDeleter deleter_ = nullptr; // Invariant checks. static_assert(sizeof(int32_t) == sizeof(RefCounterType) && - alignof(int32_t) == sizeof(RefCounterType), + alignof(int32_t) == sizeof(RefCounterType), "RefCounter ABI check."); /*! @@ -287,12 +277,10 @@ class Object { * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. * \return The allocated type index. */ - TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( - const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t type_child_slots, - bool type_child_slots_can_overflow); + TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t type_child_slots, + bool type_child_slots_can_overflow); // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -316,9 +304,9 @@ class Object { */ TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const; // friend classes - template + template friend class ObjAllocatorBase; - template + template friend class ObjectPtr; friend class TVMRetValue; friend class ObjectInternal; @@ -398,9 +386,7 @@ class ObjectPtr { other.data_ = nullptr; } /*! \brief destructor */ - ~ObjectPtr() { - this->reset(); - } + ~ObjectPtr() { this->reset(); } /*! * \brief Swap this array with another Object * \param other The other Object @@ -411,15 +397,11 @@ class ObjectPtr { /*! * \return Get the content of the pointer */ - T* get() const { - return static_cast(data_); - } + T* get() const { return static_cast(data_); } /*! * \return The pointer */ - T* operator->() const { - return get(); - } + T* operator->() const { return get(); } /*! * \return The reference */ @@ -455,29 +437,17 @@ class ObjectPtr { } } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } + bool unique() const { return data_ != nullptr && data_->use_count() == 1; } /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } private: /*! \brief internal pointer field */ @@ -506,9 +476,9 @@ class ObjectPtr { friend class Object; friend class ObjectRef; friend struct ObjectHash; - template + template friend class ObjectPtr; - template + template friend class ObjAllocatorBase; friend class TVMPODValue_; friend class TVMArgsSetter; @@ -533,55 +503,37 @@ class ObjectRef { * \param other Another object ref. * \return the compare result. */ - bool same_as(const ObjectRef& other) const { - return data_ == other.data_; - } + bool same_as(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator==(const ObjectRef& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator!=(const ObjectRef& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } /*! * \brief Comparator * \param other Another object ref by address. * \return the compare result. */ - bool operator<(const ObjectRef& other) const { - return data_.get() < other.data_.get(); - } + bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } /*! * \return whether the object is defined(not null). */ - bool defined() const { - return data_ != nullptr; - } + bool defined() const { return data_ != nullptr; } /*! \return the internal object pointer */ - const Object* get() const { - return data_.get(); - } + const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { - return get(); - } + const Object* operator->() const { return get(); } /*! \return whether the reference is unique */ - bool unique() const { - return data_.unique(); - } + bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_.use_count(); - } + int use_count() const { return data_.use_count(); } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -605,16 +557,14 @@ class ObjectRef { /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { - return data_.get(); - } + Object* get_mutable() const { return data_.get(); } /*! * \brief Internal helper function downcast a ref without check. * \note Only used for internal dev purposes. * \tparam T The target reference type. * \return The casted result. */ - template + template static T DowncastNoCheck(ObjectRef ref) { return T(std::move(ref.data_)); } @@ -623,16 +573,14 @@ class ObjectRef { * after we successfully moved the field. * \param ref The reference data. */ - static void FFIClearAfterMove(ObjectRef* ref) { - ref->data_.data_ = nullptr; - } + static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; } /*! * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. * \tparam ObjectType The corresponding object type. * \return the corresponding type. */ - template + template static ObjectPtr GetDataPtr(const ObjectRef& ref) { return ObjectPtr(ref.data_.data_); } @@ -657,68 +605,56 @@ inline ObjectPtr GetObjectPtr(ObjectType* ptr); /*! \brief ObjectRef hash functor */ struct ObjectHash { - size_t operator()(const ObjectRef& a) const { - return operator()(a.data_); - } + size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - template + template size_t operator()(const ObjectPtr& a) const { return std::hash()(a.get()); } }; - /*! \brief ObjectRef equal functor */ struct ObjectEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { - return a.same_as(b); - } + bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - template + template size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { return a == b; } }; - /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ - static uint32_t RuntimeTypeIndex() { \ - static_assert(TypeName::_type_child_slots == 0 || \ - ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ - return TypeName::_type_index; \ - } \ - return _GetOrAllocRuntimeTypeIndex(); \ - } \ - static uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ - TypeName::_type_key, \ - TypeName::_type_index, \ - ParentType::_GetOrAllocRuntimeTypeIndex(), \ - TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow); \ - return tidx; \ - } \ - +#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ + static uint32_t RuntimeTypeIndex() { \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ + } \ + return _GetOrAllocRuntimeTypeIndex(); \ + } \ + static uint32_t _GetOrAllocRuntimeTypeIndex() { \ + static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ + TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ + return tidx; \ + } /*! * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr bool _type_final = true; \ - static const constexpr int _type_child_slots = 0; \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr bool _type_final = true; \ + static const constexpr int _type_child_slots = 0; \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) /*! \brief helper macro to supress unused warning */ #if defined(__GNUC__) @@ -730,8 +666,7 @@ struct ObjectEqual { #define TVM_STR_CONCAT_(__x, __y) __x##__y #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) -#define TVM_OBJECT_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid +#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid /*! * \brief Helper macro to register the object type to runtime. @@ -739,20 +674,18 @@ struct ObjectEqual { * * Use this macro in the cc file for each terminal class. */ -#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ - TypeName::_GetOrAllocRuntimeTypeIndex() - +#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ + TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() /* * \brief Define the default copy/move constructor and assign opeator * \param TypeName The class typename. */ -#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ - TypeName& operator=(TypeName&& other) = default; \ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; /* * \brief Define object reference methods. @@ -760,15 +693,11 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /* @@ -778,15 +707,11 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - static constexpr bool _type_is_nullable = false; \ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + static constexpr bool _type_is_nullable = false; \ using ContainerType = ObjectName; /* @@ -797,15 +722,11 @@ struct ObjectEqual { * \note We recommend making objects immutable when possible. * This macro is only reserved for objects that stores runtime states. */ -#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /*! @@ -827,23 +748,21 @@ struct ObjectEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + ObjectName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER -inline void Object::IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); -} +inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } inline void Object::DecRef() { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { @@ -854,15 +773,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_.load(std::memory_order_relaxed); -} +inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } #else -inline void Object::IncRef() { - ++ref_counter_; -} +inline void Object::IncRef() { ++ref_counter_; } inline void Object::DecRef() { if (--ref_counter_ == 0) { @@ -872,13 +787,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_; -} +inline int Object::use_count() const { return ref_counter_; } #endif // TVM_OBJECT_ATOMIC_REF_COUNTER -template +template inline bool Object::IsInstance() const { const Object* self = this; // NOTE: the following code can be optimized by @@ -912,11 +825,9 @@ inline bool Object::IsInstance() const { } } - template inline const ObjectType* ObjectRef::as() const { - if (data_ != nullptr && - data_->IsInstance()) { + if (data_ != nullptr && data_->IsInstance()) { return static_cast(data_.get()); } else { return nullptr; @@ -944,12 +855,11 @@ template inline SubRef Downcast(BaseRef ref) { if (ref.defined()) { CHECK(ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key + << " failed."; } else { - CHECK(SubRef::_type_is_nullable) - << "Downcast from nullptr to not nullable reference of " - << SubRef::ContainerType::_type_key; + CHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c56600730eda3..01f8e994347a6 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,25 +26,33 @@ #include #include +#include #include #include -#include #include + #include -#include -#include -#include #include #include -#include +#include +#include #include - +#include +#include // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY #define TVM_RUNTIME_HEADER_ONLY 0 #endif +// Always inline macro only use in template +// expansion cases where we know inline is important. +#ifdef _MSC_VER +#define TVM_ALWAYS_INLINE __forceinline inline +#else +#define TVM_ALWAYS_INLINE inline __attribute__((always_inline)) +#endif + namespace tvm { namespace runtime { @@ -83,7 +91,7 @@ class PackedFunc { * } * \endcode */ - using FType = std::function; + using FType = std::function; /*! \brief default constructor */ PackedFunc() {} /*! \brief constructor from null */ @@ -107,8 +115,8 @@ class PackedFunc { * } * \endcode */ - template - inline TVMRetValue operator()(Args&& ...args) const; + template + inline TVMRetValue operator()(Args&&... args) const; /*! * \brief Call the function in packed format. * \param args The arguments @@ -118,13 +126,9 @@ class PackedFunc { /*! \return the internal body function */ inline FType body() const; /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return body_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return body_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return body_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } private: /*! \brief internal container of packed function */ @@ -134,7 +138,7 @@ class PackedFunc { /*! * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc" */ -template +template class TypedPackedFunc; /*! @@ -169,7 +173,7 @@ class TypedPackedFunc; * \tparam R The return value of the function. * \tparam Args The argument signature of the function. */ -template +template class TypedPackedFunc { public: /*! \brief short hand for this function type */ @@ -226,11 +230,9 @@ class TypedPackedFunc { * \param typed_lambda typed lambda function. * \tparam FLambda the type of the lambda function. */ - template - >::value>::type> + template >::value>::type> TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); } @@ -250,11 +252,9 @@ class TypedPackedFunc { * \tparam FLambda the type of the lambda function. * \returns reference to self. */ - template - >::value>::type> + template >::value>::type> TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); return *this; @@ -273,28 +273,20 @@ class TypedPackedFunc { * \param args The arguments * \returns The return value. */ - inline R operator()(Args ...args) const; + TVM_ALWAYS_INLINE R operator()(Args... args) const; /*! * \brief convert to PackedFunc * \return the internal PackedFunc */ - operator PackedFunc() const { - return packed(); - } + operator PackedFunc() const { return packed(); } /*! * \return reference the internal PackedFunc */ - const PackedFunc& packed() const { - return packed_; - } + const PackedFunc& packed() const { return packed_; } /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return packed_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return packed_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } private: friend class TVMRetValue; @@ -307,7 +299,7 @@ class TypedPackedFunc { * \tparam FLambda The lambda function type. * \note We capture the lambda when possible for maximum efficiency. */ - template + template inline void AssignTypedLambda(FLambda flambda); }; @@ -323,12 +315,8 @@ class TVMArgs { * \param type_codes The argument type codes * \param num_args number of arguments. */ - TVMArgs(const TVMValue* values, - const int* type_codes, - int num_args) - : values(values), - type_codes(type_codes), - num_args(num_args) { } + TVMArgs(const TVMValue* values, const int* type_codes, int num_args) + : values(values), type_codes(type_codes), num_args(num_args) {} /*! \return size of the arguments */ inline int size() const; /*! @@ -340,15 +328,14 @@ class TVMArgs { }; // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - CHECK_EQ(CODE, T) << " expected " \ - << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ +#define TVM_CHECK_TYPE_CODE(CODE, T) \ + CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. * \tparam T the type to be checked. */ -template +template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; @@ -402,61 +389,53 @@ class TVMPODValue_ { return value_.v_handle; } operator DLTensor*() const { - if (type_code_ == kTVMDLTensorHandle || - type_code_ == kTVMNDArrayHandle) { + if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { return static_cast(value_.v_handle); } else { if (type_code_ == kTVMNullptr) return nullptr; LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " - << TypeCode2Str(type_code_); + << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); return nullptr; } } operator NDArray() const { if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - return NDArray(NDArray::FFIDataFromHandle( - static_cast(value_.v_handle))); + return NDArray(NDArray::FFIDataFromHandle(static_cast(value_.v_handle))); } operator Module() const { if (type_code_ == kTVMNullptr) { return Module(ObjectPtr(nullptr)); } TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); - return Module( - ObjectPtr(static_cast(value_.v_handle))); + return Module(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } - int type_code() const { - return type_code_; - } + int type_code() const { return type_code_; } /*! * \brief return handle as specific pointer type. * \tparam T the data type. * \return The pointer type. */ - template + template T* ptr() const { return static_cast(value_.v_handle); } // ObjectRef handling - template::value>::type> + template ::value>::type> inline bool IsObjectRef() const; - template + template inline TObjectRef AsObjectRef() const; protected: friend class TVMArgsSetter; friend class TVMRetValue; TVMPODValue_() : type_code_(kTVMNullptr) {} - TVMPODValue_(TVMValue value, int type_code) - : value_(value), type_code_(type_code) {} + TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ TVMValue value_; @@ -479,9 +458,7 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) - : TVMPODValue_(value, type_code) { - } + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -493,8 +470,8 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -515,31 +492,27 @@ class TVMArgValue : public TVMPODValue_ { // None type if (type_code_ == kTVMNullptr) { DLDataType t; - t.code = kTVMOpaqueHandle; t.bits = 0; t.lanes = 0; + t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; return t; } TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - const TVMValue& value() const { - return value_; - } + const TVMValue& value() const { return value_; } - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -555,9 +528,7 @@ class TVMArgValue : public TVMPODValue_ { */ class TVMMovableArgValue_ : public TVMArgValue { public: - TVMMovableArgValue_(TVMValue value, int type_code) - : TVMArgValue(value, type_code) { - } + TVMMovableArgValue_(TVMValue value, int type_code) : TVMArgValue(value, type_code) {} // reuse converter from parent using TVMArgValue::operator double; using TVMArgValue::operator int64_t; @@ -576,9 +547,8 @@ class TVMMovableArgValue_ : public TVMArgValue { * Try to move out an argument if possible, * fall back to normal argument conversion rule otherwise. */ - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -598,15 +568,12 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from anoter return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) - : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ - ~TVMRetValue() { - this->Clear(); - } + ~TVMRetValue() { this->Clear(); } // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -618,12 +585,10 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { - this->Assign(other); - } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -641,15 +606,13 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } @@ -696,9 +659,7 @@ class TVMRetValue : public TVMPODValue_ { value_.v_type = t; return *this; } - TVMRetValue& operator=(const DataType& other) { - return operator=(other.operator DLDataType()); - } + TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; @@ -728,10 +689,14 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - this->SwitchToClass(kTVMPackedFuncHandle, f); + if (f == nullptr) { + this->SwitchToPOD(kTVMNullptr); + } else { + this->SwitchToClass(kTVMPackedFuncHandle, f); + } return *this; } - template + template TVMRetValue& operator=(const TypedPackedFunc& f) { return operator=(f.packed()); } @@ -756,8 +721,7 @@ class TVMRetValue : public TVMPODValue_ { * \param ret_value The return value. * \param ret_type_code The return type code. */ - void MoveToCHost(TVMValue* ret_value, - int* ret_type_code) { + void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { // cannot move str; need specially handle. CHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); *ret_value = value_; @@ -771,11 +735,9 @@ class TVMRetValue : public TVMPODValue_ { * \param type_code The type code. * \return The created TVMRetValue. */ - static TVMRetValue MoveFromCHost(TVMValue value, - int type_code) { + static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - CHECK(type_code <= kTVMPackedFuncHandle || - type_code == kTVMNDArrayHandle); + CHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -783,24 +745,20 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kTVMObjectHandle && - type_code_ != kTVMPackedFuncHandle && - type_code_ != kTVMModuleHandle && - type_code_ != kTVMStr) << "TVMRetValue.value can only be used for POD data"; + CHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && + type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) + << "TVMRetValue.value can only be used for POD data"; return value_; } // ObjectRef handling - template::value>::type> + template ::value>::type> inline TVMRetValue& operator=(TObjectRef other); - template::value>::type> + template ::value>::type> inline operator T() const; private: - template + template void Assign(const T& other) { switch (other.type_code()) { case kTVMStr: { @@ -825,9 +783,8 @@ class TVMRetValue : public TVMPODValue_ { } case kTVMObjectHandle: { // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject( - kTVMObjectHandle, GetObjectPtr( - static_cast(other.value_.v_handle))); + SwitchToObject(kTVMObjectHandle, + GetObjectPtr(static_cast(other.value_.v_handle))); break; } case kTVMObjectRValueRefArg: { @@ -848,7 +805,7 @@ class TVMRetValue : public TVMPODValue_ { type_code_ = type_code; } } - template + template void SwitchToClass(int type_code, T v) { if (type_code_ != type_code) { this->Clear(); @@ -872,8 +829,13 @@ class TVMRetValue : public TVMPODValue_ { void Clear() { if (type_code_ == kTVMNullptr) return; switch (type_code_) { - case kTVMStr: case kTVMBytes: delete ptr(); break; - case kTVMPackedFuncHandle: delete ptr(); break; + case kTVMStr: + case kTVMBytes: + delete ptr(); + break; + case kTVMPackedFuncHandle: + delete ptr(); + break; case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); break; @@ -900,24 +862,20 @@ class TVMRetValue : public TVMPODValue_ { * * \tparam TObjectRef the specific ObjectRefType. */ -template +template struct PackedFuncValueConverter { /*! * \brief Convert a TObjectRef from an argument value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMArgValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef(); } /*! * \brief Convert a TObjectRef from a return value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMRetValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef(); } }; /*! @@ -939,29 +897,22 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code); \ - int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - ::tvm::runtime::TVMRetValue rv; \ - Function(::tvm::runtime::TVMArgs( \ - args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code); \ + int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + ::tvm::runtime::TVMRetValue rv; \ + Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } /*! @@ -999,181 +950,170 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - auto f = Function; \ - using FType = ::tvm::runtime::detail:: \ - function_signature::FType; \ - ::tvm::runtime::TVMRetValue rv; \ - ::tvm::runtime::detail::unpack_call_by_signature::run( \ - f, \ - ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + auto f = Function; \ + using FType = ::tvm::runtime::detail::function_signature::FType; \ + ::tvm::runtime::TVMRetValue rv; \ + ::tvm::runtime::detail::unpack_call_by_signature::run( \ + f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } - inline TVMArgValue TVMArgs::operator[](int i) const { - CHECK_LT(i, num_args) - << "not enough argument passed, " - << num_args << " passed" - << " but request arg[" << i << "]."; + CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" + << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } -inline int TVMArgs::size() const { - return num_args; -} +inline int TVMArgs::size() const { return num_args; } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { - body_(args, rv); -} +inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } -inline PackedFunc::FType PackedFunc::body() const { - return body_; -} +inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace namespace detail { -template +template struct for_each_dispatcher { - template + template static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) f(I, std::forward(value)); - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } }; -template -struct for_each_dispatcher { +template +struct for_each_dispatcher { static void run(const F& f) {} // NOLINT(*) }; -template +template inline void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } -template +template struct func_signature_helper { using FType = void; }; -template +template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; -template +template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; /*! * \brief template class to get function signature of a function or functor. * \tparam T The funtion/functor type. */ -template +template struct function_signature { using FType = typename func_signature_helper::FType; }; // handle case of function. -template +template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; // handle case of function ptr. -template +template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; } // namespace detail /* \brief argument settter to PackedFunc */ class TVMArgsSetter { public: - TVMArgsSetter(TVMValue* values, int* type_codes) - : values_(values), type_codes_(type_codes) {} + TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} // setters for POD types - template::value>::type> - void operator()(size_t i, T value) const { + template ::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } - void operator()(size_t i, uint64_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); - CHECK_LE(value, - static_cast(std::numeric_limits::max())); + CHECK_LE(value, static_cast(std::numeric_limits::max())); type_codes_[i] = kDLInt; } - void operator()(size_t i, double value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } - void operator()(size_t i, std::nullptr_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { values_[i].v_handle = value; type_codes_[i] = kTVMNullptr; } - void operator()(size_t i, const TVMArgValue& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { values_[i] = value.value_; type_codes_[i] = value.type_code_; } - void operator()(size_t i, void* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMOpaqueHandle; } - void operator()(size_t i, DLTensor* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMDLTensorHandle; } - void operator()(size_t i, TVMContext value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const { values_[i].v_ctx = value; type_codes_[i] = kTVMContext; } - void operator()(size_t i, DLDataType value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { values_[i].v_type = value; type_codes_[i] = kTVMDataType; } - void operator()(size_t i, DataType dtype) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { operator()(i, dtype.operator DLDataType()); } - void operator()(size_t i, const char* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kTVMStr; } // setters for container types - void operator()(size_t i, const std::string& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kTVMStr; } - void operator()(size_t i, const TVMByteArray& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - void operator()(size_t i, const PackedFunc& value) const { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; + TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { + if (value != nullptr) { + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kTVMPackedFuncHandle; + } else { + values_[i].v_handle = nullptr; + type_codes_[i] = kTVMNullptr; + } } - template - void operator()(size_t i, const TypedPackedFunc& value) const { + template + TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { @@ -1187,25 +1127,21 @@ class TVMArgsSetter { } } // ObjectRef handling - template::value> - ::type> - void operator()(size_t i, const TObjectRef& value) const { + template ::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { this->SetObject(i, value); } - template::type>::value> - ::type> - void operator()(size_t i, TObjectRef&& value) const { + template ::type>::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { this->SetObject(i, std::forward(value)); } private: - template + template inline void SetObject(size_t i, TObjectRef&& value) const; /*! \brief The values fields */ TVMValue* values_; @@ -1213,130 +1149,120 @@ class TVMArgsSetter { int* type_codes_; }; -template -inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { +template +inline TVMRetValue PackedFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; - detail::for_each(TVMArgsSetter(values, type_codes), - std::forward(args)...); + detail::for_each(TVMArgsSetter(values, type_codes), std::forward(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } namespace detail { -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. - unpack_call_dispatcher - ::run(f, args_pack, rv, - std::forward(unpacked_args)..., - TVMMovableArgValue_(args_pack.values[index], - args_pack.type_codes[index])); + unpack_call_dispatcher::run( + f, args_pack, rv, std::forward(unpacked_args)..., + TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index])); } }; -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { - *rv = R(f(std::forward(unpacked_args)...)); + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { + using RetType = decltype(f(std::forward(unpacked_args)...)); + if (std::is_same::value) { + *rv = f(std::forward(unpacked_args)...); + } else { + *rv = R(f(std::forward(unpacked_args)...)); + } } }; -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; -template -inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { - CHECK_EQ(nargs, args.size()) - << "Expect " << nargs << " arguments but get " << args.size(); +template +TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { + CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); unpack_call_dispatcher::run(f, args, rv); } -template -struct unpack_call_by_signature { -}; +template +struct unpack_call_by_signature {}; -template +template struct unpack_call_by_signature { - template - static void run(const F& f, - const TVMArgs& args, - TVMRetValue* rv) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { unpack_call(f, args, rv); } }; -template -inline R call_packed(const PackedFunc& pf, Args&& ...args) { +template +TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) { return R(pf(std::forward(args)...)); } -template +template struct typed_packed_call_dispatcher { - template - static inline R run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) { return pf(std::forward(args)...); } }; -template<> +template <> struct typed_packed_call_dispatcher { - template - static inline void run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) { pf(std::forward(args)...); } }; } // namespace detail -template -TypedPackedFunc::TypedPackedFunc(PackedFunc packed) - : packed_(packed) {} +template +TypedPackedFunc::TypedPackedFunc(PackedFunc packed) : packed_(packed) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMRetValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) : packed_(value.operator PackedFunc()) {} -template -template +template +template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { - detail::unpack_call(flambda, args, rv); - }); + detail::unpack_call(flambda, args, rv); + }); } -template -inline R TypedPackedFunc::operator()(Args... args) const { - return detail::typed_packed_call_dispatcher - ::run(packed_, std::forward(args)...); +template +TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const { + return detail::typed_packed_call_dispatcher::run(packed_, std::forward(args)...); } // ObjectRef related conversion handling @@ -1344,18 +1270,18 @@ inline R TypedPackedFunc::operator()(Args... args) const { // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // // We use type traits to eliminate un-necessary checks. -template +template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { - using TObjectRef = typename std::remove_reference::type; + using ContainerType = typename std::remove_reference::type::ContainerType; if (value.defined()) { Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && + } else if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; @@ -1371,53 +1297,53 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } -template +template inline bool TVMPODValue_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMNDArrayHandle && - TVMArrayHandleToObjectHandle( - static_cast(value_.v_handle))->IsInstance(); + TVMArrayHandleToObjectHandle(static_cast(value_.v_handle)) + ->IsInstance(); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMModuleHandle && - static_cast(value_.v_handle)->IsInstance(); + static_cast(value_.v_handle)->IsInstance(); } // NOTE: we don't pass NDArray and runtime::Module as RValue ref. if (type_code_ == kTVMObjectRValueRefArg) { - return ObjectTypeChecker::Check( - *static_cast(value_.v_handle)); - } - return - (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) || - (std::is_base_of::value && type_code_ == kTVMModuleHandle) || - (type_code_ == kTVMObjectHandle && - ObjectTypeChecker::Check(static_cast(value_.v_handle))); + return ObjectTypeChecker::Check(*static_cast(value_.v_handle)); + } + return (std::is_base_of::value && + type_code_ == kTVMNDArrayHandle) || + (std::is_base_of::value && + type_code_ == kTVMModuleHandle) || + (type_code_ == kTVMObjectHandle && + ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -template +template inline TObjectRef TVMPODValue_::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); + static_assert(std::is_base_of::value, + "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; + if (type_code_ == kTVMNullptr) { CHECK(TObjectRef::_type_is_nullable) << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); CHECK(data->IsInstance()) << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -1429,22 +1355,22 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { // normal object type check. Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (type_code_ == kTVMObjectRValueRefArg) { Object* ptr = *static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); return TObjectRef(data); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMModuleHandle) { // Casting to a base class that Module can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); @@ -1454,17 +1380,18 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { } } -template +template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { + using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(NDArray(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(Module(std::move(other.data_))); } @@ -1475,13 +1402,12 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { return *this; } - -template +template inline TVMArgValue::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMMovableArgValue_::operator T() const { if (type_code_ == kTVMObjectRValueRefArg) { auto** ref = static_cast(value_.v_handle); @@ -1493,7 +1419,7 @@ inline TVMMovableArgValue_::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMRetValue::operator T() const { return PackedFuncValueConverter::From(*this); } diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 6faa7b7c84d70..4a5a21088222d 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -44,9 +44,10 @@ #define TVM_RUNTIME_REGISTRY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -68,7 +69,8 @@ class Registry { } /*! * \brief set the body of the function to the given function. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -88,14 +90,15 @@ class Registry { * \param f The function to forward to. * \tparam FLambda The signature of the function. */ - template + template Registry& set_body_typed(FLambda f) { using FType = typename detail::function_signature::FType; return set_body(TypedPackedFunc(std::move(f)).packed()); } /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -113,9 +116,9 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...)) { - auto fwrap =[f](T target, Args... params) -> R { + auto fwrap = [f](T target, Args... params) -> R { // call method pointer return (target.*f)(params...); }; @@ -124,7 +127,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -142,7 +146,7 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...) const) { auto fwrap = [f](const T target, Args... params) -> R { // call method pointer @@ -154,7 +158,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -181,8 +186,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { auto fwrap = [f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); @@ -195,7 +200,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -222,8 +228,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { auto fwrap = [f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); @@ -270,8 +276,7 @@ class Registry { friend struct Manager; }; -#define TVM_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM +#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM /*! * \brief Register a function globally. @@ -281,9 +286,8 @@ class Registry { * }); * \endcode */ -#define TVM_REGISTER_GLOBAL(OpName) \ - TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::runtime::Registry::Register(OpName) +#define TVM_REGISTER_GLOBAL(OpName) \ + TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName) } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index 37bb95f546554..f40c87ee07ec7 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -33,14 +33,14 @@ namespace dmlc { namespace serializer { -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLDataType& dtype) { + inline static void Write(Stream* strm, const DLDataType& dtype) { Handler::Write(strm, dtype.code); Handler::Write(strm, dtype.bits); Handler::Write(strm, dtype.lanes); } - inline static bool Read(Stream *strm, DLDataType* dtype) { + inline static bool Read(Stream* strm, DLDataType* dtype) { if (!Handler::Read(strm, &(dtype->code))) return false; if (!Handler::Read(strm, &(dtype->bits))) return false; if (!Handler::Read(strm, &(dtype->lanes))) return false; @@ -48,14 +48,14 @@ struct Handler { } }; -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLContext& ctx) { + inline static void Write(Stream* strm, const DLContext& ctx) { int32_t device_type = static_cast(ctx.device_type); Handler::Write(strm, device_type); Handler::Write(strm, ctx.device_id); } - inline static bool Read(Stream *strm, DLContext* ctx) { + inline static bool Read(Stream* strm, DLContext* ctx) { int32_t device_type = 0; if (!Handler::Read(strm, &(device_type))) return false; ctx->device_type = static_cast(device_type); diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index f1984013e6a9a..95a64049fd45e 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -40,26 +40,25 @@ class ThreadGroup { public: class Impl; - /*! - * \brief Creates a collection of threads which run a provided function. - * - * \param num_workers The total number of worker threads in this group. - Includes main thread if `exclude_worker0 = true` - * \param worker_callback A callback which is run in its own thread. - Receives the worker_id as an argument. - * \param exclude_worker0 Whether to use the main thread as a worker. - * If `true`, worker0 will not be launched in a new thread and - * `worker_callback` will only be called for values >= 1. This - * allows use of the main thread as a worker. - */ - ThreadGroup(int num_workers, - std::function worker_callback, + /*! + * \brief Creates a collection of threads which run a provided function. + * + * \param num_workers The total number of worker threads in this group. + Includes main thread if `exclude_worker0 = true` + * \param worker_callback A callback which is run in its own thread. + Receives the worker_id as an argument. + * \param exclude_worker0 Whether to use the main thread as a worker. + * If `true`, worker0 will not be launched in a new thread and + * `worker_callback` will only be called for values >= 1. This + * allows use of the main thread as a worker. + */ + ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0 = false); ~ThreadGroup(); - /*! - * \brief Blocks until all non-main threads in the pool finish. - */ + /*! + * \brief Blocks until all non-main threads in the pool finish. + */ void Join(); enum AffinityMode : int { @@ -95,7 +94,6 @@ void Yield(); */ int MaxConcurrency(); - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 44d58987954c6..62534d9ca6a96 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_ -#include #include +#include #include #include + #include #include #include @@ -271,8 +272,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, - const std::vector& shape, DLDataType dtype, RegName dst); + static Instruction AllocTensor(RegName storage, const std::vector& shape, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. @@ -281,8 +282,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, - RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, + RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. @@ -379,8 +380,8 @@ struct Instruction { * \param dst The destination to place the storage. * \return The alloc storage instruction. */ - static Instruction AllocStorage(RegName size, RegName alignment, - DLDataType dtype_hint, RegName dst); + static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint, + RegName dst); Instruction(); Instruction(const Instruction& instr); @@ -407,8 +408,7 @@ struct VMFunction { Index register_file_size; VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, - Index register_file_size) + const std::vector& instructions, Index register_file_size) : name(name), params(params), instructions(instructions), @@ -473,8 +473,7 @@ class Executable : public ModuleNode { * * \return PackedFunc or nullptr when it is not available. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief Serialize the executable into global section, constant section, and @@ -559,9 +558,7 @@ class Executable : public ModuleNode { virtual ~Executable() {} - const char* type_key() const final { - return "VMExecutable"; - } + const char* type_key() const final { return "VMExecutable"; } /*! \brief The runtime module/library that contains both the host and also the device * code when executing on non-CPU devices. */ @@ -668,14 +665,11 @@ class VirtualMachine : public runtime::ModuleNode { * If the function needs resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); virtual ~VirtualMachine() {} - const char* type_key() const final { - return "VirtualMachine"; - } + const char* type_key() const final { return "VirtualMachine"; } VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {} @@ -763,11 +757,8 @@ class VirtualMachine : public runtime::ModuleNode { * * \note The return value will be stored in the last output_size slots of args. */ - virtual void InvokePacked(Index packed_index, - const PackedFunc& func, - Index arg_count, - Index output_size, - const std::vector& args); + virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args); /*! * \brief Initialize the virtual machine for a set of contexts. diff --git a/include/tvm/support/logging.h b/include/tvm/support/logging.h index 44b990e0d7db3..c318b89e5c51f 100644 --- a/include/tvm/support/logging.h +++ b/include/tvm/support/logging.h @@ -59,8 +59,8 @@ * a = ... * b = ... * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default behaviour) - * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" + * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default + * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" * ... * for (int i = 0; i < N; i++) { * a = ... @@ -84,29 +84,24 @@ // Not supposed to be used by users directly. #define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ - if (!quit_on_assert) { \ - if (!((x) op (y))) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!((x)op(y))) what; \ + } else /* NOLINT(*) */ \ CHECK_##op(x, y) #define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) #define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) #define COND_CHECK_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - if (!(x)) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!(x)) what; \ + } else /* NOLINT(*) */ \ CHECK(x) #define COND_LOG_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + what; \ + } else /* NOLINT(*) */ \ LOG(x) #define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) @@ -114,4 +109,4 @@ #define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) #define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) -#endif // TVM_SUPPORT_LOGGING_H_ +#endif // TVM_SUPPORT_LOGGING_H_ diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 46b091a68f345..90c82c4f3a06b 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -26,6 +26,7 @@ #define TVM_SUPPORT_WITH_H_ #include + #include namespace tvm { @@ -52,22 +53,19 @@ namespace tvm { * * \tparam ContextType Type of the context object. */ -template +template class With { public: /*! * \brief constructor. * Enter the scope of the context. */ - template - explicit With(Args&& ...args) - : ctx_(std::forward(args)...) { + template + explicit With(Args&&... args) : ctx_(std::forward(args)...) { ctx_.EnterWithScope(); } /*! \brief destructor, leaves the scope of the context. */ - ~With() DMLC_THROW_EXCEPTION { - ctx_.ExitWithScope(); - } + ~With() DMLC_THROW_EXCEPTION { ctx_.ExitWithScope(); } private: /*! \brief internal context type. */ diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 4b7ea56e705de..e89d44dd4eb1f 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -24,14 +24,13 @@ #ifndef TVM_TARGET_CODEGEN_H_ #define TVM_TARGET_CODEGEN_H_ -#include #include -#include +#include #include +#include #include - namespace tvm { /*! \brief namespace for target translation and codegen. */ namespace codegen { @@ -71,8 +70,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib); * \param target_triple LLVM target triple * \return runtime::Module The generated LLVM module. */ -runtime::Module PackImportsToLLVM(const runtime::Module& m, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h index f2a361b3afaf1..a310173fa6eae 100644 --- a/include/tvm/target/generic_func.h +++ b/include/tvm/target/generic_func.h @@ -24,14 +24,14 @@ #ifndef TVM_TARGET_GENERIC_FUNC_H_ #define TVM_TARGET_GENERIC_FUNC_H_ -#include #include +#include #include -#include #include -#include #include +#include +#include namespace tvm { @@ -52,8 +52,7 @@ class GenericFunc : public ObjectRef { * false, an error will be logged if the call would override a previously registered function. * \return reference to self. */ - TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, - bool allow_override = false); + TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Register a specialized function * \param tags The tags for this specialization @@ -63,8 +62,7 @@ class GenericFunc : public ObjectRef { * \return reference to self. */ TVM_DLL GenericFunc& register_func(const std::vector& tags, - const runtime::PackedFunc value, - bool allow_override = false); + const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Call generic function by directly passing in unpacked format. * \param args Arguments to be passed. @@ -79,16 +77,15 @@ class GenericFunc : public ObjectRef { * } * \endcode */ - template - inline runtime::TVMRetValue operator()(Args&& ...args) const; + template + inline runtime::TVMRetValue operator()(Args&&... args) const; /*! * \brief Invoke the relevant function for the current target context, set by set_target_context. * Arguments are passed in packed format. * \param args The arguments to pass to the function. * \param ret The return value */ - TVM_DLL void CallPacked(runtime::TVMArgs args, - runtime::TVMRetValue* ret) const; + TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const; /*! * \brief Find or register the GenericFunc instance corresponding to the give name @@ -120,14 +117,14 @@ class GenericFunc : public ObjectRef { friend struct Manager; }; -template -inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { +template +inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes), - std::forward(args)...); + std::forward(args)...); runtime::TVMRetValue rv; CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv); return rv; @@ -155,8 +152,7 @@ inline GenericFuncNode* GenericFunc::operator->() { return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM +#define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM /*! * \def TVM_REGISTER_GENERIC_FUNC @@ -165,9 +161,8 @@ inline GenericFuncNode* GenericFunc::operator->() { * * \param name The name of the function */ -#define TVM_REGISTER_GENERIC_FUNC(name) \ - TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::GenericFunc::Get(#name) +#define TVM_REGISTER_GENERIC_FUNC(name) \ + TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name) } // namespace tvm #endif // TVM_TARGET_GENERIC_FUNC_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 829de7381f00e..c28b0514dac1a 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,15 +24,15 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ -#include -#include #include #include +#include +#include #include -#include #include #include +#include namespace tvm { /*! @@ -99,9 +99,9 @@ class Target : public ObjectRef { Target() {} explicit Target(ObjectPtr n) : ObjectRef(n) {} /*! - * \brief Create a Target given a string - * \param target_str the string to parse - */ + * \brief Create a Target given a string + * \param target_str the string to parse + */ TVM_DLL static Target Create(const std::string& target_str); /*! * \brief Get the current target context from thread local storage. @@ -113,12 +113,11 @@ class Target : public ObjectRef { */ TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - const TargetNode* operator->() const { - return static_cast(get()); - } + const TargetNode* operator->() const { return static_cast(get()); } using ContainerType = TargetNode; class Internal; + private: // enable with syntax. friend class Internal; @@ -140,48 +139,37 @@ class Target : public ObjectRef { namespace target { /*! \return A target for LLVM */ -TVM_DLL Target llvm(const std::vector& options = - std::vector()); +TVM_DLL Target llvm(const std::vector& options = std::vector()); /*! \return A target for CUDA */ -TVM_DLL Target cuda(const std::vector& options = - std::vector()); +TVM_DLL Target cuda(const std::vector& options = std::vector()); /*! \return A target for ROCm */ -TVM_DLL Target rocm(const std::vector& options = - std::vector()); +TVM_DLL Target rocm(const std::vector& options = std::vector()); /*! \return A target for OpenCL */ -TVM_DLL Target opencl(const std::vector& options = - std::vector()); +TVM_DLL Target opencl(const std::vector& options = std::vector()); /*! \return A target for Metal */ -TVM_DLL Target metal(const std::vector& options = - std::vector()); +TVM_DLL Target metal(const std::vector& options = std::vector()); /*! \return A target for rasp */ -TVM_DLL Target rasp(const std::vector& options = - std::vector()); +TVM_DLL Target rasp(const std::vector& options = std::vector()); /*! \return A target for Mali */ -TVM_DLL Target mali(const std::vector& options = - std::vector()); +TVM_DLL Target mali(const std::vector& options = std::vector()); /*! \return A target for Intel Graphics */ -TVM_DLL Target intel_graphics(const std::vector& options = - std::vector()); +TVM_DLL Target intel_graphics(const std::vector& options = std::vector()); /*! \return A target for stackvm */ -TVM_DLL Target stackvm(const std::vector& options = - std::vector()); +TVM_DLL Target stackvm(const std::vector& options = std::vector()); /*! \return A target for external device */ -TVM_DLL Target ext_dev(const std::vector& options = - std::vector()); +TVM_DLL Target ext_dev(const std::vector& options = std::vector()); /*! \return A target for hexagon */ -TVM_DLL Target hexagon(const std::vector& options = - std::vector()); +TVM_DLL Target hexagon(const std::vector& options = std::vector()); } // namespace target /*! @@ -273,12 +261,8 @@ class BuildConfig : public ::tvm::ObjectRef { public: BuildConfig() {} explicit BuildConfig(ObjectPtr n) : ObjectRef(n) {} - const BuildConfigNode* operator->() const { - return static_cast(get()); - } - BuildConfigNode* operator->() { - return static_cast(get_mutable()); - } + const BuildConfigNode* operator->() const { return static_cast(get()); } + BuildConfigNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. * \return The new BuildConfig diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 4466476a18dea..1de15a5bd5268 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -25,6 +25,7 @@ #define TVM_TARGET_TARGET_INFO_H_ #include + #include namespace tvm { diff --git a/include/tvm/te/autodiff.h b/include/tvm/te/autodiff.h index 180ec0bf676cc..e2d379969c656 100644 --- a/include/tvm/te/autodiff.h +++ b/include/tvm/te/autodiff.h @@ -27,6 +27,7 @@ #include #include + #include "tensor.h" namespace tvm { @@ -59,8 +60,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor * dot product. \p input must be an immediate dependency of \p output (must be called from within - * the body of \p output). That is, the function will compute one summand of the adjoint for \p input - * given the adjoint for \p output (which is called \p head here). + * the body of \p output). That is, the function will compute one summand of the adjoint for \p + * input given the adjoint for \p output (which is called \p head here). * * \param output The tensor to differentiate. * \param input The input tensor, which \p output should directly use. @@ -68,7 +69,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * \return The tensor of shape `prefix + input.shape` * representing the partial adjoint of \p input wrt one of its consumers (output) */ -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head); +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head); /*! * \brief Perform reverse mode automatic differentiation. @@ -82,14 +83,12 @@ Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Te * wrt all tensors the output depends on. * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians * will be multiplied (using tensordot axes=`output.shape`). - * Its shape must be of the form `prefix + output.shape`. If the null pointer is provided, - * the identity tensor of shape `output.shape + output.shape` will be used. - * \return An array of adjoints corresponding to \p inputs. + * Its shape must be of the form `prefix + output.shape`. If the null pointer is + * provided, the identity tensor of shape `output.shape + output.shape` will be used. \return An + * array of adjoints corresponding to \p inputs. */ -TVM_DLL Array Gradient( - const Tensor& output, - const Array& inputs, - const Tensor& head = Tensor()); +TVM_DLL Array Gradient(const Tensor& output, const Array& inputs, + const Tensor& head = Tensor()); } // namespace te } // namespace tvm diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 205589928f015..739ea8599179c 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -25,16 +25,15 @@ #define TVM_TE_OPERATION_H_ #include -#include #include - +#include +#include #include #include -#include #include -#include #include +#include namespace tvm { /*! \brief Tensor expression language DSL. */ @@ -46,8 +45,7 @@ namespace te { */ struct TensorDom { // constructor - explicit TensorDom(int ndim) - : data(ndim) {} + explicit TensorDom(int ndim) : data(ndim) {} /*! \brief The domain data */ std::vector > data; }; @@ -64,9 +62,7 @@ class OperationNode : public tir::FunctionBaseNode { /*! \brief additional attributes of the operation*/ Map attrs; /*! \return name of the operation */ - const std::string& func_name() const final { - return name; - } + const std::string& func_name() const final { return name; } /*! * \return The list of iteration variable at root * \note root_iter_vars decides the shape of the outputs. @@ -96,9 +92,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param rmap The replacement map. * \return self if nothing is replaced, otherwise return replaced op. */ - virtual Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const = 0; + virtual Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const = 0; /*! * \brief Propagate the bounds to inputs * \param self The reference to self. @@ -108,11 +103,9 @@ class OperationNode : public tir::FunctionBaseNode { * The function is only asked to fill the bounds for Tensors that * is already in the out_dom_map */ - virtual void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const = 0; + virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. * Set the range of each root_iter_vars in the op to out_dom_map @@ -121,10 +114,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param tensor_dom Domain map of Tensor->access set of each dimension. * \param out_dom_map The output domain map of each IterVar to be setted. */ - virtual void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const = 0; + virtual void GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Build the Realize statement that realizes * the op's output tensors. @@ -133,10 +125,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param body The body that is going to get * \return A realization statement that wraps body. */ - virtual Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + virtual Stmt BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -144,10 +135,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return A statement that add production and wraps consumer. */ - virtual Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const = 0; + virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const = 0; static constexpr const char* _type_key = "Operation"; @@ -169,26 +158,17 @@ class PlaceholderOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -197,9 +177,7 @@ class PlaceholderOpNode : public OperationNode { v->Visit("shape", &shape); v->Visit("dtype", &dtype); } - static Operation make(std::string name, - Array shape, - DataType dtype); + static Operation make(std::string name, Array shape, DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); @@ -219,21 +197,16 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { // override functions Array root_iter_vars() const final; Array output_shape(size_t idx) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; - /*! * \brief A Compute op that compute a tensor on certain domain. */ @@ -247,18 +220,13 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -269,11 +237,8 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } - static Operation make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body); + static Operation make(std::string name, std::string tag, Map attrs, + Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); @@ -300,18 +265,13 @@ class TensorComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -325,14 +285,9 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("input_regions", &input_regions); v->Visit("scalar_inputs", &scalar_inputs); } - static Operation make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, + static Operation make(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; @@ -375,26 +330,17 @@ class ScanOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -407,14 +353,9 @@ class ScanOpNode : public OperationNode { v->Visit("inputs", &inputs); v->Visit("spatial_axis_", &spatial_axis_); } - static Operation make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array input); + static Operation make(std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array input); static constexpr const char* _type_key = "ScanOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); @@ -442,26 +383,17 @@ class ExternOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -472,13 +404,10 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body); + TVM_DLL static Operation make(std::string name, std::string tag, + Map attrs, Array inputs, + Array input_placeholders, Array output_placeholders, + Stmt body); static constexpr const char* _type_key = "ExternOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); @@ -510,26 +439,17 @@ class HybridOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -540,12 +460,9 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body); + TVM_DLL static Operation make(std::string name, std::string tag, + Map attrs, Array inputs, + Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); @@ -575,10 +492,10 @@ TVM_DLL IterVar thread_axis(Range dom, std::string tag); TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function (const Array& i)>; +using FBatchCompute = std::function(const Array& i)>; /*! * \brief create a place holder tensor. @@ -586,8 +503,7 @@ using FBatchCompute = std::function (const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, - DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! @@ -599,11 +515,8 @@ TVM_DLL Tensor placeholder(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, - FCompute fcompute, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}); +TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", + std::string tag = "", Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -614,10 +527,8 @@ TVM_DLL Tensor compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, - FBatchCompute fcompute, - std::string name = "tensor", - std::string tag = "", +TVM_DLL Array compute(Array shape, FBatchCompute fcompute, + std::string name = "tensor", std::string tag = "", Map attrs = {}); /*! @@ -632,45 +543,34 @@ TVM_DLL Array compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs = Array(), - std::string name = "scan", - std::string tag = "", +TVM_DLL Array scan(Array init, Array update, + Array state_placeholder, Array inputs = Array(), + std::string name = "scan", std::string tag = "", Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0]); }; + FCompute fc = [f](const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; + FCompute fc = [f](const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; - return compute(shape, fc, name, tag, attrs); + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2]); }; + return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index a8a02365fbdad..3667e1ed18fbd 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -25,10 +25,10 @@ #ifndef TVM_TE_SCHEDULE_H_ #define TVM_TE_SCHEDULE_H_ -#include +#include #include #include -#include +#include #include #include @@ -84,12 +84,12 @@ class Stage : public ObjectRef { * \param scope The iteration point to carry the schedule. * \return reference to self. */ - TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) + TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline. * \return reference to self. */ - TVM_DLL Stage& compute_inline(); // NOLINT(*) + TVM_DLL Stage& compute_inline(); // NOLINT(*) /*! * \brief Compute the function at group root. * \return reference to self. @@ -131,7 +131,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -141,7 +142,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -169,7 +171,7 @@ class Stage : public ObjectRef { * \param order The order of iteration variable. * \return reference to self. */ - TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) + TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) /*! * \brief Perform tiling on two dimensions * The final loop order from outmost to inner most are @@ -185,16 +187,15 @@ class Stage : public ObjectRef { * \param p_y_inner Inner axis of y dimension * \return reference to self. */ - TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner); + TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) + PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, + IterVar* p_x_inner, IterVar* p_y_inner); /*! * \brief Vectorize iteration. * \param var The axis to be vectorized. * \return reference to self. */ - TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) + TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) /*! * \brief Replace computation of the current stage by tensor intrinsic f. * \param var The axis marks beginning of tensorization. @@ -202,19 +203,19 @@ class Stage : public ObjectRef { * \param f The Tensor compute intrinsics. * \return reference to self. */ - TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) + TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) /*! * \brief Unroll iteration. * \param var The axis to be unrolled. * \return reference to self. */ - TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) + TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) /*! * \brief Parallelize iteration. * \param var The axis to be parallelized. * \return reference to self. */ - TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) + TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) /*! * \brief Annotate the iteration with pragma * @@ -224,9 +225,8 @@ class Stage : public ObjectRef { * * \return reference to self. */ - TVM_DLL Stage& pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) + TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) /*! * \brief Fetch data in advance. * \param domain the tensor to be prefetched @@ -234,7 +234,7 @@ class Stage : public ObjectRef { * \param offset the number of iterations be to fetched in advance * \return reference to self */ - TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*) + TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*) /*! * \brief Set alignment requirement for specific dimension. * @@ -245,17 +245,17 @@ class Stage : public ObjectRef { * \param offset The required offset factor. * \return reference to self */ - TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) + TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*) /*! * \brief Compute current stage with double buffering. * \return reference to self. */ - TVM_DLL Stage& double_buffer(); // NOLINT(*) + TVM_DLL Stage& double_buffer(); // NOLINT(*) /*! * \brief Schedule for OpenGL fragment shader. * \return reference to self. */ - Stage& opengl(); // NOLINT(*) + Stage& opengl(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -297,9 +297,7 @@ class Schedule : public ObjectRef { * \param tensor The tensor * \return The stage corresponding to the tensor's op */ - TVM_DLL Stage operator[](const Tensor& tensor) { - return this->operator[](tensor->op); - } + TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } /*! * \brief Create a new stage group for all intermediate * operations between inputs and outputs. @@ -309,9 +307,8 @@ class Schedule : public ObjectRef { * \param include_inputs Whether include inputs if they are reachable from outputs. * \return The new grouped stage. */ - TVM_DLL Stage create_group(const Array& outputs, - const Array& inputs, - bool include_inputs = false); + TVM_DLL Stage create_group(const Array& outputs, const Array& inputs, + bool include_inputs = false); /*! * \brief create a cache read of original tensor for readers. * This will mutate the body of the readers. @@ -321,9 +318,8 @@ class Schedule : public ObjectRef { * \param readers The readers to redirect to the tensor. * \return The created tensor. */ - TVM_DLL Tensor cache_read(const Tensor& tensor, - const std::string& scope, - const Array& readers); + TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope, + const Array& readers); /*! * \brief Create a cache write tensor for producing tensor. * The the tensor will take over body of original tensor op. @@ -371,9 +367,7 @@ class Schedule : public ObjectRef { * \param factor_axis The position where the new axis is placed. * \return The created factored tensors. */ - TVM_DLL Array rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis = 0); + TVM_DLL Array rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0); /*! * \brief Normalize the schedule. * This is needed before bound inference. @@ -565,9 +559,7 @@ class ScheduleNode : public Object { * \param tensor The candidate tensor. * \return true if the schedule has the tensor. Otherwise, false. */ - TVM_DLL bool Contain(const Tensor& tensor) const { - return Contain(tensor->op); - } + TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } /*! * \brief Create a schedule for array of ops(and their dependencies). @@ -585,9 +577,7 @@ class ScheduleNode : public Object { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ -inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); -} +inline Schedule create_schedule(Array ops) { return ScheduleNode::make(ops); } /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Object { @@ -666,10 +656,7 @@ class SplitNode : public IterVarRelationNode { v->Visit("nparts", &nparts); } - static IterVarRelation make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, + static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); static constexpr const char* _type_key = "Split"; @@ -694,8 +681,7 @@ class FuseNode : public IterVarRelationNode { v->Visit("fused", &fused); } - static IterVarRelation make( - IterVar outer, IterVar inner, IterVar fused); + static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused); static constexpr const char* _type_key = "Fuse"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); @@ -724,7 +710,6 @@ class RebaseNode : public IterVarRelationNode { TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; - /*! * \brief Singleton iterator [0, 1) */ @@ -733,9 +718,7 @@ class SingletonNode : public IterVarRelationNode { /*! \brief The singleton iterator */ IterVar iter; - void VisitAttrs(AttrVisitor* v) { - v->Visit("iter", &iter); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } static IterVarRelation make(IterVar iter); @@ -753,9 +736,7 @@ class SpecializedConditionNode : public Object { */ Array clauses; - void VisitAttrs(AttrVisitor* v) { - v->Visit("clauses", &clauses); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); } static constexpr const char* _type_key = "SpecializedCondition"; TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object); @@ -792,19 +773,13 @@ class SpecializedCondition : public ObjectRef { }; // implementations -inline const StageNode* Stage::operator->() const { - return static_cast(get()); -} -inline StageNode* Stage::operator->() { - return static_cast(get_mutable()); -} +inline const StageNode* Stage::operator->() const { return static_cast(get()); } +inline StageNode* Stage::operator->() { return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { return static_cast(get()); } -inline ScheduleNode* Schedule::operator->() { - return static_cast(get_mutable()); -} +inline ScheduleNode* Schedule::operator->() { return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(get()); diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index 618fc229d2f36..a4efa7a949908 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -90,10 +90,8 @@ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivia * buffer assignment of input and outputs. * \return Transformed stmt. */ -Stmt SchedulePostProcRewriteForTensorCore( - Stmt stmt, - Schedule schedule, - Map extern_buffer); +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer); /*! * \brief Postprocessing the Stmt generated by ScheduleOps to create @@ -111,8 +109,7 @@ Stmt SchedulePostProcRewriteForTensorCore( * \param body The body of the function. * \param bindings potential Tensor to Buffer bindings for the Tensors in the body. */ -PrimFunc SchedulePostProcToPrimFunc(Array arg_list, - Stmt body, +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> bindings); } // namespace te diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index c247dca3ff454..f82df8cf0aea5 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -24,15 +24,15 @@ #ifndef TVM_TE_TENSOR_H_ #define TVM_TE_TENSOR_H_ -#include #include +#include #include #include #include -#include -#include #include +#include +#include namespace tvm { namespace te { @@ -78,8 +78,8 @@ class Tensor : public ObjectRef { * \param args The indices * \return the result expression representing tensor read. */ - template - inline PrimExpr operator()(Args&& ...args) const { + template + inline PrimExpr operator()(Args&&... args) const { Array indices{std::forward(args)...}; return operator()(indices); } @@ -119,9 +119,7 @@ class Tensor : public ObjectRef { * This is only valid when all the coordinates are fully specified. * \return the corresponding expression of this slice. */ - inline operator PrimExpr() const { - return tensor_(indices_); - } + inline operator PrimExpr() const { return tensor_(indices_); } private: const Tensor& tensor_; @@ -132,9 +130,7 @@ class Tensor : public ObjectRef { * \param i the index of the coordinate * \return the subsequent slice. */ - inline Slice operator[](PrimExpr i) const { - return Slice(*this, {i}); - } + inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } /*! \brief specify container node */ using ContainerType = TensorNode; }; @@ -180,57 +176,46 @@ class TensorNode : public Object { v->Visit("op", &op); v->Visit("value_index", &value_index); } - TVM_DLL static Tensor make(Array shape, - DataType dtype, - Operation op, - int value_index); + TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); static constexpr const char* _type_key = "Tensor"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); }; - // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { return static_cast(get()); } -inline size_t Tensor::ndim() const { - return (*this)->shape.size(); -} +inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { if (get() == other.get()) return true; if (get() == nullptr || other.get() == nullptr) return false; if ((*this)->op.defined() || other->op.defined()) { - return (*this)->op == other->op && - (*this)->value_index == other->value_index; + return (*this)->op == other->op && (*this)->value_index == other->value_index; } else { return false; } } -inline bool Tensor::operator!=(const Tensor& other) const { - return !(*this == other); -} +inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); } // macro to turn every operation of slice to expression -#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ - inline PrimExpr operator Op (const Tensor::Slice& a) { \ - return Op a.operator PrimExpr() ; \ - } \ +#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ + inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } -#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ - template \ - inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ - return a.operator PrimExpr() Op b; \ - } \ - template \ - inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ - return a Op b.operator PrimExpr(); \ - } \ - inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ - return a.operator PrimExpr() Op b.operator PrimExpr(); \ +#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ + template \ + inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ + return a.operator PrimExpr() Op b; \ + } \ + template \ + inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ + return a Op b.operator PrimExpr(); \ + } \ + inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ + return a.operator PrimExpr() Op b.operator PrimExpr(); \ } DEFINE_OVERLOAD_SLICE_UNARY_OP(!); @@ -254,8 +239,7 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash { -}; +struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash {}; template <> struct hash<::tvm::te::Tensor> { @@ -263,7 +247,7 @@ struct hash<::tvm::te::Tensor> { ::tvm::ObjectHash hasher; if (k.defined() && k->op.defined()) { return hasher(k->op); - } else{ + } else { return hasher(k); } } diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index c964d3e5491b1..252c5f59acfff 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -100,14 +100,9 @@ class TensorIntrinNode : public Object { v->Visit("reduce_update", &reduce_update); } - TVM_DLL static TensorIntrin make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update); + TVM_DLL static TensorIntrin make(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update); static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); @@ -144,7 +139,6 @@ class TensorIntrinCallNode : public Object { /*! \brief regions of input tensors */ Array regions; - /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis @@ -161,11 +155,8 @@ class TensorIntrinCallNode : public Object { v->Visit("reduce_axis", &reduce_axis); v->Visit("scalar_inputs", &scalar_inputs); } - static TensorIntrinCall make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, - Array scalar_inputs); + static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index f7a89f50ef61e..b0f409c4a5654 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -29,6 +29,7 @@ #include #include #include + #include namespace tvm { @@ -75,8 +76,7 @@ TVM_DLL bool HasSideEffect(const PrimExpr& expr); * \param vset_contains The check function to see if var is in the vset. * \return Whether e uses vset. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, - std::function vset_contains); +TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); /*! * \brief Whether e expression used var. @@ -85,12 +85,9 @@ TVM_DLL bool ExprUseVar(const PrimExpr& expr, * \return Whether e uses v. */ inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { - return var.get() == node; - }); + return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); } - /*! * \brief Verifies whether the IR stmt or Expr is in SSA form. * That is: each Var is defined and assigned once(in Let/For) @@ -133,8 +130,7 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); * \return valid Whether it is a valid GPU code * */ -TVM_DLL bool VerifyGPUCode(const PrimFunc& func, - Map constraints); +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); // Pass variants of verification analysis // directly throws RuntimeError when verification fails. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 08a8e69a45326..5d4e86026b396 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -24,11 +24,11 @@ #ifndef TVM_TIR_BUFFER_H_ #define TVM_TIR_BUFFER_H_ -#include #include +#include #include -#include +#include namespace tvm { namespace tir { @@ -76,8 +76,7 @@ class Buffer : public ObjectRef { * \param content_lanes The number of lanes for the (data) type. * \param offset The offset of ptr. */ - TVM_DLL PrimExpr access_ptr(int access_mask, - DataType ptr_type = DataType::Handle(), + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0)) const; /*! @@ -155,15 +154,10 @@ class BufferNode : public Object { bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { // Use DefEqual as buffer can define variables // in its semantics, skip name as name is not important. - return - equal.DefEqual(data, other->data) && - equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && - equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && - equal(scope, other->scope) && - equal(data_alignment, other->data_alignment) && - equal(buffer_type, other->buffer_type); + return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && + equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } void SHashReduce(SHashReducer hash_reduce) const { @@ -184,15 +178,9 @@ class BufferNode : public Object { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL static Buffer make(Var ptr, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, + TVM_DLL static Buffer make(Var ptr, DataType dtype, Array shape, + Array strides, PrimExpr elem_offset, std::string name, + std::string scope, int data_alignment, int offset_factor, BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; @@ -213,8 +201,7 @@ inline const BufferNode* Buffer::operator->() const { * \return The created buffer. * \sa BufferNode::make for complete constructor. */ -TVM_DLL Buffer decl_buffer(Array shape, - DataType dtype = DataType::Float(32), +TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 4343370571679..0a20db6a0a632 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -25,16 +25,14 @@ #ifndef TVM_TIR_DATA_LAYOUT_H_ #define TVM_TIR_DATA_LAYOUT_H_ - #include #include -#include +#include #include -#include +#include #include -#include - +#include namespace tvm { namespace tir { @@ -63,18 +61,12 @@ class LayoutAxis { } // return the primal axis. If it is already primal, return itself. - const LayoutAxis& ToPrimal() const { - return IsPrimal() ? *this : ToDual(); - } + const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); } // return the subordinate axis. If it is already subordinate, return itself. - const LayoutAxis& ToSubordinate() const { - return IsPrimal() ? ToDual() : *this; - } + const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; } - inline bool operator==(const LayoutAxis& rhs) const { - return name_ == rhs.name_; - } + inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; } friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { os << l.name(); @@ -136,7 +128,7 @@ class Layout : public ObjectRef { explicit Layout(const Array& axes); /*! \brief construct from a string */ - Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) + Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) /*! * \brief construct from a string. @@ -146,23 +138,19 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) + Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container * \return the pointer to the internal node container */ - const LayoutNode* operator->() const { - return static_cast(get()); - } + const LayoutNode* operator->() const { return static_cast(get()); } /*! * \brief access the internal node container * \return the pointer to the internal node container */ - LayoutNode* operator->() { - return static_cast(get_mutable()); - } + LayoutNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Return an undefined layout. @@ -190,8 +178,7 @@ class Layout : public ObjectRef { * \param factor size of the sub-dimension. * \return A newly constructed Layout object. */ - Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const; - + Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const; /*! \return number of dimensions */ inline size_t ndim() const { @@ -292,9 +279,7 @@ class Layout : public ObjectRef { * \param rhs Another layout. * \return whether the two layouts are equal. */ - inline bool Equals(const Layout &rhs) const { - return name() == rhs.name(); - } + inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); } /*! * \brief allow output string of layout to ostream diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index bf0d4f985a92e..a9f34d2df8917 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -25,20 +25,20 @@ #ifndef TVM_TIR_EXPR_H_ #define TVM_TIR_EXPR_H_ -#include +#include #include #include +#include #include #include -#include -#include #include +#include -#include #include -#include #include #include +#include +#include #include namespace tvm { @@ -62,9 +62,7 @@ class StringImmNode : public PrimExprNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } TVM_DLL PrimExpr static make(std::string value); @@ -110,7 +108,7 @@ class CastNode : public PrimExprNode { * \brief Base template to implement binary ops. * \tparam T The type of the child class. */ -template +template class BinaryOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -125,10 +123,7 @@ class BinaryOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -215,7 +210,7 @@ class MaxNode : public BinaryOpNode { * \brief Base template to implement comparison ops. * \tparam T The type of the child class. */ -template +template class CmpOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -230,10 +225,7 @@ class CmpOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -307,10 +299,7 @@ class AndNode : public PrimExprNode { } bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -340,10 +329,7 @@ class OrNode : public PrimExprNode { } bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -408,11 +394,8 @@ class SelectNode : public PrimExprNode { } bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(condition, other->condition) && - equal(true_value, other->true_value) && - equal(false_value, other->false_value); + return equal(dtype, other->dtype) && equal(condition, other->condition) && + equal(true_value, other->true_value) && equal(false_value, other->false_value); } void SHashReduce(SHashReducer hash_reduce) const { @@ -452,10 +435,8 @@ class BufferLoadNode : public PrimExprNode { } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer, other->buffer) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -470,8 +451,7 @@ class BufferLoadNode : public PrimExprNode { class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, - Array indices); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); }; @@ -507,11 +487,8 @@ class LoadNode : public PrimExprNode { } bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer_var, other->buffer_var) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -553,11 +530,8 @@ class RampNode : public PrimExprNode { } bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(base, other->base) && - equal(stride, other->stride) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && + equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -588,10 +562,7 @@ class BroadcastNode : public PrimExprNode { } bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(value, other->value) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -626,11 +597,8 @@ class LetNode : public PrimExprNode { } bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -668,8 +636,7 @@ class FunctionBaseNode : public Object { return this == other; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -731,13 +698,9 @@ class CallNode : public PrimExprNode { } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(name, other->name) && - equal(args, other->args) && - equal(call_type, other->call_type) && - equal(func, other->func) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && + equal(call_type, other->call_type) && equal(func, other->func) && + equal(value_index, other->value_index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -749,18 +712,13 @@ class CallNode : public PrimExprNode { hash_reduce(value_index); } - TVM_DLL static PrimExpr make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func = FunctionRef(), + TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array args, + CallType call_type, FunctionRef func = FunctionRef(), int value_index = 0); /*! \return Whether call node is pure. */ bool is_pure() const { - return (call_type == PureExtern || - call_type == PureIntrinsic || - call_type == Halide); + return (call_type == PureExtern || call_type == PureIntrinsic || call_type == Halide); } /*! @@ -768,10 +726,7 @@ class CallNode : public PrimExprNode { * \param intrin_name The name of the intrinsic. */ bool is_intrinsic(const char* intrin_name) const { - return - ((call_type == Intrinsic || - call_type == PureIntrinsic) && - name == intrin_name); + return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); } /*! \return Whether call node can be vectorized. */ @@ -818,10 +773,8 @@ class ShuffleNode : public PrimExprNode { } bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(vectors, other->vectors) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(vectors, other->vectors) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -880,9 +833,7 @@ class CommReducerNode : public Object { /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; /*! \brief construct CommReducer from args, result and identity_element */ - TVM_DLL static CommReducer make(Array lhs, - Array rhs, - Array result, + TVM_DLL static CommReducer make(Array lhs, Array rhs, Array result, Array identity_element); void VisitAttrs(AttrVisitor* v) { @@ -893,11 +844,8 @@ class CommReducerNode : public Object { } bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { - return - equal.DefEqual(lhs, other->lhs) && - equal.DefEqual(rhs, other->rhs) && - equal(result, other->result) && - equal(identity_element, other->identity_element); + return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && + equal(result, other->result) && equal(identity_element, other->identity_element); } void SHashReduce(SHashReducer hash_reduce) const { @@ -916,9 +864,7 @@ class CommReducerNode : public Object { inline const CommReducerNode* CommReducer::get() const { return static_cast(data_.get()); } -inline const CommReducerNode* CommReducer::operator->() const { - return get(); -} +inline const CommReducerNode* CommReducer::operator->() const { return get(); } /*! \brief Reduction operator operator */ class ReduceNode : public PrimExprNode { @@ -938,11 +884,8 @@ class ReduceNode : public PrimExprNode { int value_index; /*! \brief construct expr from op and rdom */ - TVM_DLL static PrimExpr make(CommReducer combiner, - Array src, - Array rdom, - PrimExpr condition, - int value_index); + TVM_DLL static PrimExpr make(CommReducer combiner, Array src, Array rdom, + PrimExpr condition, int value_index); void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -955,13 +898,9 @@ class ReduceNode : public PrimExprNode { bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { // check axis first so IterVars can define the necessary variables. - return - equal(dtype, other->dtype) && - equal(axis, other->axis) && - equal(combiner, other->combiner) && - equal(source, other->source) && - equal(condition, other->condition) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(axis, other->axis) && + equal(combiner, other->combiner) && equal(source, other->source) && + equal(condition, other->condition) && equal(value_index, other->value_index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -982,17 +921,12 @@ class AnyNode : public PrimExprNode { public: void VisitAttrs(AttrVisitor* v) {} - bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} /*! \brief Convert to var. */ - Var ToVar() const { - return Var("any_dim", DataType::Int(32)); - } + Var ToVar() const { return Var("any_dim", DataType::Int(32)); } TVM_DLL static PrimExpr make(); @@ -1000,7 +934,6 @@ class AnyNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; - /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map @@ -1009,7 +942,7 @@ class AnyNode : public PrimExprNode { * \tparam K the key of the Map. * \tparam V the value of the Map. */ -template +template inline std::unordered_map as_unordered_map(const Map& dmap) { std::unordered_map ret; for (auto kv : dmap) { @@ -1176,7 +1109,7 @@ constexpr const char* tvm_call_packed = "tvm_call_packed"; * return 0; * } */ -constexpr const char *tvm_call_trace_packed = "tvm_call_trace_packed"; +constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; /*! * \brief See pesudo code * Mark the content as thread local context, can get optimized @@ -1223,8 +1156,7 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; * TVMRetValue(value_stack + end, tcode_stack + end)); * } */ -constexpr const char *tvm_call_trace_packed_lowered = - "tvm_call_trace_packed_lowered"; +constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; /*! * \brief See pseudo code * @@ -1234,22 +1166,43 @@ constexpr const char *tvm_call_trace_packed_lowered = * } */ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; + /*! * \brief See pseudo code * - * Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) { - * return (value passed in by warp indicated by warp_id); + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); + * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); * } * * Parameter warp_id indicates the source thread ID in a warp. * + * Parameter offset indicates the relative distance to this_warp_id. + * * Parameter width indicates the number of threads involved in one - * shuffle. See CUDA document for __shfl. + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. * * Parameter warp_size is the size of a warp, which helps a backend * to determine wheter the width paramter is legal. + * */ constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; +constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; +constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; +constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; + /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. @@ -1347,7 +1300,7 @@ enum TVMStructFieldKind : int { kTVMValueContent, kTVMValueKindBound_ }; -} // namespace intrinsic +} // namespace intrinsic } // namespace tir } // namespace tvm @@ -1356,7 +1309,7 @@ namespace tvm { namespace runtime { // Additional implementattion overloads for PackedFunc. -template<> +template <> struct PackedFuncValueConverter { // common rule for RetValue and ArgValue static tvm::Integer From(const TVMPODValue_& val) { @@ -1374,7 +1327,6 @@ struct PackedFuncValueConverter { namespace std { template <> -struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash { -}; -} +struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash {}; +} // namespace std #endif // TVM_TIR_EXPR_H_ diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index dcf04c346454f..15ec3d2ae0bf2 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -71,22 +71,19 @@ namespace tir { * \tparam FType function signiture * This type if only defined for FType with function signiture R(const Expr&, Args...) */ -template +template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT { \ - return VisitExprDefault_(op, std::forward(args)...); \ - } +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } -#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class ExprFunctor { private: using TSelf = ExprFunctor; @@ -152,7 +149,7 @@ class ExprFunctor { virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Object* op, Args ...) { + virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -205,8 +202,7 @@ class ExprFunctor { /*! * \brief ExprVisitor */ -class TVM_DLL ExprVisitor : - public ExprFunctor { +class TVM_DLL ExprVisitor : public ExprFunctor { public: using ExprFunctor::operator(); @@ -251,8 +247,7 @@ class TVM_DLL ExprVisitor : /*! * \brief ExprMutator that mutates expressions. */ -class TVM_DLL ExprMutator : - protected ExprFunctor { +class TVM_DLL ExprMutator : protected ExprFunctor { public: using ExprFunctor::operator(); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1866f2f1f891a..919391e36b96c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -25,11 +25,11 @@ #define TVM_TIR_FUNCTION_H_ #include -#include #include +#include #include -#include +#include namespace tvm { namespace tir { @@ -104,12 +104,9 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. - return - equal.DefEqual(params, other->params) && - equal(buffer_map, other->buffer_map) && - equal(ret_type, other->ret_type) && - equal(body, other->body) && - equal(attrs, other->attrs); + return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && + equal(ret_type, other->ret_type) && equal(body, other->body) && + equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -146,9 +143,7 @@ class PrimFunc : public BaseFunc { * \param buffer_map The buffer map for parameter buffer unpacking. * \param attrs Additional function attributes. */ - TVM_DLL PrimFunc(Array params, - Stmt body, - Type ret_type = VoidType(), + TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = NullValue>(), DictAttrs attrs = NullValue()); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index b54aa9aaf7cc0..5884942ebef11 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -33,9 +33,8 @@ #include #include -#include #include - +#include namespace tvm { @@ -551,7 +550,7 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ - } \ + } TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp2); @@ -570,7 +569,12 @@ TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(cosh); TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sinh); +TVM_DECLARE_INTRIN_UNARY(asin); +TVM_DECLARE_INTRIN_UNARY(acos); TVM_DECLARE_INTRIN_UNARY(atan); +TVM_DECLARE_INTRIN_UNARY(acosh); +TVM_DECLARE_INTRIN_UNARY(asinh); +TVM_DECLARE_INTRIN_UNARY(atanh); namespace tir { /*! @@ -580,8 +584,8 @@ namespace tir { * \return the result expression. * \tparam ValueType The constant value type */ -template::value>::type> +template ::value>::type> inline PrimExpr make_const(DataType t, ValueType value); /*! * \brief Make a const zero expr. @@ -594,17 +598,13 @@ inline PrimExpr make_zero(DataType t); * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_true(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 1); -} +inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); } /*! * \brief Make a constant false expression. * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_false(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 0); -} +inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); } /*! * \brief Get x as constant int expression. * \param x The expression @@ -641,9 +641,7 @@ inline bool is_no_op(const tir::Stmt& stmt); * \note This only return true for integer types. * \return whether x is constant 1 */ -inline bool is_one(const PrimExpr& x) { - return is_const_int(x, 1); -} +inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } /*! * \brief Check whether x is a constant integer 0 @@ -651,9 +649,7 @@ inline bool is_one(const PrimExpr& x) { * \return whether x is constant 0 * \note This only return true for integer types. */ -inline bool is_zero(const PrimExpr& x) { - return is_const_int(x, 0); -} +inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } /*! * \brief Check whether x is a constant. @@ -724,7 +720,7 @@ inline bool is_no_op(const tir::Stmt& stmt) { return false; } -template +template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { if (t.is_int()) return IntImm(t, static_cast(value)); if (t.is_uint()) { @@ -751,13 +747,12 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return PrimExpr(); } -template +template inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return tir::BroadcastNode::make( - MakeConstScalar(t.element_of(), value), t.lanes()); + return tir::BroadcastNode::make(MakeConstScalar(t.element_of(), value), t.lanes()); } } @@ -770,44 +765,34 @@ inline PrimExpr make_zero(DataType t) { } // namespace tir // additional const expression overloading -#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ - inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\ - a = OpFunc(a, b); \ - return a; \ +#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ + inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \ + a = OpFunc(a, b); \ + return a; \ } -#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(float a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, tir::make_const(DataType::Float(64), b)); \ +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(tir::make_const(b.dtype(), a), b); \ + } \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, tir::make_const(DataType::Float(64), b)); \ } -#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, bool b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(bool a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); @@ -829,8 +814,8 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); @@ -843,7 +828,7 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); * \note The call to this function will always results in a compiler error. * \tparam TA Any class type. */ -template +template inline void DivAmbiguityError(const TA& a) { constexpr bool div_ambiguity = !std::is_class::value; static_assert(div_ambiguity, @@ -859,19 +844,19 @@ inline void DivAmbiguityError(const TA& a) { // to use the specific division function. // The second template argument is necessary to make sure the // code compiles lazily by the compiler during invocation. -template +template inline PrimExpr operator/(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator%(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index aed8b5c77ae58..115d05c48d1fd 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -26,10 +26,10 @@ #include -#include #include -#include +#include #include +#include namespace tvm { namespace tir { @@ -69,10 +69,8 @@ class LetStmtNode : public StmtNode { } bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -116,11 +114,8 @@ class AttrStmtNode : public StmtNode { } bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { - return - equal(node, other->node) && - equal(attr_key, other->attr_key) && - equal(value, other->value) && - equal(body, other->body); + return equal(node, other->node) && equal(attr_key, other->attr_key) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -130,10 +125,7 @@ class AttrStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(ObjectRef node, - std::string type_key, - PrimExpr value, - Stmt body); + TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body); static constexpr const char* _type_key = "AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); @@ -161,10 +153,8 @@ class AssertStmtNode : public StmtNode { } bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(message, other->message) && - equal(body, other->body); + return equal(condition, other->condition) && equal(message, other->message) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -216,11 +206,8 @@ class StoreNode : public StmtNode { } bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var) && - equal(value, other->value) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(buffer_var, other->buffer_var) && equal(value, other->value) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -230,10 +217,7 @@ class StoreNode : public StmtNode { hash_reduce(predicate); } - TVM_DLL static Stmt make(Var buffer_var, - PrimExpr value, - PrimExpr index, - PrimExpr predicate); + TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); static constexpr const char* _type_key = "Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); @@ -265,10 +249,8 @@ class BufferStoreNode : public StmtNode { } bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(value, other->value) && - equal(indices, other->indices); + return equal(buffer, other->buffer) && equal(value, other->value) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -287,9 +269,7 @@ class BufferStoreNode : public StmtNode { */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, - PrimExpr value, - Array indices); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; @@ -323,11 +303,8 @@ class BufferRealizeNode : public StmtNode { } bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(bounds, other->bounds) && - equal(condition, other->condition) && - equal(body, other->body); + return equal(buffer, other->buffer) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -338,12 +315,8 @@ class BufferRealizeNode : public StmtNode { } BufferRealizeNode() = default; - BufferRealizeNode(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) - : buffer(buffer), bounds(bounds), - condition(condition), body(body) {} + BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) + : buffer(buffer), bounds(bounds), condition(condition), body(body) {} static constexpr const char* _type_key = "BufferRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); @@ -355,10 +328,7 @@ class BufferRealizeNode : public StmtNode { */ class BufferRealize : public Stmt { public: - TVM_DLL explicit BufferRealize(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body); + TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); }; @@ -387,11 +357,8 @@ class ProvideNode : public StmtNode { } bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(value, other->value) && - equal(args, other->args); + return equal(func, other->func) && equal(value_index, other->value_index) && + equal(value, other->value) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { @@ -401,10 +368,7 @@ class ProvideNode : public StmtNode { hash_reduce(args); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - PrimExpr value, - Array args); + TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, Array args); static constexpr const char* _type_key = "Provide"; TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); @@ -435,12 +399,9 @@ class AllocateNode : public StmtNode { } bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { - return - equal.DefEqual(buffer_var, other->buffer_var) && - equal(dtype, other->dtype) && - equal(extents, other->extents) && - equal(condition, other->condition) && - equal(body, other->body); + return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && + equal(extents, other->extents) && equal(condition, other->condition) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -451,28 +412,22 @@ class AllocateNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body); + TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array extents, + PrimExpr condition, Stmt body); /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { - return constant_allocation_size(extents); - } + int32_t constant_allocation_size() const { return constant_allocation_size(extents); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size( - const Array& extents); + TVM_DLL static int32_t constant_allocation_size(const Array& extents); static constexpr const char* _type_key = "Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); @@ -484,18 +439,13 @@ class FreeNode : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); } bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var); + return equal(buffer_var, other->buffer_var); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer_var); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); } TVM_DLL static Stmt make(Var buffer_var); @@ -533,21 +483,13 @@ class RealizeNode : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body); + TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, Region bounds, + PrimExpr condition, Stmt body); bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && - equal(bounds, other->bounds) && - equal(condition, other->condition) && - equal(body, other->body); + return equal(func, other->func) && equal(value_index, other->value_index) && + equal(dtype, other->dtype) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -573,27 +515,19 @@ class SeqStmtNode : public StmtNode { Array seq; /*! \return get the size of the sequence */ - size_t size() const { - return seq.size(); - } + size_t size() const { return seq.size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return seq[index]; - } + Stmt operator[](size_t index) const { return seq[index]; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("seq", &seq); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("seq", &seq); } bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { return equal(seq, other->seq); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(seq); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } static constexpr const char* _type_key = "SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); @@ -609,15 +543,11 @@ class SeqStmt : public Stmt { TVM_DLL explicit SeqStmt(Array seq); /*! \return get the size of the sequence */ - size_t size() const { - return operator->()->size(); - } + size_t size() const { return operator->()->size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return (*(operator->()))[index]; - } + Stmt operator[](size_t index) const { return (*(operator->()))[index]; } /*! * \brief Construct a sequence statement by flattening * all the arrays and sequences in the arguments @@ -634,19 +564,17 @@ class SeqStmt : public Stmt { * \tparam Args arguments * \return The constructed statement */ - template + template static Stmt Flatten(Args&&... seq_args) { Array seq; - runtime::detail::for_each( - Flattener(&seq), std::forward(seq_args)...); + runtime::detail::for_each(Flattener(&seq), std::forward(seq_args)...); if (seq.size() == 1) return seq[0]; return SeqStmt(seq); } /*! \brief Helper class to flatten sequence of arguments into Array. */ class Flattener { public: - explicit Flattener(Array* seq) - : seq_(seq) {} + explicit Flattener(Array* seq) : seq_(seq) {} void operator()(size_t i, const Stmt& stmt) const { if (!stmt.defined()) return; @@ -657,7 +585,7 @@ class SeqStmt : public Stmt { } } - template + template void operator()(size_t i, const T& seq) const { for (auto v : seq) { this->operator()(0, v); @@ -690,10 +618,8 @@ class IfThenElseNode : public StmtNode { } bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(then_case, other->then_case) && - equal(else_case, other->else_case); + return equal(condition, other->condition) && equal(then_case, other->then_case) && + equal(else_case, other->else_case); } void SHashReduce(SHashReducer hash_reduce) const { @@ -719,17 +645,13 @@ class EvaluateNode : public StmtNode { /*! \brief The expression to be evaluated. */ PrimExpr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } TVM_DLL static Stmt make(PrimExpr v); @@ -752,9 +674,7 @@ enum class ForType : int { // Kevice api of for loop // kept for backward compatibility // consider refactor and remove later. -enum class DeviceAPI: int { - None = 0 -}; +enum class DeviceAPI : int { None = 0 }; /*! * \brief A for loop, with poissible type annotations. @@ -784,12 +704,8 @@ class ForNode : public StmtNode { /*! \brief The body of the for loop. */ Stmt body; - TVM_DLL static Stmt make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body); + TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, + DeviceAPI device_api, Stmt body); void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); @@ -801,13 +717,9 @@ class ForNode : public StmtNode { } bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { - return - equal.DefEqual(loop_var, other->loop_var) && - equal(min, other->min) && - equal(extent, other->extent) && - equal(for_type, other->for_type) && - equal(device_api, other->device_api) && - equal(body, other->body); + return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && + equal(extent, other->extent) && equal(for_type, other->for_type) && + equal(device_api, other->device_api) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -819,7 +731,6 @@ class ForNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; @@ -840,9 +751,7 @@ class PrefetchNode : public StmtNode { } bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(bounds, other->bounds); + return equal(buffer, other->buffer) && equal(bounds, other->bounds); } void SHashReduce(SHashReducer hash_reduce) const { @@ -851,8 +760,7 @@ class PrefetchNode : public StmtNode { } PrefetchNode() = default; - PrefetchNode(Buffer buffer, Array bounds) - : buffer(buffer), bounds(bounds) {} + PrefetchNode(Buffer buffer, Array bounds) : buffer(buffer), bounds(bounds) {} static constexpr const char* _type_key = "Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); @@ -930,6 +838,8 @@ constexpr const char* loop_scope = "loop_scope"; constexpr const char* reduce_scope = "reduce_scope"; /*! \brief Mark region is guarded by the pragma extension */ constexpr const char* pragma_scope_prefix = "pragma_"; +/*! \brief Import C source or file into the final code gen module */ +constexpr const char* pragma_import_c = "pragma_import_c"; /*! \brief Import llvm source or file into the final code gen module */ constexpr const char* pragma_import_llvm = "pragma_import_llvm"; /*! \brief Try to modify the AST to support Tensor Core */ @@ -1022,9 +932,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \return Expr a expression with dtype. */ inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::CallNode::make(dtype, - "type_annotation", {}, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); } // overload printing of for type. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0f8038e13ca62..052ea92ce41ec 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -26,14 +26,14 @@ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ -#include #include +#include #include -#include #include +#include -#include #include +#include namespace tvm { namespace tir { @@ -42,22 +42,18 @@ namespace tir { * \tparam FType The function signature. * \sa ExprFunctor */ -template +template class StmtFunctor; -#define STMT_FUNCTOR_DEFAULT { \ - return VisitStmtDefault_(op, std::forward(args)...); \ - } - -#define IR_STMT_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define STMT_FUNCTOR_DEFAULT \ + { return VisitStmtDefault_(op, std::forward(args)...); } +#define IR_STMT_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class StmtFunctor { private: using TSelf = StmtFunctor; @@ -74,9 +70,7 @@ class StmtFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Stmt& n, Args... args) { - return VisitStmt(n, std::forward(args)...); - } + R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The stmt node. @@ -103,7 +97,7 @@ class StmtFunctor { virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmtDefault_(const Object* op, Args ...) { + virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -137,8 +131,7 @@ class StmtFunctor { /*! * \brief StmtVisitor. */ -class TVM_DLL StmtVisitor : - protected StmtFunctor { +class TVM_DLL StmtVisitor : protected StmtFunctor { public: using StmtFunctor::operator(); @@ -173,8 +166,7 @@ class TVM_DLL StmtVisitor : /*! * \brief StmtMutator that mutates the statements. */ -class TVM_DLL StmtMutator : - protected StmtFunctor { +class TVM_DLL StmtMutator : protected StmtFunctor { public: /*! * \brief Mutate stmt. @@ -210,7 +202,7 @@ class TVM_DLL StmtMutator : * * \return The result object pointer. */ - template + template ObjectPtr CopyOnWrite(const TNode* node) { if (allow_copy_on_write_) { // return the old node. @@ -244,9 +236,7 @@ class TVM_DLL StmtMutator : * or have a class sub-class both StmtMutator and ExprMutator * and redirect Mutate to ExprMutator::Mutate(Expr) */ - virtual PrimExpr VisitExpr(const PrimExpr& e) { - return e; - } + virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; } // statement visitor Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; @@ -275,8 +265,7 @@ class TVM_DLL StmtMutator : * \param fmutate The mutate function, can be nullptr, which defaults to Visit. * \return The mutated result. */ - Stmt VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, + Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate = nullptr); // internal helper. class Internal; @@ -285,39 +274,31 @@ class TVM_DLL StmtMutator : /*! * \brief Visitor that recursively visit stmts and exprs on them. */ -class StmtExprVisitor : - public StmtVisitor, - public ExprVisitor { +class StmtExprVisitor : public StmtVisitor, public ExprVisitor { public: using StmtVisitor::operator(); using ExprVisitor::operator(); protected: - using StmtVisitor::VisitStmt; using ExprVisitor::VisitExpr; + using StmtVisitor::VisitStmt; - void VisitExpr(const PrimExpr& e) override { - return ExprVisitor::VisitExpr(e); - } + void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); } }; /*! * \brief Mutator that recursively mutates stmts and exprs on them. */ -class StmtExprMutator : - public StmtMutator, - public ExprMutator { +class StmtExprMutator : public StmtMutator, public ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: - using StmtMutator::VisitExpr; using ExprMutator::VisitExpr; + using StmtMutator::VisitExpr; - PrimExpr VisitExpr(const PrimExpr& e) override { - return ExprMutator::VisitExpr(e); - } + PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); } }; /*! @@ -335,8 +316,7 @@ class StmtExprMutator : * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -TVM_DLL Stmt IRTransform(Stmt stmt, - const runtime::PackedFunc& preorder, +TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, Optional> only_enable = NullOpt); @@ -354,8 +334,7 @@ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function(const Var& var)> vmap); +TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& var)> vmap); /*! * \brief Substitute the var specified by vmap. @@ -363,8 +342,7 @@ TVM_DLL Stmt Substitute(Stmt stmt, * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. * \return The result. */ -TVM_DLL PrimExpr Substitute(PrimExpr expr, - std::function(const Var& var)> vmap); +TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); /*! * \brief Sugar for substitute via a given map. @@ -373,7 +351,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, * \return The result. * \tparam T the input type, can be PrimExpr or Stmt. */ -template +template inline T Substitute(T input, const Map& value_map) { auto vmap = [&](const Var& var) -> Optional { auto it = value_map.find(var); @@ -390,9 +368,8 @@ inline T Substitute(T input, const Map& value_map) { * \return The result. * \tparam T the input type, can be PrimExpr or Stmt. */ -template -inline T Substitute(T input, - const std::unordered_map& value_map) { +template +inline T Substitute(T input, const std::unordered_map& value_map) { auto vmap = [&](const Var& var) -> Optional { auto it = value_map.find(var.get()); if (it != value_map.end()) return (*it).second; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index abf8b1ce99655..13e1e2510e29d 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -35,11 +35,11 @@ namespace tir { namespace transform { using tvm::transform::Pass; -using tvm::transform::PassNode; -using tvm::transform::PassInfo; -using tvm::transform::PassInfoNode; using tvm::transform::PassContext; using tvm::transform::PassContextNode; +using tvm::transform::PassInfo; +using tvm::transform::PassInfoNode; +using tvm::transform::PassNode; using tvm::transform::Sequential; /* @@ -52,12 +52,9 @@ using tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< - PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); - +TVM_DLL Pass CreatePrimFuncPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, const tvm::Array& required); /*! * \brief Inject prefetch instructions into stmt. @@ -76,8 +73,7 @@ TVM_DLL Pass InjectPrefetch(); * * \return The Pass */ -TVM_DLL Pass StorageFlatten(int cache_line_size, - bool create_bound_attribute = false); +TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); /*! * \brief Inject copy intrinsics with optional pad. @@ -92,8 +88,7 @@ TVM_DLL Pass StorageFlatten(int cache_line_size, * Expr pad_value) * \return The pass. */ -TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, - runtime::PackedFunc fintrin); +TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, runtime::PackedFunc fintrin); /*! * \brief Detect and insert sync points to co-processor. @@ -164,9 +159,7 @@ TVM_DLL Pass StorageRewrite(); * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. * \return The pass. */ -TVM_DLL Pass UnrollLoop(int auto_max_step, - int auto_max_depth, - int auto_max_extent, +TVM_DLL Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll); /*! @@ -184,17 +177,17 @@ TVM_DLL Pass RemoveNoOp(); TVM_DLL Pass RewriteUnsafeSelect(); /*! -* \brief Run arithmetic simplifications on the statements and expressions. -* -* \return The pass. -*/ + * \brief Run arithmetic simplifications on the statements and expressions. + * + * \return The pass. + */ TVM_DLL Pass Simplify(); /*! -* \brief Instruments bound checkers. -* -* \return The pass. -*/ + * \brief Instruments bound checkers. + * + * \return The pass. + */ TVM_DLL Pass InstrumentBoundCheckers(); /*! @@ -278,7 +271,6 @@ TVM_DLL Pass SkipAssert(); */ TVM_DLL Pass ThreadSync(std::string storage_scope); - /*! * \brief Lower cross thread alleduce. * @@ -328,7 +320,6 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); */ TVM_DLL Pass CombineContextCall(); - /*! * \brief Narrow down PrimExpr datatype in stmt to target_bits. * diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index bb73bf03d88b5..a89c665b93776 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -24,9 +24,10 @@ #ifndef TVM_TIR_VAR_H_ #define TVM_TIR_VAR_H_ +#include #include #include -#include + #include namespace tvm { @@ -91,8 +92,7 @@ class Var : public PrimExpr { * \param name_hint variable name * \param dtype data type */ - TVM_DLL explicit Var(std::string name_hint = "v", - DataType dtype = DataType::Int(32)); + TVM_DLL explicit Var(std::string name_hint = "v", DataType dtype = DataType::Int(32)); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. @@ -109,16 +109,12 @@ class Var : public PrimExpr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { - return get(); - } + const VarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* get() const { - return static_cast(data_.get()); - } + const VarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = VarNode; }; @@ -142,27 +138,21 @@ class SizeVar : public Var { * \param name_hint variable name * \param t data type */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); + TVM_DLL explicit SizeVar(std::string name_hint = "s", DataType t = DataType::Int(32)); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* operator->() const { - return get(); - } + const SizeVarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* get() const { - return static_cast(data_.get()); - } + const SizeVarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = SizeVarNode; }; - /*! \brief container class of iteration variable. */ class IterVarNode; @@ -292,11 +282,8 @@ class IterVarNode : public Object { } bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { - return - equal(dom, other->dom) && - equal.DefEqual(var, other->var) && - equal(iter_type, other->iter_type) && - equal(thread_tag, other->thread_tag); + return equal(dom, other->dom) && equal.DefEqual(var, other->var) && + equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag); } void SHashReduce(SHashReducer hash_reduce) const { @@ -306,8 +293,7 @@ class IterVarNode : public Object { hash_reduce(thread_tag); } - TVM_DLL static IterVar make(Range dom, Var var, - IterVarType iter_type, + TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; @@ -321,21 +307,28 @@ inline const IterVarNode* IterVar::operator->() const { return static_cast(data_.get()); } -inline IterVar::operator PrimExpr() const { - return (*this)->var; -} +inline IterVar::operator PrimExpr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { switch (t) { - case kDataPar: return "DataPar"; - case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommReduce"; - case kOrdered: return "Ordered"; - case kOpaque: return "Opaque"; - case kUnrolled: return "Unrolled"; - case kVectorized: return "Vectorized"; - case kParallelized: return "Parallelized"; - case kTensorized: return "Tensorized"; + case kDataPar: + return "DataPar"; + case kThreadIndex: + return "ThreadIndex"; + case kCommReduce: + return "CommReduce"; + case kOrdered: + return "Ordered"; + case kOpaque: + return "Opaque"; + case kUnrolled: + return "Unrolled"; + case kVectorized: + return "Vectorized"; + case kParallelized: + return "Parallelized"; + case kTensorized: + return "Tensorized"; } return "Unknown"; } diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java index c31c67f283af5..61ff966eaf380 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java @@ -38,53 +38,14 @@ public class GraphRuntime { * @return Runtime graph module that can be used to execute the graph. */ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) { - Module graphModule = null; - if (ctx.deviceType >= RPC.RPC_SESS_MASK) { - if (!(ctx instanceof TVMRemoteContext)) { - throw new IllegalArgumentException( - "Looks like you are using remote context with no RPCSession bind." - + "Use session.context instead."); - } - RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession; - // check arguments - if (!"rpc".equals(libmod.typeKey())) { - throw new IllegalArgumentException("libmod.typeKey != rpc"); - } - final int sessIndex = (int) ((Function) reflectionStaticCall( - RPC.class, "getApi", "_SessTableIndex")) - .pushArg(libmod).invoke().asLong(); - if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { - throw new IllegalArgumentException(String.format( - "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d", - sessIndex, reflectionGetField(rpcSession, "tblIndex"))); - } - - Function rpcModuleHandle = (Function) reflectionStaticCall( - RPC.class, "getApi","_ModuleHandle"); - if (rpcModuleHandle == null) { - throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." - + "Did you compile tvm_runtime with the correct version?"); - } - - Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." - + "Did you compile tvm_runtime with correct version?"); - } - - TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke(); - graphModule = fcreate.call(graphJson, hmod, - ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule(); - } else { - Function fcreate = Function.getFunction("tvm.graph_runtime.create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." - + "Did you compile tvm_runtime with correct version?"); - } - graphModule = fcreate.pushArg(graphJson) - .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) - .invoke().asModule(); + Function fcreate = Function.getFunction("tvm.graph_runtime.create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." + + "Did you compile tvm_runtime with correct version?"); } + Module graphModule = fcreate.pushArg(graphJson) + .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) + .invoke().asModule(); return new GraphModule(graphModule, ctx); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 5178ac900a367..69321c3b51c80 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -29,7 +29,7 @@ public class Client { * @return The connected session. */ public static RPCSession connect(String url, int port, String key) { - Function doConnect = RPC.getApi("_Connect"); + Function doConnect = RPC.getApi("Connect"); if (doConnect == null) { throw new RuntimeException("Please compile with USE_RPC=1"); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java index 29a457f39a40c..1f3191fb2e8ca 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); + RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 92b328488b403..b9f621473cf4d 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -39,7 +39,7 @@ public class RPCSession { RPCSession(Module sess) { session = sess; - tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong(); } /** @@ -237,7 +237,7 @@ public byte[] download(String path) { * @return The remote module containing remote function. */ public Module loadModule(String path) { - return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); } diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index ce1979c6618b4..0f202004f99d8 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -26,7 +26,7 @@ #define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ // Helper functions for RefXXX getter & setter -jlong getLongField(JNIEnv *env, jobject obj) { +jlong getLongField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); jlong ret = env->GetLongField(obj, refFid); @@ -34,7 +34,7 @@ jlong getLongField(JNIEnv *env, jobject obj) { return ret; } -jint getIntField(JNIEnv *env, jobject obj) { +jint getIntField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); jint ret = env->GetIntField(obj, refFid); @@ -42,21 +42,21 @@ jint getIntField(JNIEnv *env, jobject obj) { return ret; } -void setIntField(JNIEnv *env, jobject obj, jint value) { +void setIntField(JNIEnv* env, jobject obj, jint value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); env->SetIntField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setLongField(JNIEnv *env, jobject obj, jlong value) { +void setLongField(JNIEnv* env, jobject obj, jlong value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); env->SetLongField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setStringField(JNIEnv *env, jobject obj, const char *value) { +void setStringField(JNIEnv* env, jobject obj, const char* value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefString"); jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); env->SetObjectField(obj, refFid, env->NewStringUTF(value)); @@ -64,8 +64,8 @@ void setStringField(JNIEnv *env, jobject obj, const char *value) { } // Helper functions for TVMValue -jlong getTVMValueLongField(JNIEnv *env, jobject obj, - const char *clsname = "org/apache/tvm/TVMValueLong") { +jlong getTVMValueLongField(JNIEnv* env, jobject obj, + const char* clsname = "org/apache/tvm/TVMValueLong") { jclass cls = env->FindClass(clsname); jfieldID fid = env->GetFieldID(cls, "value", "J"); jlong ret = env->GetLongField(obj, fid); @@ -73,7 +73,7 @@ jlong getTVMValueLongField(JNIEnv *env, jobject obj, return ret; } -jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { +jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jfieldID fid = env->GetFieldID(cls, "value", "D"); jdouble ret = env->GetDoubleField(obj, fid); @@ -81,7 +81,7 @@ jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { return ret; } -jstring getTVMValueStringField(JNIEnv *env, jobject obj) { +jstring getTVMValueStringField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;"); jstring ret = static_cast(env->GetObjectField(obj, fid)); @@ -89,7 +89,7 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) { return ret; } -jobject newTVMValueHandle(JNIEnv *env, jlong value) { +jobject newTVMValueHandle(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -97,7 +97,7 @@ jobject newTVMValueHandle(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueLong(JNIEnv *env, jlong value) { +jobject newTVMValueLong(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueLong"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -105,7 +105,7 @@ jobject newTVMValueLong(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueDouble(JNIEnv *env, jdouble value) { +jobject newTVMValueDouble(JNIEnv* env, jdouble value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jmethodID constructor = env->GetMethodID(cls, "", "(D)V"); jobject object = env->NewObject(cls, constructor, value); @@ -113,7 +113,7 @@ jobject newTVMValueDouble(JNIEnv *env, jdouble value) { return object; } -jobject newTVMValueString(JNIEnv *env, const char *value) { +jobject newTVMValueString(JNIEnv* env, const char* value) { jstring jvalue = env->NewStringUTF(value); jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jmethodID constructor = env->GetMethodID(cls, "", "(Ljava/lang/String;)V"); @@ -123,10 +123,10 @@ jobject newTVMValueString(JNIEnv *env, const char *value) { return object; } -jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { +jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) { jbyteArray jarr = env->NewByteArray(arr->size); env->SetByteArrayRegion(jarr, 0, arr->size, - reinterpret_cast(const_cast(arr->data))); + reinterpret_cast(const_cast(arr->data))); jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes"); jmethodID constructor = env->GetMethodID(cls, "", "([B)V"); jobject object = env->NewObject(cls, constructor, jarr); @@ -135,7 +135,7 @@ jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { return object; } -jobject newModule(JNIEnv *env, jlong value) { +jobject newModule(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Module"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -143,7 +143,7 @@ jobject newModule(JNIEnv *env, jlong value) { return object; } -jobject newFunction(JNIEnv *env, jlong value) { +jobject newFunction(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Function"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -151,7 +151,7 @@ jobject newFunction(JNIEnv *env, jlong value) { return object; } -jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { +jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { jclass cls = env->FindClass("org/apache/tvm/NDArrayBase"); jmethodID constructor = env->GetMethodID(cls, "", "(JZ)V"); jobject object = env->NewObject(cls, constructor, handle, isview); @@ -159,7 +159,7 @@ jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { return object; } -jobject newObject(JNIEnv *env, const char *clsname) { +jobject newObject(JNIEnv* env, const char* clsname) { jclass cls = env->FindClass(clsname); jmethodID constructor = env->GetMethodID(cls, "", "()V"); jobject object = env->NewObject(cls, constructor); @@ -167,7 +167,7 @@ jobject newObject(JNIEnv *env, const char *clsname) { return object; } -void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { +void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) { jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType"); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I"))); @@ -175,16 +175,16 @@ void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { env->DeleteLocalRef(tvmTypeClass); } -void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) { +void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) { jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext"); - ctx->device_type = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceType", "I"))); - ctx->device_id = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceId", "I"))); + ctx->device_type = static_cast( + env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I"))); + ctx->device_id = + static_cast(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I"))); env->DeleteLocalRef(tvmContextClass); } -jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { +jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { switch (tcode) { case kDLUInt: case kDLInt: @@ -204,7 +204,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { case kTVMStr: return newTVMValueString(env, value.v_str); case kTVMBytes: - return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); + return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); case kTVMNullptr: return newObject(env, "org/apache/tvm/TVMValueNull"); default: diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index b59956824d263..6fc316ca8739c 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -29,28 +29,28 @@ #include #include #endif -#include #include -#include +#include #include +#include #include "jni_helper_func.h" -JavaVM *_jvm; -void *_tvmHandle = nullptr; +JavaVM* _jvm; +void* _tvmHandle = nullptr; struct TVMFuncArgsThreadLocalEntry { std::vector tvmFuncArgValues; std::vector tvmFuncArgTypes; // for later release - std::vector > tvmFuncArgPushedStrs; - std::vector > tvmFuncArgPushedBytes; + std::vector > tvmFuncArgPushedStrs; + std::vector > tvmFuncArgPushedBytes; }; typedef dmlc::ThreadLocalStore TVMFuncArgsThreadLocalStore; -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit - (JNIEnv *env, jobject obj, jstring jtvmLibFile) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj, + jstring jtvmLibFile) { if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) { - const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); + const char* tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); _tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL); env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile); if (!_tvmHandle) { @@ -61,70 +61,70 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit return env->GetJavaVM(&_jvm); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv *env, jobject obj) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject obj) { if (_tvmHandle) { dlclose(_tvmHandle); } return 0; } -JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv * env, jobject obj) { +JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) { return env->NewStringUTF(TVMGetLastError()); } // Function -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong( - JNIEnv *env, jobject obj, jlong arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj, + jlong arg) { TVMValue value; value.v_int64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLInt); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble( - JNIEnv *env, jobject obj, jdouble arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj, + jdouble arg) { TVMValue value; value.v_float64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLFloat); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString( - JNIEnv *env, jobject obj, jstring arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj, + jstring arg) { TVMValue value; jstring garg = reinterpret_cast(env->NewGlobalRef(arg)); value.v_str = env->GetStringUTFChars(garg, 0); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMStr); // release string args later e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle( - JNIEnv *env, jobject obj, jlong arg, jint argType) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj, + jlong arg, jint argType) { TVMValue value; - value.v_handle = reinterpret_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + value.v_handle = reinterpret_cast(arg); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(static_cast(argType)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( - JNIEnv *env, jobject obj, jbyteArray arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, + jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); - jbyte *data = env->GetByteArrayElements(garg, 0); + jbyte* data = env->GetByteArrayElements(garg, 0); - TVMByteArray *byteArray = new TVMByteArray(); + TVMByteArray* byteArray = new TVMByteArray(); byteArray->size = static_cast(env->GetArrayLength(garg)); - byteArray->data = reinterpret_cast(data); + byteArray->data = reinterpret_cast(data); TVMValue value; - value.v_handle = reinterpret_cast(byteArray); + value.v_handle = reinterpret_cast(byteArray); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMBytes); @@ -132,10 +132,10 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( // release (garg, data), byteArray later } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( - JNIEnv *env, jobject obj, jobject jfuncNames) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj, + jobject jfuncNames) { int outSize; - const char **outArray; + const char** outArray; int ret = TVMFuncListGlobalNames(&outSize, &outArray); if (ret) { @@ -157,24 +157,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMFuncFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal( - JNIEnv *env, jobject obj, jstring jname, jobject jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jobject jhandle) { TVMFunctionHandle handle; - const char *name = env->GetStringUTFChars(jname, 0); + const char* name = env->GetStringUTFChars(jname, 0); int ret = TVMFuncGetGlobal(name, &handle); env->ReleaseStringUTFChars(jname, name); setLongField(env, jhandle, reinterpret_cast(handle)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( - JNIEnv *env, jobject obj, jlong jhandle, jobject jretVal) { - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj, + jlong jhandle, jobject jretVal) { + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); int numArgs = e->tvmFuncArgValues.size(); TVMValue retVal; @@ -192,8 +193,8 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( e->tvmFuncArgTypes.clear(); e->tvmFuncArgValues.clear(); - int ret = TVMFuncCall(reinterpret_cast(jhandle), - &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode); + int ret = TVMFuncCall(reinterpret_cast(jhandle), &argValues[0], &argTypes[0], + numArgs, &retVal, &retTypeCode); if (ret != 0) { return ret; @@ -204,16 +205,15 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( env->DeleteGlobalRef(iter->first); } for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { - env->ReleaseByteArrayElements(iter->first, - reinterpret_cast(const_cast(iter->second->data)), 0); + env->ReleaseByteArrayElements( + iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); env->DeleteGlobalRef(iter->first); delete iter->second; } // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue"); - jfieldID refTVMValueFid - = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); + jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode)); @@ -223,16 +223,16 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( } // Callback function -extern "C" int funcInvokeCallback(TVMValue *args, - int *typeCodes, int numArgs, TVMRetValueHandle ret, void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, + TVMRetValueHandle ret, void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } @@ -242,10 +242,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kTVMObjectHandle || - tcode == kTVMPackedFuncHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMModuleHandle) { + if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) { TVMCbArgToReturn(&arg, &tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -253,15 +251,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, } jclass clsFunc = env->FindClass("org/apache/tvm/Function"); - jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(clsFunc, "invokeRegisteredCbFunc", + jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID( + clsFunc, "invokeRegisteredCbFunc", "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;"); - jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack", - "(Ljava/lang/Object;)V"); + jmethodID pushArgToStack = + env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V"); jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc, - reinterpret_cast(resourceHandle), jargs); + reinterpret_cast(resourceHandle), jargs); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); @@ -279,16 +278,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, // release allocated strings. if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) { - const auto &pairArg = e->tvmFuncArgPushedStrs.back(); + const auto& pairArg = e->tvmFuncArgPushedStrs.back(); env->ReleaseStringUTFChars(pairArg.first, pairArg.second); env->DeleteGlobalRef(pairArg.first); e->tvmFuncArgPushedStrs.pop_back(); } // release allocated bytes. if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) { - const auto &pairArg = e->tvmFuncArgPushedBytes.back(); - env->ReleaseByteArrayElements(pairArg.first, - reinterpret_cast(const_cast(pairArg.second->data)), 0); + const auto& pairArg = e->tvmFuncArgPushedBytes.back(); + env->ReleaseByteArrayElements( + pairArg.first, reinterpret_cast(const_cast(pairArg.second->data)), 0); env->DeleteGlobalRef(pairArg.first); delete pairArg.second; e->tvmFuncArgPushedBytes.pop_back(); @@ -301,62 +300,64 @@ extern "C" int funcInvokeCallback(TVMValue *args, } // Free callback function -extern "C" void funcFreeCallback(void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" void funcFreeCallback(void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } env->DeleteGlobalRef(reinterpret_cast(resourceHandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc( - JNIEnv *env, jobject obj, jobject jfunction, jobject jretHandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj, + jobject jfunction, + jobject jretHandle) { TVMFunctionHandle out; - int ret = TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), - reinterpret_cast(env->NewGlobalRef(jfunction)), - reinterpret_cast(&funcFreeCallback), - &out); + int ret = + TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), + reinterpret_cast(env->NewGlobalRef(jfunction)), + reinterpret_cast(&funcFreeCallback), &out); setLongField(env, jretHandle, reinterpret_cast(out)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal( - JNIEnv *env, jobject obj, jstring jname, jlong jhandle, jint joverride) { - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncRegisterGlobal( - name, reinterpret_cast(jhandle), reinterpret_cast(joverride)); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj, + jstring jname, + jlong jhandle, + jint joverride) { + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMFuncRegisterGlobal(name, reinterpret_cast(jhandle), + reinterpret_cast(joverride)); env->ReleaseStringUTFChars(jname, name); return ret; } // Module -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMModFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport( - JNIEnv *env, jobject obj, jlong jmod, jlong jdep) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj, + jlong jmod, jlong jdep) { return TVMModImport(reinterpret_cast(jmod), reinterpret_cast(jdep)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( - JNIEnv *env, jobject obj, jlong jhandle, jstring jname, jint jimport, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj, + jlong jhandle, jstring jname, + jint jimport, jobject jret) { TVMFunctionHandle retFunc; - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMModGetFunction(reinterpret_cast(jhandle), - name, - reinterpret_cast(jimport), - &retFunc); + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMModGetFunction(reinterpret_cast(jhandle), name, + reinterpret_cast(jimport), &retFunc); env->ReleaseStringUTFChars(jname, name); setLongField(env, jret, reinterpret_cast(retFunc)); @@ -365,28 +366,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( } // NDArray -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMArrayFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( - JNIEnv *env, jobject obj, jlongArray jshape, jint jdtypeCode, - jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj, + jlongArray jshape, jint jdtypeCode, + jint jdtypeBits, jint jdtypeLanes, + jint jdeviceType, jint jdeviceId, + jobject jret) { int ndim = static_cast(env->GetArrayLength(jshape)); TVMArrayHandle out; - jlong *shapeArray = env->GetLongArrayElements(jshape, NULL); - int ret = TVMArrayAlloc( - reinterpret_cast(shapeArray), - ndim, - static_cast(jdtypeCode), - static_cast(jdtypeBits), - static_cast(jdtypeLanes), - static_cast(jdeviceType), - static_cast(jdeviceId), - &out); + jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); + int ret = TVMArrayAlloc(reinterpret_cast(shapeArray), ndim, + static_cast(jdtypeCode), static_cast(jdtypeBits), + static_cast(jdtypeLanes), static_cast(jdeviceType), + static_cast(jdeviceId), &out); env->ReleaseLongArrayElements(jshape, shapeArray, 0); setLongField(env, jret, reinterpret_cast(out)); @@ -394,10 +392,10 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( - JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) { - DLTensor *array = reinterpret_cast(jhandle); - int64_t *shape = array->shape; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj, + jlong jhandle, jobject jshape) { + DLTensor* array = reinterpret_cast(jhandle); + int64_t* shape = array->shape; int ndim = array->ndim; // fill shape buffer @@ -417,18 +415,19 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( return 0; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo( - JNIEnv *env, jobject obj, jlong jfrom, jlong jto) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj, + jlong jfrom, jlong jto) { return TVMArrayCopyFromTo(reinterpret_cast(jfrom), reinterpret_cast(jto), NULL); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( - JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) { - jbyte *data = env->GetByteArrayElements(jarr, NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj, + jbyteArray jarr, + jlong jfrom, jlong jto) { + jbyte* data = env->GetByteArrayElements(jarr, NULL); - DLTensor *from = reinterpret_cast(jfrom); - from->data = static_cast(data); + DLTensor* from = reinterpret_cast(jfrom); + from->data = static_cast(data); int ret = TVMArrayCopyFromTo(static_cast(from), reinterpret_cast(jto), NULL); @@ -439,13 +438,14 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( - JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) { - DLTensor *from = reinterpret_cast(jfrom); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj, + jlong jfrom, + jbyteArray jarr) { + DLTensor* from = reinterpret_cast(jfrom); int size = static_cast(env->GetArrayLength(jarr)); - jbyte *pdata = env->GetByteArrayElements(jarr, NULL); + jbyte* pdata = env->GetByteArrayElements(jarr, NULL); int ret = 0; - if (memcpy(static_cast(pdata), from->data, size) == NULL) { + if (memcpy(static_cast(pdata), from->data, size) == NULL) { ret = 1; } env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically @@ -453,7 +453,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( } // Context -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize( - JNIEnv *env, jint deviceType, jint deviceId) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType, + jint deviceId) { return TVMSynchronize(static_cast(deviceType), static_cast(deviceId), NULL); } diff --git a/nnvm/include/nnvm/base.h b/nnvm/include/nnvm/base.h index 678ed4d4a9421..b8c5c6c5ed413 100644 --- a/nnvm/include/nnvm/base.h +++ b/nnvm/include/nnvm/base.h @@ -24,13 +24,13 @@ #ifndef NNVM_BASE_H_ #define NNVM_BASE_H_ +#include +#include #include #include -#include -#include #include +#include #include -#include namespace nnvm { @@ -52,7 +52,7 @@ enum TypeFlag { kFloat16 = 2, kUint8 = 3, kInt32 = 4, - kInt8 = 5, + kInt8 = 5, kInt64 = 6, // kBool = 7, // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index b35e4da343f70..e6efb79e8626b 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,11 +41,11 @@ typedef unsigned int nn_uint; /*! \brief handle to a function that takes param and creates symbol */ -typedef void *OpHandle; +typedef void* OpHandle; /*! \brief handle to a symbol that can be bind as operator */ -typedef void *SymbolHandle; +typedef void* SymbolHandle; /*! \brief handle to Graph */ -typedef void *GraphHandle; +typedef void* GraphHandle; #ifdef __cplusplus extern "C" { @@ -65,7 +65,7 @@ NNVM_DLL void NNAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -NNVM_DLL const char *NNGetLastError(void); +NNVM_DLL const char* NNGetLastError(void); /*! * \brief list all the available operator names, include entries. @@ -73,16 +73,14 @@ NNVM_DLL const char *NNGetLastError(void); * \param out_array the output operator name array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListAllOpNames(nn_uint *out_size, - const char*** out_array); +NNVM_DLL int NNListAllOpNames(nn_uint* out_size, const char*** out_array); /*! * \brief Get operator handle given name. * \param op_name The name of the operator. * \param op_out The returnning op handle. */ -NNVM_DLL int NNGetOpHandle(const char* op_name, - OpHandle* op_out); +NNVM_DLL int NNGetOpHandle(const char* op_name, OpHandle* op_out); /*! * \brief list all the available operators. @@ -93,8 +91,7 @@ NNVM_DLL int NNGetOpHandle(const char* op_name, * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array); +NNVM_DLL int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array); /*! * \brief Get the detailed information about atomic symbol. @@ -109,14 +106,10 @@ NNVM_DLL int NNListUniqueOps(nn_uint *out_size, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGetOpInfo(OpHandle op, - const char **real_name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +NNVM_DLL int NNGetOpInfo(OpHandle op, const char** real_name, const char** description, + nn_uint* num_doc_args, const char*** arg_names, + const char*** arg_type_infos, const char*** arg_descriptions, + const char** return_type); /*! * \brief Create an AtomicSymbol functor. * \param op The operator handle @@ -126,18 +119,15 @@ NNVM_DLL int NNGetOpInfo(OpHandle op, * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out); /*! * \brief Create a Variable Symbol. * \param name name of the variable * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); +NNVM_DLL int NNSymbolCreateVariable(const char* name, SymbolHandle* out); /*! * \brief Create a Symbol by grouping list of symbols together * \param num_symbols number of symbols to be grouped @@ -145,16 +135,13 @@ NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out); /*! * \brief Add src_dep to the handle as control dep. * \param handle The symbol to add dependency edges on. * \param src_dep the source handles. */ -NNVM_DLL int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep); +NNVM_DLL int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep); /*! * \brief Free the symbol handle. * \param symbol the symbol @@ -167,14 +154,14 @@ NNVM_DLL int NNSymbolFree(SymbolHandle symbol); * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); +NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Print the content of symbol, used for debug. * \param symbol the symbol * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); +NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char** out_str); /*! * \brief Get string attribute from symbol * \param symbol the source symbol @@ -183,13 +170,11 @@ NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int *success); +NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success); /*! * \brief Set string attribute from symbol. - * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic + * graph. * * Safe recommendaton: use immutable graph * - Only allow set attributes during creation of new symbol as optional parameter @@ -204,9 +189,7 @@ NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, * \param values The value to be set * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, +NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** values); /*! * \brief Get all attributes from symbol, including all descendents. @@ -216,9 +199,7 @@ NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, * \param out 2*out_size strings representing key value pairs. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, - int recursive_option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, int recursive_option, nn_uint* out_size, const char*** out); /*! @@ -232,9 +213,7 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, * \param out_sym_array the output array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array); /*! @@ -248,10 +227,8 @@ NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array); +NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array); /*! * \brief List returns names in the symbol. * \param symbol the symbol @@ -259,10 +236,8 @@ NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - +NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, + const char*** out_str_array); /*! * \brief Supply number of outputs of the symbol. @@ -270,8 +245,7 @@ NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, * \param output_count number of outputs * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count); +NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count); /*! * \brief Get a symbol that contains all the internals. @@ -279,16 +253,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, * \param out The output symbol whose outputs are all the internals. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get a symbol that contains only direct children. * \param symbol The symbol * \param out The output symbol whose outputs are the direct children. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol @@ -296,9 +268,7 @@ NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, * \param out The output symbol whose outputs are the index-th symbol. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out); /*! * \brief Compose the symbol on other symbols. @@ -314,11 +284,8 @@ NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, * \param args arguments to sym * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCompose(SymbolHandle sym, - const char* name, - nn_uint num_args, - const char** keys, - SymbolHandle* args); +NNVM_DLL int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, + const char** keys, SymbolHandle* args); // Graph IR API /*! @@ -327,7 +294,7 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym, * \param graph The graph handle created. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); +NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph); /*! * \brief free the graph handle * \param handle The handle to be freed. @@ -339,7 +306,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); * \param symbol The corresponding symbol * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); +NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol); /*! * \brief Get Set a attribute in json format. @@ -351,9 +318,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value); +NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value); /*! * \brief Get a serialized attrirbute from graph. @@ -367,10 +332,8 @@ NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success); +NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, + int* success); /*! * \brief Set a attribute whose type is std::vector in c++ @@ -383,9 +346,7 @@ NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, * \param list The symbol whose outputs represents the list of NodeEntry to be passed. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list); +NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list); /*! * \brief Apply passes on the src graph. * \param src The source graph handle. @@ -394,10 +355,8 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, * \param dst The result graph. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst); +NNVM_DLL int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst); #ifdef __cplusplus } /* end extern "C" */ diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 1911a0337ac2f..475494e62c4d3 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -24,13 +24,14 @@ #ifndef NNVM_GRAPH_H_ #define NNVM_GRAPH_H_ -#include -#include -#include #include #include +#include #include #include +#include +#include + #include "base.h" #include "node.h" #include "symbolic.h" @@ -64,7 +65,7 @@ class Graph { * \return the reference to corresponding attribute * \tparam T the type of the attribute. */ - template + template inline const T& GetAttr(const std::string& attr_name) const; /*! * \brief Check whether has a specific attribute. @@ -81,7 +82,7 @@ class Graph { * \return a new copy of the corresponding attribute. * \tparam T the type of the attribute. */ - template + template inline T MoveCopyAttr(const std::string& attr_name); /*! * \brief get a indexed graph of current graph, if not exist, create it on demand @@ -127,13 +128,9 @@ class IndexedGraph { std::weak_ptr weak_ref; }; /*! \return number of nodes in the graph */ - inline size_t num_nodes() const { - return nodes_.size(); - } + inline size_t num_nodes() const { return nodes_.size(); } /*! \return total number of NodeEntry in the graph */ - inline size_t num_node_entries() const { - return entry_rptr_.back(); - } + inline size_t num_node_entries() const { return entry_rptr_.back(); } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given IndexedGraph::NodeEntry @@ -150,9 +147,7 @@ class IndexedGraph { * \param e The entry to query for index. * \return the unique index. */ - inline uint32_t entry_id(const NodeEntry& e) const { - return entry_rptr_[e.node_id] + e.index; - } + inline uint32_t entry_id(const NodeEntry& e) const { return entry_rptr_[e.node_id] + e.index; } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given NodeEntry. @@ -167,42 +162,30 @@ class IndexedGraph { * \param node The Node to query for index. * \return the node index. */ - inline uint32_t node_id(const nnvm::Node* node) const { - return node2index_.at(node); - } + inline uint32_t node_id(const nnvm::Node* node) const { return node2index_.at(node); } /*! * \brief Get the corresponding Node structure for a given node_id. * \param node_id The node id * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](uint32_t node_id) const { - return nodes_[node_id]; - } + inline const Node& operator[](uint32_t node_id) const { return nodes_[node_id]; } /*! * \brief Get the corresponding Node structure * \param node The pointer to the Node structure * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](const nnvm::Node* node) const { - return nodes_[node_id(node)]; - } + inline const Node& operator[](const nnvm::Node* node) const { return nodes_[node_id(node)]; } /*! \return list of argument nodes */ - inline const std::vector& input_nodes() const { - return input_nodes_; - } + inline const std::vector& input_nodes() const { return input_nodes_; } /*! \return list of mutable nodes */ inline const std::unordered_set& mutable_input_nodes() const { return mutable_input_nodes_; } /*! \return list of output entries */ - inline const std::vector& outputs() const { - return outputs_; - } + inline const std::vector& outputs() const { return outputs_; } /*! \return whether a node is existed in the indexed graph */ - inline bool exist(const nnvm::Node* node) const { - return node2index_.count(node); - } + inline bool exist(const nnvm::Node* node) const { return node2index_.count(node); } // disalllow copy assign IndexedGraph(const IndexedGraph&) = delete; @@ -239,15 +222,14 @@ class IndexedGraph { * \param fvisit a function of type std::function&)> * \tparam FVisit The function type to perform the visit. */ -template +template inline void DFSVisit(const std::vector& heads, FVisit fvisit); // inline function implementations -template +template inline const T& Graph::GetAttr(const std::string& attr_name) const { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; return nnvm::unsafe_get(*it->second); } @@ -256,11 +238,10 @@ inline bool Graph::HasAttr(const std::string& attr_name) const { return it != attrs.end(); } -template +template inline T Graph::MoveCopyAttr(const std::string& attr_name) { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; std::shared_ptr sptr = it->second; attrs.erase(it); if (sptr.unique()) { @@ -270,14 +251,10 @@ inline T Graph::MoveCopyAttr(const std::string& attr_name) { } } -template -void PostOrderDFSVisit(const std::vector& heads, - FVisit fvisit, - HashFunc hash, - InDegree indegree, - GetInput getinput) { +template +void PostOrderDFSVisit(const std::vector& heads, FVisit fvisit, HashFunc hash, + InDegree indegree, GetInput getinput) { std::vector > stack; std::unordered_set visited; for (auto& head : heads) { @@ -303,28 +280,20 @@ void PostOrderDFSVisit(const std::vector& heads, } } -template -inline void DFSVisit(const std::vector& heads, - FVisit fvisit) { +template +inline void DFSVisit(const std::vector& heads, FVisit fvisit) { typedef const ObjectPtr* GNode; std::vector head_nodes(heads.size()); std::transform(heads.begin(), heads.end(), head_nodes.begin(), - [](const NodeEntry& e)->GNode { - return &e.node; - }); + [](const NodeEntry& e) -> GNode { return &e.node; }); PostOrderDFSVisit( - head_nodes, - [fvisit](GNode n) { - fvisit(*n); - }, // FVisit - [](GNode n)->Node* { - return n->get(); - }, // HashFunc - [](GNode n)->uint32_t { // InDegree + head_nodes, [fvisit](GNode n) { fvisit(*n); }, // FVisit + [](GNode n) -> Node* { return n->get(); }, // HashFunc + [](GNode n) -> uint32_t { // InDegree if (!(*n)) return 0; return (*n)->inputs.size() + (*n)->control_deps.size(); - }, - [](GNode n, uint32_t index)->GNode { // GetInput + }, + [](GNode n, uint32_t index) -> GNode { // GetInput if (index < (*n)->inputs.size()) { return &(*n)->inputs.at(index).node; } else { diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index acc52a2ae1db8..9e0185526eef6 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,12 @@ #ifndef NNVM_GRAPH_ATTR_TYPES_H_ #define NNVM_GRAPH_ATTR_TYPES_H_ -#include #include #include -#include "tuple.h" +#include + #include "layout.h" +#include "tuple.h" namespace nnvm { diff --git a/nnvm/include/nnvm/layout.h b/nnvm/include/nnvm/layout.h index 3a81b84b24871..e2e99784c99e4 100644 --- a/nnvm/include/nnvm/layout.h +++ b/nnvm/include/nnvm/layout.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,11 +31,12 @@ #define NNVM_LAYOUT_H_ #include -#include + +#include #include -#include +#include #include -#include +#include namespace nnvm { @@ -44,7 +45,7 @@ class Layout { using LayoutDim = char; /*! \brief default constructor */ - Layout() : name_("__undef__") {} // NOLINT(*) + Layout() : name_("__undef__") {} // NOLINT(*) /*! * \brief construct from a string. @@ -54,21 +55,21 @@ class Layout { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - inline Layout(const std::string& layout) { // NOLINT(*) + inline Layout(const std::string& layout) { // NOLINT(*) parse(layout); } /*! * \brief copy constructor from another layout * \param s the source layout */ - inline Layout(const Layout& s) { // NOLINT(*) + inline Layout(const Layout& s) { // NOLINT(*) this->parse(s.name_); } /*! * \brief move constructor from Layout * \param src the source layout */ - inline Layout(Layout&& src) { // NOLINT(*) + inline Layout(Layout&& src) { // NOLINT(*) this->swap(src); } /*! @@ -86,7 +87,7 @@ class Layout { * \return reference of self */ inline Layout& operator=(Layout&& src) { - Layout(std::move(src)).swap(*this); // NOLINT(*) + Layout(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! @@ -102,16 +103,12 @@ class Layout { * \return whether two layout equals * \param s the layout to compare against */ - inline bool operator==(const Layout& s) const { - return name_ == s.name_; - } + inline bool operator==(const Layout& s) const { return name_ == s.name_; } /*! * \return whether two layout not equal * \param s the layout to compare against */ - inline bool operator!=(const Layout& s) const { - return !(*this == s); - } + inline bool operator!=(const Layout& s) const { return !(*this == s); } /*! * \brief Append the current layout by another. @@ -134,18 +131,14 @@ class Layout { * \param dim input dimension * \return Whether a given dimension is a super-dimension. */ - static inline bool is_superdim(LayoutDim dim) { - return dim >= 'A' && dim <= 'Z'; - } + static inline bool is_superdim(LayoutDim dim) { return dim >= 'A' && dim <= 'Z'; } /*! * \brief Check whether a given dimension is a sub-dimension. * \param dim input dimension * \return Whether a given dimension is a sub-dimension. */ - static inline bool is_subdim(LayoutDim dim) { - return dim >= 'a' && dim <= 'z'; - } + static inline bool is_subdim(LayoutDim dim) { return dim >= 'a' && dim <= 'z'; } /*! * \brief Convert a given dimension to super-dimension. @@ -200,7 +193,7 @@ class Layout { * \param dst the target layout * \return Whether can be converted to dst layout. */ - inline bool convertible(const Layout &dst) const { + inline bool convertible(const Layout& dst) const { if (!this->defined() || !dst.defined()) return false; for (size_t i = 0; i < kUniqueDim; ++i) { if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || @@ -258,13 +251,12 @@ class Layout { * \return A newly constructed Layout object. */ inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name_; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name_; CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; - CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim - << " has already been split in " - << name_; + CHECK(!this->contains(to_subdim(dim))) + << "Dimension " << dim << " has already been split in " << name_; CHECK(size > 0) << "Invalid split size " << size; std::ostringstream new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { @@ -282,26 +274,16 @@ class Layout { using reverse_iterator = std::vector::const_reverse_iterator; /*! \return begin iterator */ - inline iterator begin() const { - return layout_simplified_.begin(); - } + inline iterator begin() const { return layout_simplified_.begin(); } /*! \return end iterator */ - inline iterator end() const { - return layout_simplified_.end(); - } + inline iterator end() const { return layout_simplified_.end(); } /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return layout_simplified_.rbegin(); - } + inline reverse_iterator rbegin() const { return layout_simplified_.rbegin(); } /*! \return rend iterator */ - inline reverse_iterator rend() const { - return layout_simplified_.rend(); - } + inline reverse_iterator rend() const { return layout_simplified_.rend(); } /*! \return number of dimensions */ - inline size_t ndim() const { - return layout_simplified_.size(); - } + inline size_t ndim() const { return layout_simplified_.size(); } /*! * \brief The description of the \p i-th dimension. @@ -311,8 +293,7 @@ class Layout { * \return the description of the dimension. */ inline std::string at(size_t i) const { - CHECK_LT(i, this->ndim()) << "position " << i - << " exceeds ndim=" << this->ndim(); + CHECK_LT(i, this->ndim()) << "position " << i << " exceeds ndim=" << this->ndim(); std::ostringstream repr; if (is_subdim(layout_simplified_[i])) { auto factor = subsizeof(layout_simplified_[i]); @@ -331,9 +312,12 @@ class Layout { * \return the index or -1 if not found. */ inline int32_t indexof(LayoutDim dim) const { - if (!this->defined()) return -1; - else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; - else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; + if (!this->defined()) + return -1; + else if (is_superdim(dim)) + return superdim_pos_[dim - 'A']; + else if (is_subdim(dim)) + return subdim_pos_[dim - 'a']; return -1; } @@ -359,34 +343,26 @@ class Layout { */ inline bool contains(LayoutDim dim) const { if (is_superdim(dim)) { - return superdim_pos_[dim-'A'] >= 0; + return superdim_pos_[dim - 'A'] >= 0; } else if (is_subdim(dim)) { - return subdim_pos_[dim-'a'] >= 0; + return subdim_pos_[dim - 'a'] >= 0; } return false; } - inline LayoutDim operator[](size_t i) const { - return layout_simplified_[i]; - } + inline LayoutDim operator[](size_t i) const { return layout_simplified_[i]; } /*! \return whether the layout is defined */ - inline bool defined() const { - return name_ != "__undef__"; - } + inline bool defined() const { return name_ != "__undef__"; } /*! \return the string description of the layout */ - inline const std::string& name() const { - return name_; - } + inline const std::string& name() const { return name_; } /*! * \brief Write layout in JSON format. * \param writer JSONWriter */ - inline void Save(dmlc::JSONWriter* writer) const { - writer->Write(name_); - } + inline void Save(dmlc::JSONWriter* writer) const { writer->Write(name_); } /*! * \brief Load layout from JSON. @@ -433,21 +409,20 @@ class Layout { const LayoutDim c = layout.at(i); if (is_superdim(c)) { int pos = c - 'A'; - CHECK_EQ(factor, 0) << "Invalid layout " << layout - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor << " before dimension " << c; - CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_EQ(superdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; superdim_pos_[pos] = curr++; layout_simplified_.push_back(c); } else if (is_subdim(c)) { int pos = c - 'a'; - CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " - << factor << " for dimension " << c; - CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor + << " for dimension " << c; + CHECK_EQ(subdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; + CHECK_EQ(subdim_size_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; subdim_pos_[pos] = curr++; subdim_size_[pos] = factor; layout_simplified_.push_back(c); @@ -461,9 +436,8 @@ class Layout { } CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; for (LayoutDim dim : layout_simplified_) { - CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) - << "Invalid layout " << layout << ": missing axis " - << static_cast(dim - 'a' + 'A'); + CHECK(is_superdim(dim) || superdim_pos_[dim - 'a'] >= 0) + << "Invalid layout " << layout << ": missing axis " << static_cast(dim - 'a' + 'A'); } } }; diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 2155481373fd5..91d13e569d1e3 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -26,12 +26,13 @@ #include #include -#include -#include #include +#include +#include + #include "base.h" -#include "op.h" #include "c_api.h" +#include "op.h" namespace nnvm { @@ -49,27 +50,16 @@ using ObjectPtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { - NodeEntry(ObjectPtr node, uint32_t index, uint32_t version): - node(std::move(node)), - index(index), - version(version) - {} - - explicit NodeEntry(ObjectPtr node): - node(std::move(node)), - index(), - version() - {} + NodeEntry(ObjectPtr node, uint32_t index, uint32_t version) + : node(std::move(node)), index(index), version(version) {} + + explicit NodeEntry(ObjectPtr node) : node(std::move(node)), index(), version() {} /** * MXNet assumes that a node with a null ptr doesn't have a gradient attached. Don't change this * constructor. */ - NodeEntry(): - node(nullptr), - index(), - version() - {} + NodeEntry() : node(nullptr), index(), version() {} /*! \brief the source node of this data */ ObjectPtr node; @@ -79,7 +69,8 @@ struct NodeEntry { * \brief version of input Variable. * This field can only be nonzero when this->node is a Variable node. * version is increased by one each time a Variable get composed to a mutation Op. - * This information can be helpful to decide order of operations when sequence of mutation happens. + * This information can be helpful to decide order of operations when sequence of mutation + * happens. */ uint32_t version; }; @@ -90,9 +81,8 @@ struct NodeEntry { */ struct NodeEntryHash { size_t operator()(const NodeEntry& e) const { - return std::hash()(e.node.get()) ^ - (std::hash()(e.index) << 1 >> 1) ^ - (std::hash()(e.version) << 1); + return std::hash()(e.node.get()) ^ (std::hash()(e.index) << 1 >> 1) ^ + (std::hash()(e.version) << 1); } }; @@ -102,14 +92,12 @@ struct NodeEntryHash { */ struct NodeEntryEqual { size_t operator()(const NodeEntry& a, const NodeEntry& b) const { - return (a.node.get() == b.node.get()) && - (a.index == b.index) && - (a.version == b.version); + return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version); } }; /*! use NodeEntry as key in unordered_map */ -template +template using NodeEntryMap = std::unordered_map; /*! @@ -121,7 +109,7 @@ struct NodeAttrs { * \brief The operator this node uses. * For place holder variable, op == nullptr. */ - const Op *op{nullptr}; + const Op* op{nullptr}; /*! \brief name of the node */ std::string name; /*! \brief The dictionary representation of attributes */ @@ -188,7 +176,7 @@ class NNVM_DLL Node { * \brief create a new empty shared_ptr of Node. * \return a created empty node. */ - template + template static ObjectPtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } @@ -202,12 +190,9 @@ class NNVM_DLL Node { * \param attrs The attributes * \return The created node entry. */ -inline NodeEntry MakeNode( - const char* op_name, - std::string node_name, - std::vector inputs, - std::unordered_map attrs = - std::unordered_map()) { +inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector inputs, + std::unordered_map attrs = + std::unordered_map()) { ObjectPtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); @@ -220,13 +205,9 @@ inline NodeEntry MakeNode( } // implementation of functions. -inline const Op* Node::op() const { - return this->attrs.op; -} +inline const Op* Node::op() const { return this->attrs.op; } -inline bool Node::is_variable() const { - return this->op() == nullptr; -} +inline bool Node::is_variable() const { return this->op() == nullptr; } inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 645804cd361f0..f53e0f25ee370 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,12 +25,14 @@ #define NNVM_OP_H_ #include + +#include +#include #include -#include -#include #include -#include -#include +#include +#include + #include "base.h" #include "c_api.h" @@ -39,7 +41,7 @@ namespace nnvm { // forward declarations class Node; struct NodeAttrs; -template +template class OpMap; class OpGroup; class OpRegistryEntry; @@ -193,15 +195,14 @@ class NNVM_DLL Op { * \param description Description of the argument. * \return reference to self. */ - inline Op& add_argument(const std::string &name, - const std::string &type, - const std::string &description); + inline Op& add_argument(const std::string& name, const std::string& type, + const std::string& description); /*! * \brief Append list if arguments to the end. * \param args Additional list of arguments. * \return reference to self. */ - inline Op& add_arguments(const std::vector &args); + inline Op& add_arguments(const std::vector& args); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -219,7 +220,7 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_inputs(std::function fn); // NOLINT(*) + inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. @@ -231,13 +232,13 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(std::function fn); // NOLINT(*) + inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. * \return reference to self. */ - inline Op& set_attr_parser(std::function fn); // NOLINT(*) + inline Op& set_attr_parser(std::function fn); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -251,10 +252,9 @@ class NNVM_DLL Op { * * \tparam ValueType The type of the value to be set. */ - template + template inline Op& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); + const ValueType& value, int plevel = 10); /*! * \brief Add another alias to this operator. * The same Op can be queried with Op::Get(alias) @@ -284,11 +284,11 @@ class NNVM_DLL Op { * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ - template + template static const OpMap& GetAttr(const std::string& attr_name); private: - template + template friend class OpMap; friend class OpGroup; friend class dmlc::Registry; @@ -300,15 +300,13 @@ class NNVM_DLL Op { // get const reference to certain attribute static const any* GetAttrMap(const std::string& key); // update the attribute OpMap - static void UpdateAttrMap(const std::string& key, - std::function updater); + static void UpdateAttrMap(const std::string& key, std::function updater); // add a trigger based on tag matching on certain tag attribute // This will apply trigger on all the op such that // include the corresponding group. // The trigger will also be applied to all future registrations // that calls include - static void AddGroupTrigger(const std::string& group_name, - std::function trigger); + static void AddGroupTrigger(const std::string& group_name, std::function trigger); }; /*! @@ -316,7 +314,7 @@ class NNVM_DLL Op { * and returns ValueType * \tparam ValueType The type of the value stored in map. */ -template +template class OpMap { public: /*! @@ -351,7 +349,7 @@ class OpMap { // internal attribute name std::string attr_name_; // internal data - std::vector > data_; + std::vector> data_; OpMap() = default; }; @@ -376,18 +374,17 @@ class OpGroup { * * \tparam ValueType The type of the value to be set. */ - template + template inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 1); + const ValueType& value, int plevel = 1); }; // internal macros to make -#define NNVM_REGISTER_VAR_DEF(OpName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName +#define NNVM_REGISTER_VAR_DEF(OpName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName -#define NNVM_REGISTER_GVAR_DEF(TagName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName +#define NNVM_REGISTER_GVAR_DEF(TagName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName /*! * \def NNVM_REGISTER_OP @@ -404,8 +401,8 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ +#define NNVM_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) /*! @@ -429,85 +426,72 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP_GROUP(GroupName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ - ::nnvm::OpGroup {#GroupName} +#define NNVM_REGISTER_OP_GROUP(GroupName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName } // implementations of template functions after this. // member function of Op -template +template inline const OpMap& Op::GetAttr(const std::string& key) { const any* ref = GetAttrMap(key); if (ref == nullptr) { // update the attribute map of the key by creating new empty OpMap UpdateAttrMap(key, [key](any* pmap) { - // use callback so it is in lockscope - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = key; - *pmap = std::move(pm); - } - }); + // use callback so it is in lockscope + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = key; + *pmap = std::move(pm); + } + }); ref = GetAttrMap(key); } - return nnvm::get >(*ref); + return nnvm::get>(*ref); } -template +template inline Op& Op::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; // update the attribute map of the key by creating new empty if needed. - UpdateAttrMap(attr_name, - [this, attr_name, value, plevel](any* pmap) { - // the callback is in lockscope so is threadsafe. - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = attr_name; - *pmap = std::move(pm); - } - CHECK(pmap->type() == typeid(OpMap)) - << "Attribute " << attr_name - << " of operator " << this->name - << " is registered as inconsistent types" - << " previously " << pmap->type().name() - << " current " << typeid(OpMap).name(); - std::vector >& vec = - nnvm::get >(*pmap).data_; - // resize the value type. - if (vec.size() <= index_) { - vec.resize(index_ + 1, - std::make_pair(ValueType(), 0)); - } - std::pair& p = vec[index_]; - CHECK(p.second != plevel) - << "Attribute " << attr_name - << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - if (p.second < plevel) { - vec[index_] = std::make_pair(value, plevel); - } - }); + UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) { + // the callback is in lockscope so is threadsafe. + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = attr_name; + *pmap = std::move(pm); + } + CHECK(pmap->type() == typeid(OpMap)) + << "Attribute " << attr_name << " of operator " << this->name + << " is registered as inconsistent types" + << " previously " << pmap->type().name() << " current " << typeid(OpMap).name(); + std::vector>& vec = nnvm::get>(*pmap).data_; + // resize the value type. + if (vec.size() <= index_) { + vec.resize(index_ + 1, std::make_pair(ValueType(), 0)); + } + std::pair& p = vec[index_]; + CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + vec[index_] = std::make_pair(value, plevel); + } + }); return *this; } - inline Op& Op::describe(const std::string& descr) { // NOLINT(*) this->description = descr; return *this; } -inline Op& Op::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { +inline Op& Op::add_argument(const std::string& name, const std::string& type, + const std::string& description) { arguments.push_back({name, type, type, description}); return *this; } -inline Op& Op::add_arguments(const std::vector &args) { +inline Op& Op::add_arguments(const std::vector& args) { this->arguments.insert(arguments.end(), args.begin(), args.end()); return *this; } @@ -522,7 +506,7 @@ inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -532,18 +516,18 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } -inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) +inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) this->attr_parser = fn; return *this; } // member functions of OpMap -template +template inline int OpMap::count(const Op* op) const { if (contains(op)) { return 1; @@ -552,7 +536,7 @@ inline int OpMap::count(const Op* op) const { } } -template +template inline bool OpMap::contains(const Op* op) const { if (op == nullptr) { return false; @@ -561,17 +545,16 @@ inline bool OpMap::contains(const Op* op) const { return idx < data_.size() ? (data_[idx].second != 0) : false; } -template +template inline const ValueType& OpMap::operator[](const Op* op) const { CHECK(op != nullptr); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } -template +template inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { if (op == nullptr) return def_value; const uint32_t idx = op->index_; @@ -582,9 +565,8 @@ inline const ValueType& OpMap::get(const Op* op, const ValueType& def } } -template -inline OpGroup& OpGroup::set_attr(const std::string& attr_name, - const ValueType& value, +template +inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value, int plevel) { auto trigger = [attr_name, value, plevel](Op* op) { op->set_attr(attr_name, value, plevel); diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index c2af989ba3e0c..84095368886eb 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -24,15 +24,16 @@ #ifndef NNVM_OP_ATTR_TYPES_H_ #define NNVM_OP_ATTR_TYPES_H_ -#include -#include -#include #include +#include #include +#include +#include + #include "base.h" +#include "layout.h" #include "node.h" #include "tuple.h" -#include "layout.h" namespace nnvm { @@ -48,7 +49,7 @@ namespace nnvm { * * FListInputNames enables automatic variable creation for missing arguments. */ -using FListInputNames = std::function (const NodeAttrs& attrs)>; +using FListInputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Return number of visible outputs by the user. @@ -60,7 +61,7 @@ using FListInputNames = std::function (const NodeAttrs& * but the additional outputs can be used to pass information from * forward to gradient pass. */ -using FNumVisibleOutputs = std::function; +using FNumVisibleOutputs = std::function; /*! * \brief Return list of output arguments names of each operator. @@ -71,7 +72,7 @@ using FNumVisibleOutputs = std::function; * * FListOutputNames customized naming for operator outputs. */ -using FListOutputNames = std::function (const NodeAttrs& attrs)>; +using FListOutputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Check whether operator will mutate k-th input. @@ -81,17 +82,16 @@ using FListOutputNames = std::function (const NodeAttrs * \note Register under "FMutateInputs", default return false * FMutateInputs enables mutation order handling correctly. */ -using FMutateInputs = std::function (const NodeAttrs& attrs)>; +using FMutateInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Inference function of certain type. * \tparam AttrType The type of the attribute to be infered. * \return whether all attributes are inferred. */ -template -using FInferNodeEntryAttr = std::function *in_attrs, - std::vector *out_attrs)>; +template +using FInferNodeEntryAttr = std::function* in_attrs, std::vector* out_attrs)>; /*! * \brief Get attribute dictionary from node. @@ -100,9 +100,8 @@ using FInferNodeEntryAttr = std::function - (const NodeAttrs& attrs)>; +using FGetAttrDict = + std::function(const NodeAttrs& attrs)>; /*! * \brief Shape inference function. @@ -155,8 +154,7 @@ using TIsGhost = bool; * * \note Register under "FInplaceOption", by default no inplace can happen. */ -using FInplaceOption = std::function< - std::vector > (const NodeAttrs& attrs)>; +using FInplaceOption = std::function >(const NodeAttrs& attrs)>; /*! * \brief Get if the inplace option is an identity @@ -168,7 +166,7 @@ using FInplaceOption = std::function< * * \note Register under "FInplaceIdentity", by default no identities. */ -using FInplaceIdentity = std::function (const NodeAttrs& attrs)>; +using FInplaceIdentity = std::function(const NodeAttrs& attrs)>; /*! * \brief Get list of inputs in the op whose content are actually not used by the operator @@ -179,8 +177,7 @@ using FInplaceIdentity = std::function (const NodeAttrs& attrs * * \note Register under "FIgnoreInputs". */ -using FIgnoreInputs = std::function< - std::vector (const NodeAttrs& attrs)>; +using FIgnoreInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Get the gradient node of the op node @@ -191,9 +188,8 @@ using FIgnoreInputs = std::function< * * \note Register under "FGradient" */ -using FGradient = std::function( - const ObjectPtr& nodeptr, - const std::vector& out_grads)>; +using FGradient = std::function(const ObjectPtr& nodeptr, + const std::vector& out_grads)>; /*! * \brief Set the attributes of input variable. @@ -202,10 +198,8 @@ using FGradient = std::function( * \param var the input variable * \param index index of var in all inputs */ -using FSetInputVarAttrOnCompose = std::function; +using FSetInputVarAttrOnCompose = + std::function; /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention @@ -226,11 +220,9 @@ using FSetInputVarAttrOnCompose = std::function *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts)>; +using FCorrectLayout = + std::function* ilayouts, + const std::vector* last_ilayouts, std::vector* olayouts)>; /*! * \brief Get a list of inputs that represent graphs instead of data. diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index a6158df5ffdf9..0bccdccd07912 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,9 @@ #ifndef NNVM_PASS_H_ #define NNVM_PASS_H_ -#include #include +#include + #include "base.h" #include "graph.h" @@ -42,7 +43,7 @@ namespace nnvm { * \param src The graph to be transformed. * \return The generated graph. */ -typedef std::function PassFunction; +typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on the input graph. @@ -50,8 +51,7 @@ typedef std::function PassFunction; * \param passes A list of pass names to be applied. * \return The transformed graph */ -Graph ApplyPasses(Graph src, - const std::vector& passes); +Graph ApplyPasses(Graph src, const std::vector& passes); /*! * \brief Apply one pass to the graph. @@ -59,17 +59,12 @@ Graph ApplyPasses(Graph src, * \param pass The name of pass to be applied. * \return The transformed graph. */ -inline Graph ApplyPass(Graph src, const std::string& pass) { - return ApplyPasses(src, {pass}); -} - +inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); } /*! * \brief Registry entry for pass functions. */ -struct PassFunctionReg - : public dmlc::FunctionRegEntryBase { +struct PassFunctionReg : public dmlc::FunctionRegEntryBase { /*! * \brief Whether the pass will change graph structure * If this is false, the pass will only change attributes. @@ -138,7 +133,7 @@ struct PassFunctionReg * }); * \endcode */ -#define NNVM_REGISTER_PASS(name) \ +#define NNVM_REGISTER_PASS(name) \ DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) } // namespace nnvm diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index a7893c6fec567..3097e20223d58 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,14 @@ #ifndef NNVM_PASS_FUNCTIONS_H_ #define NNVM_PASS_FUNCTIONS_H_ -#include #include -#include +#include #include +#include + #include "base.h" -#include "pass.h" #include "graph_attr_types.h" +#include "pass.h" namespace nnvm { namespace pass { @@ -60,7 +61,6 @@ inline std::string SaveJSON(Graph graph) { return ret.GetAttr("json"); } - /*! * \brief Print graph ir * \param graph The graph to be printed @@ -81,9 +81,7 @@ inline std::string PrintGraphIR(Graph graph) { * \param src The input graph. * \return A graph with proper control flow dependencies added. */ -inline Graph OrderMutation(Graph src) { - return ApplyPass(std::move(src), "OrderMutation"); -} +inline Graph OrderMutation(Graph src) { return ApplyPass(std::move(src), "OrderMutation"); } /*! * \brief Infer shapes in the graph given the information. @@ -94,9 +92,7 @@ inline Graph OrderMutation(Graph src) { * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferShape(Graph graph, - ShapeVector shape_inputs, - std::string shape_attr_key = "") { +inline Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key = "") { if (shape_inputs.size() != 0) { graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); } @@ -115,9 +111,7 @@ inline Graph InferShape(Graph graph, * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferType(Graph graph, - DTypeVector dtype_inputs, - std::string dtype_attr_key = "") { +inline Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key = "") { if (dtype_inputs.size() != 0) { graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); } @@ -141,10 +135,8 @@ inline Graph InferType(Graph graph, * \param device_copy_op The name of copy op to be inserted when cross device copy happened. * \return A graph with new attribute "device", cotaining device information of each node. */ -inline Graph PlaceDevice(Graph graph, - std::string device_group_attr_key, - DeviceAssignMap device_assign_map, - std::string device_copy_op) { +inline Graph PlaceDevice(Graph graph, std::string device_group_attr_key, + DeviceAssignMap device_assign_map, std::string device_copy_op) { graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); @@ -159,22 +151,18 @@ inline Graph PlaceDevice(Graph graph, * \param ys_out_grad The symbol for additional gradient to be propagate back to y. * \param aggregate_fun Aggregation function applied to aggregate the inputs. * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph - * \return A new graph, whose outputs correspond to inputs of xs. + * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same + * as like. \param zero_ops Optional, list of operators that outputs a single zero array. The first + * one must be zeros_like. \param copy_op_str Optional, name of the copy operation required to + * handle duplicates on the edge of the graph \return A new graph, whose outputs correspond to + * inputs of xs. */ inline Graph Gradient( - Graph graph, - std::vector ys, - std::vector xs, + Graph graph, std::vector ys, std::vector xs, std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, + std::function attr_hint_fun = nullptr, std::vector zero_ops = std::vector(), std::string copy_op_str = std::string()) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); @@ -198,7 +186,7 @@ inline Graph Gradient( } if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); } return ApplyPass(std::move(graph), "Gradient"); diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index d3555ec726b27..77d385505845f 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -29,10 +29,10 @@ #define NNVM_SYMBOLIC_H_ #include -#include #include -#include #include +#include +#include #include "base.h" #include "node.h" @@ -81,13 +81,13 @@ class NNVM_DLL Symbol { * \brief Print the symbol info to output stream. * \param os The output stream to print to. */ - void Print(std::ostream &os) const; // NOLINT(*) + void Print(std::ostream& os) const; // NOLINT(*) /*! * \brief Get the index-th element from the returned tuple. * \param index Index of multi output. * \return The symbol corresponds to the indexed element. */ - Symbol operator[] (size_t index) const; + Symbol operator[](size_t index) const; /*! * \brief List the input variable nodes. * @@ -139,9 +139,9 @@ class NNVM_DLL Symbol { * \param name Name of returned symbol. * \return A new Symbol which is the composition of current symbol with its arguments. */ - Symbol operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const; + Symbol operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const; /*! * \brief Add control flow dependencies to the operators in symbols. * @@ -201,16 +201,14 @@ class NNVM_DLL Symbol { * * \return The created attribute in format . */ - std::vector > - ListAttrsRecursive() const; + std::vector > ListAttrsRecursive() const; /*! * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. * \param op The operator. * \param attrs The additional attributes. * \return Symbol that can be used to call compose further. */ - static Symbol CreateFunctor(const Op* op, - std::unordered_map attrs); + static Symbol CreateFunctor(const Op* op, std::unordered_map attrs); /*! * \brief Create symbolic functor(AtomicSymbol) by given node attributes. * \param attrs pre-initialized Node attributes. diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index a7f2d26030936..c6d6125aa1948 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,13 @@ #ifndef NNVM_TUPLE_H_ #define NNVM_TUPLE_H_ -#include -#include #include -#include #include #include +#include +#include +#include + #include "base.h" namespace nnvm { @@ -47,29 +48,23 @@ typedef int64_t dim_t; * \tparam ValueType The type of data stored inside tuple. * \sa TShape */ -template +template class Tuple { public: /*! \brief default constructor */ Tuple() = default; /*! \brief destructor */ - inline ~Tuple() { - delete [] data_heap_; - } + inline ~Tuple() { delete[] data_heap_; } /*! * \brief copy constructor from another tuple * \param s the source tuple */ - inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); - } + inline Tuple(const Tuple& s) { this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline Tuple(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline Tuple(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init the vector @@ -82,7 +77,7 @@ class Tuple { * \param src the source shape */ - inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) + inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) this->swap(src); } /*! @@ -91,9 +86,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline Tuple(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -102,9 +96,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline void assign(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline void assign(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); std::copy(begin, end, this->begin()); } @@ -141,7 +134,7 @@ class Tuple { * \param init the source initializer list * \return reference of self */ - inline Tuple &operator=(std::initializer_list init) { + inline Tuple& operator=(std::initializer_list init) { this->assign(init.begin(), init.end()); return *this; } @@ -149,7 +142,7 @@ class Tuple { * \return whether two tuple equals * \param s the tuple to compare against */ - inline bool operator==(const Tuple &s) const { + inline bool operator==(const Tuple& s) const { if (ndim_ != s.ndim_) return false; return std::equal(begin(), end(), s.begin()); } @@ -157,45 +150,33 @@ class Tuple { * \return whether two tuple not equal * \param s the tuple to compare against */ - inline bool operator!=(const Tuple &s) const { - return !(*this == s); - } + inline bool operator!=(const Tuple& s) const { return !(*this == s); } /*! \return the begin data pointer to content of the tuple */ - inline const ValueType *begin() const { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline const ValueType* begin() const { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the begin data pointer to content of the tuple */ - inline ValueType *begin() { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline ValueType* begin() { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the data pointer to end of the tuple */ inline const ValueType* end() const { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return the data pointer to end the tuple */ inline ValueType* end() { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { - return ndim_; - } + inline uint32_t ndim() const { return ndim_; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline ValueType& operator[](size_t i) { - return begin()[i]; - } + inline ValueType& operator[](size_t i) { return begin()[i]; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline const ValueType& operator[](size_t i) const { - return begin()[i]; - } + inline const ValueType& operator[](size_t i) const { return begin()[i]; } /*! * \brief Save Tuple to JSON. * \param writer JSONWriter @@ -219,7 +200,7 @@ class Tuple { * \param t the tuple * \return the ostream */ - friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + friend std::ostream& operator<<(std::ostream& os, const Tuple& t) { os << '['; const ValueType* begin = t.begin(); const ValueType* end = t.end(); @@ -236,7 +217,7 @@ class Tuple { * \param t The tuple * \return the istream */ - friend std::istream &operator>>(std::istream &is, Tuple &t) { + friend std::istream& operator>>(std::istream& is, Tuple& t) { // get ( while (true) { char ch = is.peek(); @@ -252,7 +233,7 @@ class Tuple { if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; - } + } } // Handle empty tuple while (isspace(is.peek())) { @@ -278,10 +259,12 @@ class Tuple { while (true) { ch = is.peek(); if (isspace(ch)) { - is.get(); continue; + is.get(); + continue; } if (ch == ')' || ch == ']') { - is.get(); break; + is.get(); + break; } break; } @@ -302,8 +285,8 @@ class Tuple { * \tparam DType data type that save to * \tparam TStream any stream type that have write */ - template - inline void Save(TStream *strm) const; + template + inline void Save(TStream* strm) const; /*! * \brief load the content from binary stream * \param strm the output stream @@ -311,8 +294,8 @@ class Tuple { * \tparam TStream any stream type that have write * \return whether the load is successful */ - template - inline bool Load(TStream *strm); + template + inline bool Load(TStream* strm); protected: // stack cache size @@ -327,9 +310,8 @@ class Tuple { ValueType* data_heap_{nullptr}; // internal function to change the dimension inline void SetDim(uint32_t ndim) { - if (ndim > kStackCache && - ndim > num_heap_allocated_) { - delete [] data_heap_; + if (ndim > kStackCache && ndim > num_heap_allocated_) { + delete[] data_heap_; data_heap_ = new ValueType[ndim]; num_heap_allocated_ = ndim; } @@ -356,16 +338,14 @@ class TShape : public Tuple { * \brief copy constructor of TShape * \param s source shape. */ - inline TShape(const Tuple& s) { // NOLINT(*) + inline TShape(const Tuple& s) { // NOLINT(*) this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline TShape(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline TShape(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief move constructor. * \param s source shape. @@ -379,9 +359,8 @@ class TShape : public Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline TShape(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline TShape(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -399,13 +378,13 @@ class TShape : public Tuple { * \return self. */ inline TShape& operator=(Tuple&& src) { // NOLINT(*) - TShape(std::move(src)).swap(*this); // NOLINT(*) + TShape(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! \return total number of elements in the shape */ inline size_t Size() const { dim_t size = 1; - const dim_t* start = begin(), *fin = end(); + const dim_t *start = begin(), *fin = end(); for (const dim_t* it = start; it != fin; ++it) { size *= *it; } @@ -418,28 +397,24 @@ class TShape : public Tuple { */ inline size_t ProdShape(int dimstart, int dimend) const { dim_t num = 1; - const dim_t *d = this->data(); + const dim_t* d = this->data(); for (int i = dimstart; i < dimend; ++i) { num *= d[i]; } return num; } /*! \return the begin data pointer to content of the tuple */ - inline const dim_t *data() const { - return begin(); - } + inline const dim_t* data() const { return begin(); } /*! \return the begin data pointer to content of the tuple */ - inline dim_t *data() { - return begin(); - } + inline dim_t* data() { return begin(); } #ifdef MSHADOW_XINLINE - template - inline TShape(const mshadow::Shape &s) {// NOLINT(*) + template + inline TShape(const mshadow::Shape& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } - template - inline TShape(mshadow::Shape &&s) {// NOLINT(*) + template + inline TShape(mshadow::Shape&& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } /*! @@ -448,8 +423,8 @@ class TShape : public Tuple { * \tparam dim shape dimension * \return reference of self */ - template - inline TShape &operator=(const mshadow::Shape &shape) { + template + inline TShape& operator=(const mshadow::Shape& shape) { this->assign(shape.shape_, shape.shape_ + dim); return *this; } @@ -458,11 +433,11 @@ class TShape : public Tuple { * \return the shape requested * \tparam dim dimension of the tensor */ - template + template inline mshadow::Shape get() const { CHECK_EQ(dim, static_cast(ndim())) << "dimension do not match target dimension " << dim << " vs " << ndim(); - const dim_t *d = this->data(); + const dim_t* d = this->data(); mshadow::Shape s; for (int i = 0; i < dim; ++i) { s[i] = d[i]; @@ -476,7 +451,7 @@ class TShape : public Tuple { inline mshadow::Shape<2> FlatTo2D(void) const { mshadow::Shape<2> s; if (ndim() == 0) return mshadow::Shape2(0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[1] = d[ndim() - 1]; dim_t ymax = 1; for (size_t i = 1; i < ndim(); ++i) { @@ -495,7 +470,7 @@ class TShape : public Tuple { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; if (ndim() == 0) return mshadow::Shape3(0, 0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[0] = 1; s.shape_[1] = 1; s.shape_[2] = 1; @@ -516,25 +491,21 @@ class TShape : public Tuple { * \param axis The axis specified. * \return the flat 3d shape */ - inline mshadow::Shape<3> FlatTo3D(size_t axis) const { - return FlatTo3D(axis, axis); - } - inline bool operator==(const TShape &s) const { + inline mshadow::Shape<3> FlatTo3D(size_t axis) const { return FlatTo3D(axis, axis); } + inline bool operator==(const TShape& s) const { if (ndim() != s.ndim()) return false; return std::equal(begin(), end(), s.begin()); } - inline bool operator!=(const TShape &s) const { - return !(*this == s); - } + inline bool operator!=(const TShape& s) const { return !(*this == s); } /*! * \return whether two shape equals * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator==(const mshadow::Shape &s) const { + template + inline bool operator==(const mshadow::Shape& s) const { if (ndim_ != dim) return false; - const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_; + const dim_t* d = dim <= kStackCache ? data_stack_ : data_heap_; for (size_t i = 0; i < dim; ++i) { if (d[i] != s.shape_[i]) return false; } @@ -545,18 +516,16 @@ class TShape : public Tuple { * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator!=(const mshadow::Shape &s) const { + template + inline bool operator!=(const mshadow::Shape& s) const { return !(*this == s); } #endif }; /*! \brief helper function to cast type of container elements */ -template -inline DstIter ShapeTypeCast(const SrcIter begin, - const SrcIter end, - DstIter dst_begin) { +template +inline DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin) { typedef typename std::iterator_traits::value_type SrcDType; typedef typename std::iterator_traits::value_type DstDType; auto cast = [](const SrcDType& dim) { return static_cast(dim); }; @@ -564,7 +533,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin, } /*! \brief helper function to transform a container to TShape with type cast */ -template +template inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { size_t ndim = std::distance(begin, end); TShape res(ndim); @@ -573,9 +542,9 @@ inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline void Tuple::Save(TStream *strm) const { +template +template +inline void Tuple::Save(TStream* strm) const { strm->Write(&ndim_, sizeof(ndim_)); if (typeid(DType) == typeid(ValueType)) { strm->Write(begin(), sizeof(ValueType) * ndim_); @@ -587,9 +556,9 @@ inline void Tuple::Save(TStream *strm) const { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline bool Tuple::Load(TStream *strm) { +template +template +inline bool Tuple::Load(TStream* strm) { if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; this->SetDim(ndim_); size_t nread = sizeof(DType) * ndim_; @@ -607,7 +576,7 @@ inline bool Tuple::Load(TStream *strm) { namespace std { /*! \brief hash function for Tuple. */ -template +template struct hash > { /*! \brief hash a Tuple into unsigned int */ size_t operator()(const nnvm::Tuple& val) const { @@ -621,7 +590,7 @@ struct hash > { }; /*! \brief hash function for TShape. */ -template<> +template <> struct hash { /*! \brief hash a TShape into unsigned int */ size_t operator()(const nnvm::TShape& val) const { @@ -640,11 +609,9 @@ namespace dmlc { DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); // avoid low version of MSVC #if !defined(_MSC_VER) -template +template struct type_name_helper > { - static inline std::string value() { - return "tuple of <" + type_name() + ">"; - } + static inline std::string value() { return "tuple of <" + type_name() + ">"; } }; #endif } // namespace dmlc diff --git a/nnvm/src/c_api/c_api_common.h b/nnvm/src/c_api/c_api_common.h index b3ff36ae606f0..1291947156494 100644 --- a/nnvm/src/c_api/c_api_common.h +++ b/nnvm/src/c_api/c_api_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,23 +29,34 @@ #include #include #include -#include + #include -#include #include +#include +#include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (dmlc::Error & _except_) { \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) - +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (dmlc::Error & _except_) { \ + Finalize; \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! \brief entry to to easily hold returning information */ struct NNAPIThreadLocalEntry { @@ -54,9 +65,9 @@ struct NNAPIThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; /*! \brief result holder for returning handles */ - std::vector ret_handles; + std::vector ret_handles; /*! \brief argument holder to hold symbol */ std::unordered_map kwarg_symbol; }; @@ -69,7 +80,7 @@ typedef dmlc::ThreadLocalStore NNAPIThreadLocalStore; * \param e the exception * \return the return value of API after exception is handled */ -inline int NNAPIHandleException(const dmlc::Error &e) { +inline int NNAPIHandleException(const dmlc::Error& e) { NNAPISetLastError(e.what()); return -1; } diff --git a/nnvm/src/c_api/c_api_error.cc b/nnvm/src/c_api/c_api_error.cc index ba6e1cd37c8a6..c2f90b162e1f9 100644 --- a/nnvm/src/c_api/c_api_error.cc +++ b/nnvm/src/c_api/c_api_error.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief C error handling */ #include + #include "c_api_common.h" struct ErrorEntry { @@ -30,10 +31,6 @@ struct ErrorEntry { typedef dmlc::ThreadLocalStore NNAPIErrorStore; -const char *NNGetLastError() { - return NNAPIErrorStore::Get()->last_error.c_str(); -} +const char* NNGetLastError() { return NNAPIErrorStore::Get()->last_error.c_str(); } -void NNAPISetLastError(const char* msg) { - NNAPIErrorStore::Get()->last_error = msg; -} +void NNAPISetLastError(const char* msg) { NNAPIErrorStore::Get()->last_error = msg; } diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index cc5449b0fbbe6..a547476e4c7e0 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,17 +21,18 @@ * \file c_api_graph.cc * \brief C API related to Graph IR. */ +#include #include -#include -#include #include +#include #include -#include +#include + #include "c_api_common.h" using namespace nnvm; -int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) { +int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph) { Graph* g = new Graph(); API_BEGIN(); g->outputs = static_cast(symbol)->outputs; @@ -45,7 +46,7 @@ int NNGraphFree(GraphHandle handle) { API_END(); } -int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { +int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol) { Symbol* s = new Symbol(); API_BEGIN(); s->outputs = static_cast(graph)->outputs; @@ -53,20 +54,15 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { API_END_HANDLE_ERROR(delete s); } -int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list) { +int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list) { API_BEGIN(); Symbol* s = static_cast(list); Graph* g = static_cast(handle); - g->attrs[std::string(key)] - = std::make_shared(s->outputs); + g->attrs[std::string(key)] = std::make_shared(s->outputs); API_END(); } -int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value) { +int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value) { API_BEGIN(); Graph* g = static_cast(handle); std::string temp(json_value); @@ -78,11 +74,8 @@ int NNGraphSetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success) { - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, int* success) { + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); Graph* g = static_cast(handle); std::string skey(key); @@ -100,10 +93,8 @@ int NNGraphGetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst) { +int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst) { Graph* g = new Graph(); API_BEGIN(); std::vector vpass; diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 7ca56035acaeb..2127997da05a6 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -24,14 +24,14 @@ #include #include #include + #include "c_api_common.h" using namespace nnvm; -int NNListAllOpNames(nn_uint *out_size, - const char*** out_array) { +int NNListAllOpNames(nn_uint* out_size, const char*** out_array) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); ret->ret_vec_str = dmlc::Registry::ListAllNames(); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); @@ -43,40 +43,31 @@ int NNListAllOpNames(nn_uint *out_size, API_END(); } -int NNGetOpHandle(const char* op_name, - OpHandle* op_out) { +int NNGetOpHandle(const char* op_name, OpHandle* op_out) { API_BEGIN(); *op_out = (OpHandle)Op::Get(op_name); // NOLINT(*) API_END(); } -int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array) { +int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); + auto& vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep) { +int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep) { API_BEGIN(); - static_cast(handle)->AddControlDeps( - *static_cast(src_dep)); + static_cast(handle)->AddControlDeps(*static_cast(src_dep)); API_END(); } -int NNGetOpInfo(OpHandle handle, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - const Op *op = static_cast(handle); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGetOpInfo(OpHandle handle, const char** name, const char** description, nn_uint* num_doc_args, + const char*** arg_names, const char*** arg_type_infos, + const char*** arg_descriptions, const char** return_type) { + const Op* op = static_cast(handle); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); *name = op->name.c_str(); @@ -100,12 +91,9 @@ int NNGetOpInfo(OpHandle handle, API_END(); } -int NNSymbolCreateAtomicSymbol(OpHandle creator, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateAtomicSymbol(OpHandle creator, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); const Op* op = static_cast(creator); std::unordered_map kwargs; @@ -117,19 +105,17 @@ int NNSymbolCreateAtomicSymbol(OpHandle creator, API_END_HANDLE_ERROR(delete s;); } -int NNSymbolCreateVariable(const char *name, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateVariable(const char* name, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = Symbol::CreateVariable(name); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out) { - Symbol *s = new Symbol(); - Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) +int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out) { + Symbol* s = new Symbol(); + Symbol** sym_arr = (Symbol**)symbols; // NOLINT(*) API_BEGIN(); std::vector syms; for (nn_uint i = 0; i < num_symbols; ++i) { @@ -140,28 +126,24 @@ int NNSymbolCreateGroup(nn_uint num_symbols, API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = (*static_cast(symbol))[index]; *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetChildren(); *out = s; @@ -174,17 +156,17 @@ int NNSymbolFree(SymbolHandle symbol) { API_END(); } -int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->Copy(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolPrint(SymbolHandle symbol, const char** out_str) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; s->Print(os); @@ -193,12 +175,9 @@ int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { API_END(); } -int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int* success) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); @@ -210,27 +189,20 @@ int NNSymbolGetAttr(SymbolHandle symbol, API_END(); } -int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, - const char** vals) { - Symbol *s = static_cast(symbol); +int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** vals) { + Symbol* s = static_cast(symbol); API_BEGIN(); std::vector > kwargs; for (nn_uint i = 0; i < num_param; ++i) { - kwargs.emplace_back( - std::make_pair(std::string(keys[i]), std::string(vals[i]))); + kwargs.emplace_back(std::make_pair(std::string(keys[i]), std::string(vals[i]))); } s->SetAttrs(kwargs); API_END(); } -int NNSymbolListAttrs(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char*** out) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListAttrs(SymbolHandle symbol, int option, nn_uint* out_size, const char*** out) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(option)); // NOLINT(*) @@ -252,12 +224,10 @@ int NNSymbolListAttrs(SymbolHandle symbol, API_END(); } -int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); ret->ret_handles.resize(0); @@ -272,15 +242,12 @@ int NNSymbolListInputVariables(SymbolHandle symbol, API_END(); } -int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = - s->ListInputNames(Symbol::ListInputOption(option)); + ret->ret_vec_str = s->ListInputNames(Symbol::ListInputOption(option)); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { @@ -291,11 +258,9 @@ int NNSymbolListInputNames(SymbolHandle symbol, API_END(); } -int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = s->ListOutputNames(); ret->ret_vec_charp.resize(0); @@ -308,24 +273,19 @@ int NNSymbolListOutputNames(SymbolHandle symbol, API_END(); } -int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count) { - Symbol *s = static_cast(symbol); +int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count) { + Symbol* s = static_cast(symbol); API_BEGIN(); *output_count = static_cast(s->outputs.size()); API_END(); } -int NNSymbolCompose(SymbolHandle sym, - const char *name, - nn_uint num_args, - const char** keys, +int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, const char** keys, SymbolHandle* args) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); std::string& s_name = ret->ret_str; - std::unordered_map& kwargs - = ret->kwarg_symbol; + std::unordered_map& kwargs = ret->kwarg_symbol; kwargs.clear(); if (name != nullptr) { s_name = name; @@ -335,8 +295,7 @@ int NNSymbolCompose(SymbolHandle sym, Symbol* s = static_cast(sym); if (keys == nullptr && num_args != 0) { kwargs.clear(); - array_view parg( - (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) + array_view parg((Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) s->Compose(parg, kwargs, s_name); } else { for (nn_uint i = 0; i < num_args; ++i) { diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index c3ae60e999370..fd5b64f4777d8 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -23,6 +23,7 @@ */ #include #include + #include namespace nnvm { @@ -39,23 +40,22 @@ const IndexedGraph& Graph::indexed_graph() const { // e.g. the main graph is level 0 // subgraphs of the main graph is level 1 // subgraphs of the subgraphs of the main graph is level 2 -static void SubgraphSanityCheck(const std::vector> &subgraphs) { +static void SubgraphSanityCheck(const std::vector>& subgraphs) { std::vector*> curr_level; std::vector*> next_level; std::unordered_map node2level; - for (auto &subgraph : subgraphs) - next_level.push_back(&subgraph->outputs); + for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs); for (uint32_t level = 0; !next_level.empty(); ++level) { curr_level.swap(next_level); next_level.clear(); - for (const std::vector *graph_ptr : curr_level) { - const std::vector &graph = *graph_ptr; + for (const std::vector* graph_ptr : curr_level) { + const std::vector& graph = *graph_ptr; DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { - nnvm::Node *node = n.get(); + nnvm::Node* node = n.get(); // if the node is visited, but on a different level, then check failed // if check failed here or before, we stop doing anything, but raise an error CHECK(!node2level.count(node) || node2level[node] == level) - << "A subgraph should not depend on the outputs of nodes on higher levels"; + << "A subgraph should not depend on the outputs of nodes on higher levels"; // otherwise, this node belongs to the current level node2level[node] = level; // subgraphs of current node belongs to next level @@ -68,55 +68,51 @@ static void SubgraphSanityCheck(const std::vector> &subg } // implement constructor from graph -IndexedGraph::IndexedGraph(const Graph &g) { +IndexedGraph::IndexedGraph(const Graph& g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; std::vector> subgraphs; - DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] - (const ObjectPtr& n) { - const auto& is_ghost = Op::GetAttr("TIsGhost"); - if (!n->is_variable() && is_ghost.get(n->op(), false)) return; - CHECK_LT(nodes_.size(), std::numeric_limits::max()); - uint32_t nid = static_cast(nodes_.size()); - CHECK(n); - for (const auto &subgraph : n->attrs.subgraphs) - subgraphs.push_back(subgraph); - // nodes_ - IndexedGraph::Node new_node; - new_node.source = n.get(); - new_node.weak_ref = n; - nodes_.emplace_back(std::move(new_node)); - // arg_nodes_ - if (n->is_variable()) { - input_nodes_.push_back(nid); - } - // node2index_ - node2index_[n.get()] = nid; - // entry rptr - entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); - // input entries - for (const auto& e : n->inputs) { - auto it = node2index_.find(e.node.get()); - if (it == node2index_.end() || it->first != e.node.get()) continue; - input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); - } - inputs_rptr.push_back(input_entries_.size()); - // control deps - for (const auto& nptr : n->control_deps) { - if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; - auto it = node2index_.find(nptr.get()); - CHECK(it != node2index_.end()) << "control dep not found in graph"; - control_deps_.push_back(it->second); - } - control_rptr.push_back(control_deps_.size()); + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; + CHECK_LT(nodes_.size(), std::numeric_limits::max()); + uint32_t nid = static_cast(nodes_.size()); + CHECK(n); + for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph); + // nodes_ + IndexedGraph::Node new_node; + new_node.source = n.get(); + new_node.weak_ref = n; + nodes_.emplace_back(std::move(new_node)); + // arg_nodes_ + if (n->is_variable()) { + input_nodes_.push_back(nid); + } + // node2index_ + node2index_[n.get()] = nid; + // entry rptr + entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); + // input entries + for (const auto& e : n->inputs) { + auto it = node2index_.find(e.node.get()); + if (it == node2index_.end() || it->first != e.node.get()) continue; + input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); + } + inputs_rptr.push_back(input_entries_.size()); + // control deps + for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; + auto it = node2index_.find(nptr.get()); + CHECK(it != node2index_.end()) << "control dep not found in graph"; + control_deps_.push_back(it->second); + } + control_rptr.push_back(control_deps_.size()); }); - if (!subgraphs.empty()) - SubgraphSanityCheck(subgraphs); + if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs); for (const auto& e : g.outputs) { - outputs_.emplace_back(NodeEntry{ - node2index_.at(e.node.get()), e.index, e.version}); + outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version}); } static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -124,10 +120,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { // input_entries_ and control_rptr must not change after this step. const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].inputs = array_view( - iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); - if (nodes_[nid].source->op() != nullptr && - fmutate_inputs.count(nodes_[nid].source->op())) { + nodes_[nid].inputs = + array_view(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); + if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) { for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); } @@ -135,8 +130,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { } const uint32_t* cptr = dmlc::BeginPtr(control_deps_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].control_deps = array_view( - cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); + nodes_[nid].control_deps = + array_view(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); } } diff --git a/nnvm/src/core/op.cc b/nnvm/src/core/op.cc index eb51d4b3cd743..08a11dff9a028 100644 --- a/nnvm/src/core/op.cc +++ b/nnvm/src/core/op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,8 @@ #include #include -#include #include +#include #include #include @@ -46,7 +46,7 @@ struct OpManager { // storage of additional attribute table. std::unordered_map > attr; // storage of existing triggers - std::unordered_map > > tmap; + std::unordered_map > > tmap; // group of each operator. std::vector > op_group; // get singleton of the @@ -70,14 +70,13 @@ Op& Op::add_alias(const std::string& alias) { // NOLINT(*) // find operator by name const Op* Op::Get(const std::string& name) { const Op* op = dmlc::Registry::Find(name); - CHECK(op != nullptr) - << "Operator " << name << " is not registered"; + CHECK(op != nullptr) << "Operator " << name << " is not registered"; return op; } // Get attribute map by key const any* Op::GetAttrMap(const std::string& key) { - auto& dict = OpManager::Global()->attr; + auto& dict = OpManager::Global()->attr; auto it = dict.find(key); if (it != dict.end()) { return it->second.get(); @@ -87,8 +86,7 @@ const any* Op::GetAttrMap(const std::string& key) { } // update attribute map -void Op::UpdateAttrMap(const std::string& key, - std::function updater) { +void Op::UpdateAttrMap(const std::string& key, std::function updater) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); std::unique_ptr& value = mgr->attr[key]; @@ -96,16 +94,14 @@ void Op::UpdateAttrMap(const std::string& key, if (updater != nullptr) updater(value.get()); } -void Op::AddGroupTrigger(const std::string& group_name, - std::function trigger) { +void Op::AddGroupTrigger(const std::string& group_name, std::function trigger) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); auto& tvec = mgr->tmap[group_name]; tvec.push_back(trigger); auto& op_group = mgr->op_group; for (const Op* op : dmlc::Registry::List()) { - if (op->index_ < op_group.size() && - op_group[op->index_].count(group_name) != 0) { + if (op->index_ < op_group.size() && op_group[op->index_].count(group_name) != 0) { trigger((Op*)op); // NOLINT(*) } } diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index b43d470f3eb3b..974cd2b35918a 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief Support for pass registry. */ #include + #include namespace dmlc { @@ -31,7 +32,7 @@ DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg); namespace nnvm { -const PassFunctionReg* FindPassDep(const std::string&attr_name) { +const PassFunctionReg* FindPassDep(const std::string& attr_name) { for (auto* r : dmlc::Registry::List()) { for (auto& s : r->graph_attr_targets) { if (s == attr_name) return r; @@ -40,13 +41,11 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { return nullptr; } -Graph ApplyPasses(Graph g, - const std::vector& pass) { +Graph ApplyPasses(Graph g, const std::vector& pass) { std::vector fpass; for (auto& name : pass) { auto* reg = dmlc::Registry::Find(name); - CHECK(reg != nullptr) - << "Cannot find pass " << name << " in the registry"; + CHECK(reg != nullptr) << "Cannot find pass " << name << " in the registry"; fpass.push_back(reg); } @@ -58,10 +57,8 @@ Graph ApplyPasses(Graph g, if (pass_dep != nullptr) { msg = " The attribute is provided by pass " + pass_dep->name; } - LOG(FATAL) << "Graph attr dependency " << dep - << " is required by pass " << r->name - << " but is not available " - << msg; + LOG(FATAL) << "Graph attr dependency " << dep << " is required by pass " << r->name + << " but is not available " << msg; } } g = r->body(std::move(g)); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 86dc7e63c4034..12b8675d0bd70 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -22,13 +22,13 @@ * \brief Symbolic graph composition API. */ #include -#include #include +#include namespace nnvm { namespace symbol_constants { -const char *kNamespaceSeparator = "$"; +const char* kNamespaceSeparator = "$"; } // namespace symbol_constants // auxililary version attribute in variable. @@ -48,7 +48,7 @@ ObjectPtr CreateVariableNode(const std::string& name) { // If the node's op mutates a certain input variable, // The version of that varaible will increase // version is used to implicitly order the mutation sequences -inline void UpdateNodeVersion(Node *n) { +inline void UpdateNodeVersion(Node* n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); for (NodeEntry& e : n->inputs) { if (e.node->is_variable()) { @@ -58,16 +58,14 @@ inline void UpdateNodeVersion(Node *n) { if (fmutate_inputs.count(n->op()) != 0) { for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) { NodeEntry& e = n->inputs[i]; - CHECK(e.node->is_variable()) - << "Mutation target can only be Variable"; + CHECK(e.node->is_variable()) << "Mutation target can only be Variable"; // increase the version of the variable. e.version = ++nnvm::get(e.node->attrs.parsed).version; } } } -inline std::string DefaultVarName(const std::string &op_name, - const std::string &arg_name) { +inline std::string DefaultVarName(const std::string& op_name, const std::string& arg_name) { if (op_name.length() == 0) { return arg_name; } else { @@ -75,8 +73,7 @@ inline std::string DefaultVarName(const std::string &op_name, } } -inline void KeywordArgumentMismatch(const char *source, - const std::vector& user_args, +inline void KeywordArgumentMismatch(const char* source, const std::vector& user_args, const array_view& args) { std::unordered_set keys(args.begin(), args.end()); std::ostringstream head, msg; @@ -87,16 +84,13 @@ inline void KeywordArgumentMismatch(const char *source, for (const auto& key : user_args) { if (keys.count(key) == 0) { - LOG(FATAL) << source - << "Keyword argument name " << key << " not found." - << msg.str(); + LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } -template -inline std::vector GetKeys( - const std::unordered_map& kwargs) { +template +inline std::vector GetKeys(const std::unordered_map& kwargs) { std::vector keys(kwargs.size()); std::transform(kwargs.begin(), kwargs.end(), keys.begin(), [](decltype(*kwargs.begin())& kv) { return kv.first; }); @@ -117,14 +111,14 @@ Symbol Symbol::Copy() const { std::unordered_map old_new; // use DFSVisit to copy all the nodes DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) { - ObjectPtr np = Node::Create(); - np->attrs = node->attrs; - old_new[node.get()] = std::move(np); - }); + ObjectPtr np = Node::Create(); + np->attrs = node->attrs; + old_new[node.get()] = std::move(np); + }); // connect nodes of new graph - for (const auto &kv : old_new) { + for (const auto& kv : old_new) { for (const NodeEntry& e : kv.first->inputs) { - Node *ptr = e.node.get(); + Node* ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } for (const ObjectPtr& p : kv.first->control_deps) { @@ -133,66 +127,64 @@ Symbol Symbol::Copy() const { } // set the head Symbol ret; - for (const NodeEntry &e : outputs) { + for (const NodeEntry& e : outputs) { ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); } return ret; } -void Symbol::Print(std::ostream &os) const { - if (outputs.size() == 1 && - outputs[0].node->inputs.size() == 0 && +void Symbol::Print(std::ostream& os) const { + if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0 && outputs[0].node->control_deps.size() == 0) { if (outputs[0].node->is_variable()) { os << "Variable:" << outputs[0].node->attrs.name << '\n'; } else { - os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n'; + os << "AtomicFunctor " + << " Op:" << outputs[0].node->op()->name << '\n'; } } else { // use DFSVisit to copy all the nodes os << "Symbol Outputs:\n"; for (size_t i = 0; i < outputs.size(); ++i) { - os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name - << '(' << outputs[i].index << ")\n"; + os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index + << ")\n"; } DFSVisit(this->outputs, [&os](const ObjectPtr& node) { - if (node->is_variable()) { - os << "Variable:" << node->attrs.name << '\n'; - } else { - os << "--------------------\n"; - os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' - << "Inputs:\n"; - for (size_t i = 0; i < node->inputs.size(); ++i) { - const NodeEntry& e = node->inputs[i]; - os << "\targ[" << i << "]=" << e.node->attrs.name - << '(' << e.index << ")"; - if (e.node->is_variable()) { - os << " version=" << e.version << '\n'; - } else { - os << '\n'; - } + if (node->is_variable()) { + os << "Variable:" << node->attrs.name << '\n'; + } else { + os << "--------------------\n"; + os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + const NodeEntry& e = node->inputs[i]; + os << "\targ[" << i << "]=" << e.node->attrs.name << '(' << e.index << ")"; + if (e.node->is_variable()) { + os << " version=" << e.version << '\n'; + } else { + os << '\n'; } - if (!node->attrs.dict.empty()) { - os << "Attrs:\n"; - // make an ordered copy because unordered_map doesn't guarantee order. - std::map sorted_dict( - node->attrs.dict.begin(), node->attrs.dict.end()); - for (auto &kv : sorted_dict) { - os << '\t' << kv.first << '=' << kv.second << '\n'; - } + } + if (!node->attrs.dict.empty()) { + os << "Attrs:\n"; + // make an ordered copy because unordered_map doesn't guarantee order. + std::map sorted_dict(node->attrs.dict.begin(), + node->attrs.dict.end()); + for (auto& kv : sorted_dict) { + os << '\t' << kv.first << '=' << kv.second << '\n'; } - if (node->control_deps.size() != 0) { - os << "Control deps:\n"; - for (size_t i = 0; i < node->control_deps.size(); ++i) { - os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; - } + } + if (node->control_deps.size() != 0) { + os << "Control deps:\n"; + for (size_t i = 0; i < node->control_deps.size(); ++i) { + os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; } } - }); + } + }); } } -Symbol Symbol::operator[] (size_t index) const { +Symbol Symbol::operator[](size_t index) const { size_t nreturn = outputs.size(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; if (nreturn == 1) { @@ -208,25 +200,25 @@ std::vector Symbol::ListInputs(ListInputOption option) const { std::vector ret; if (option == kAll) { ret.reserve(this->outputs.size()); - DFSVisit(this->outputs, [&ret](const ObjectPtr &node) { - if (node->is_variable()) { - ret.push_back(node); - } - }); + DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { + if (node->is_variable()) { + ret.push_back(node); + } + }); } else { std::unordered_set mutable_set; std::vector vlist; vlist.reserve(this->outputs.size()); static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) { - if (node->is_variable()) { - vlist.push_back(node); - } else if (fmutate_inputs.count(node->op())) { - for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ - mutable_set.insert(node->inputs[i].node.get()); - } + DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr& node) { + if (node->is_variable()) { + vlist.push_back(node); + } else if (fmutate_inputs.count(node->op())) { + for (uint32_t i : fmutate_inputs[node->op()](node->attrs)) { + mutable_set.insert(node->inputs[i].node.get()); } - }); + } + }); ret.reserve(vlist.size()); for (const ObjectPtr& node : vlist) { if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || @@ -252,7 +244,7 @@ std::vector Symbol::ListOutputNames() const { std::vector ret; ret.reserve(outputs.size()); - for (auto &head : outputs) { + for (auto& head : outputs) { if (head.node->is_variable()) { ret.push_back(head.node->attrs.name); } else { @@ -291,8 +283,7 @@ void Symbol::Compose(const array_view& args, Node* n = outputs[0].node.get(); FInputGraph fng = fgraph.get(n->op(), nullptr); std::vector garg_idx; - if (fng != nullptr) - garg_idx = fng(n->attrs); + if (fng != nullptr) garg_idx = fng(n->attrs); // The names of the arguments that contain graphs. FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); @@ -300,8 +291,7 @@ void Symbol::Compose(const array_view& args, std::vector garg_names(garg_idx.size()); for (size_t i = 0; i < garg_idx.size(); i++) { size_t idx = garg_idx[i]; - if (idx < arg_names.size()) - garg_names[i] = arg_names[idx]; + if (idx < arg_names.size()) garg_names[i] = arg_names[idx]; } // parameter check. @@ -309,13 +299,13 @@ void Symbol::Compose(const array_view& args, // If the argument isn't a graph, it should have only one output. if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) CHECK_EQ(args[i]->outputs.size(), 1U) - << "Argument " << i << " is a tuple, single value is required"; + << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - if (garg_names.empty() - || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + if (garg_names.empty() || + std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) CHECK_EQ(kv.second->outputs.size(), 1U) - << "Keyword Argument " << kv.first << " is a tuple, single value is required"; + << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name if (!name.empty()) outputs[0].node->attrs.name = name; @@ -323,14 +313,14 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { uint32_t n_req = n->num_inputs(); - std::vector arg_vec(args.begin(), args.end()); + std::vector arg_vec(args.begin(), args.end()); std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); // If one of the input arguments is a graph, we need to remove it from the // list. if (fng != nullptr) { std::vector idxes = fng(n->attrs); for (auto idx : idxes) { - const Symbol *sym; + const Symbol* sym; if (idx < arg_vec.size()) { sym = arg_vec[idx]; } else { @@ -339,8 +329,7 @@ void Symbol::Compose(const array_view& args, sym = it->second; kwarg_map.erase(it); } - if (n_req != kVarg) - n_req--; + if (n_req != kVarg) n_req--; n->attrs.subgraphs.push_back(std::make_shared(*sym)); } // Because idxes does not contain duplicates, the loop below functions well. @@ -358,8 +347,7 @@ void Symbol::Compose(const array_view& args, if (n_req != kVarg) { n->inputs.resize(n_req); CHECK_LE(arg_vec.size(), n_req) - << "Incorrect number of arguments, requires " << n_req - << ", provided " << arg_vec.size(); + << "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size(); for (size_t i = 0; i < arg_vec.size(); ++i) { n->inputs[i] = arg_vec[i]->outputs[0]; } @@ -375,8 +363,7 @@ void Symbol::Compose(const array_view& args, n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { - n->inputs[i] = NodeEntry{ - CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; + n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; // copy attribute of parent over automatically created variables n->inputs[i].node->attrs.dict = n->attrs.dict; } @@ -409,20 +396,19 @@ void Symbol::Compose(const array_view& args, } } else { // general composition - CHECK_EQ(args.size(), 0U) - << "General composition only support kwargs for now"; + CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now"; size_t nmatched = 0; size_t arg_counter = 0; - std::unordered_map replace_map; + std::unordered_map replace_map; // replace map stores the existing replacement plan for arguments node - auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] - (const ObjectPtr &node) { + auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, + &replace_map](const ObjectPtr& node) { if (node->is_variable()) { if (arg_counter < args.size()) { replace_map[node.get()] = &(args[arg_counter]->outputs[0]); ++arg_counter; } else { - // match kwargs + // match kwargs auto kit = kwargs.find(node->attrs.name); if (kit != kwargs.end()) { replace_map[node.get()] = &(kit->second->outputs[0]); @@ -436,12 +422,11 @@ void Symbol::Compose(const array_view& args, if (nmatched == kwargs.size() && arg_counter <= args.size()) { std::vector update_nodes; std::vector > replace_plan; - auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] - (const ObjectPtr &node) { + auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) { // visit all the childs, find possible replacement bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { - NodeEntry *e = &(node->inputs[i]); + NodeEntry* e = &(node->inputs[i]); if (e->node->is_variable()) { auto iter = replace_map.find(e->node.get()); if (iter != replace_map.end()) { @@ -479,17 +464,16 @@ void Symbol::Compose(const array_view& args, } } -Symbol Symbol::operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const { +Symbol Symbol::operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const { Symbol s = this->Copy(); s.Compose(args, kwargs, name); return s; } void Symbol::AddControlDeps(const Symbol& src) { - CHECK_EQ(outputs.size(), 1U) - << "AddControlDeps only works for nongrouped symbol"; + CHECK_EQ(outputs.size(), 1U) << "AddControlDeps only works for nongrouped symbol"; Node* n = outputs[0].node.get(); for (const NodeEntry& sp : src.outputs) { n->control_deps.push_back(sp.node); @@ -500,21 +484,21 @@ Symbol Symbol::GetInternals() const { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { - Node* n = node.get(); - if (n->is_variable()) { - // grab version from variable. - VariableParam& param = nnvm::get(n->attrs.parsed); - ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); - } else { - uint32_t nout = n->num_outputs(); - if (fnum_vis_output.count(n->op())) { - nout = fnum_vis_output[n->op()](n->attrs); - } - for (uint32_t i = 0; i < nout; ++i) { - ret.outputs.emplace_back(NodeEntry{node, i, 0}); - } + Node* n = node.get(); + if (n->is_variable()) { + // grab version from variable. + VariableParam& param = nnvm::get(n->attrs.parsed); + ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); + } else { + uint32_t nout = n->num_outputs(); + if (fnum_vis_output.count(n->op())) { + nout = fnum_vis_output[n->op()](n->attrs); } - }); + for (uint32_t i = 0; i < nout; ++i) { + ret.outputs.emplace_back(NodeEntry{node, i, 0}); + } + } + }); return ret; } @@ -533,8 +517,7 @@ Symbol Symbol::GetChildren() const { void Symbol::SetAttrs(const std::vector >& attrs) { Node* node = outputs[0].node.get(); for (const NodeEntry& e : outputs) { - CHECK(node == e.node.get()) - << "Symbol.SetAttrs only works for non-grouped symbol"; + CHECK(node == e.node.get()) << "Symbol.SetAttrs only works for non-grouped symbol"; } for (const auto& kv : attrs) { if (kv.first == "name") { @@ -583,29 +566,27 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op if (option == kRecursive) { std::unordered_map ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; - } - }); + for (const auto& it : n->attrs.dict) { + ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; + } + }); return ret; } else { return outputs[0].node->attrs.dict; } } -std::vector > - Symbol::ListAttrsRecursive() const { +std::vector > Symbol::ListAttrsRecursive() const { std::vector > ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); - } - }); + for (const auto& it : n->attrs.dict) { + ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); + } + }); return ret; } -Symbol Symbol::CreateFunctor(const Op* op, - std::unordered_map attrs) { +Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; ObjectPtr n = Node::Create(); @@ -641,9 +622,9 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { return s; } -Symbol Symbol::CreateGroup(const std::vector &symbols) { +Symbol Symbol::CreateGroup(const std::vector& symbols) { Symbol ret; - for (const auto &s : symbols) { + for (const auto& s : symbols) { ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end()); } return ret; diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc index bdb7dbab6aba3..b9024a56d1436 100644 --- a/nnvm/src/pass/correct_layout.cc +++ b/nnvm/src/pass/correct_layout.cc @@ -22,16 +22,15 @@ * \brief Infer and correct layout. */ #include -#include #include -#include #include +#include +#include namespace nnvm { namespace pass { -nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, - const Layout& dst) { +nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; nnvm::ObjectPtr n = nnvm::Node::Create(); @@ -50,8 +49,7 @@ using LayoutAttrDict = std::unordered_map >; * insert layout transform nodes automatically. */ nnvm::Graph CorrectLayout(nnvm::Graph src) { - static auto& op_correct_layout = - nnvm::Op::GetAttr("FCorrectLayout"); + static auto& op_correct_layout = nnvm::Op::GetAttr("FCorrectLayout"); const IndexedGraph& idx = src.indexed_graph(); std::vector mirror_vec(idx.num_nodes(), nullptr); @@ -65,13 +63,12 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. - auto input_iter = std::find( - idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); + auto input_iter = std::find(idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); CHECK(input_iter != idx.input_nodes().cend()); int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); if (src.HasAttr("layout_inputs")) { - new_layouts[new_node.get()] = - {src.GetAttr >("layout_inputs")[input_id]}; + new_layouts[new_node.get()] = { + src.GetAttr >("layout_inputs")[input_id]}; } else { new_layouts[new_node.get()] = {Layout::Undef()}; } @@ -110,9 +107,9 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { } if (op_correct_layout.count(new_node->op())) { - const auto &flayout = op_correct_layout[new_node->op()]; + const auto& flayout = op_correct_layout[new_node->op()]; CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) - << "Layout infer fail"; + << "Layout infer fail"; CHECK_EQ(request_ilayouts.size(), num_inputs); CHECK_EQ(produce_olayouts.size(), num_outputs); } @@ -175,10 +172,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { // register pass NNVM_REGISTER_PASS(CorrectLayout) -.describe("Return a layout-transformed graph of src.") -.set_body(CorrectLayout) -.provide_graph_attr("layout") -.set_change_graph(true); + .describe("Return a layout-transformed graph of src.") + .set_body(CorrectLayout) + .provide_graph_attr("layout") + .set_change_graph(true); DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout); diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 9c30a785cac22..1df3af7ffaaf5 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -22,8 +22,9 @@ * \brief Passes that takes gradient of the graph * This code code was modified based on mxnet codebase by Min Lin */ -#include #include +#include + #include #include @@ -53,8 +54,7 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { } } -bool CheckGradAllZero(const std::vector& grads, - const std::vector& zero_ops) { +bool CheckGradAllZero(const std::vector& grads, const std::vector& zero_ops) { if (!grads.size() || !zero_ops.size()) return false; for (const auto& g : grads) { bool found = false; @@ -82,22 +82,18 @@ struct GradEntry { Graph Gradient(Graph src) { using nnvm::FGradient; - using MirrorFun = std::function; - using AttrHintFun = std::function; + using MirrorFun = std::function; + using AttrHintFun = std::function; - CHECK_NE(src.attrs.count("grad_ys"), 0U) - << "Gradient require grad_ys to be presented."; + CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) << "Gradient require grad_ys_out_grad to be presented."; - CHECK_NE(src.attrs.count("grad_xs"), 0U) - << "Gradient require grad_xs to be presented."; - const std::vector& ys = - src.GetAttr >("grad_ys"); + CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; + const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); - const std::vector& xs = - src.GetAttr >("grad_xs"); - using AggFun = std::function&& inputs)>; + const std::vector& xs = src.GetAttr >("grad_xs"); + using AggFun = std::function && inputs)>; AggFun agg_fun = DefaultAggregateGradient; if (src.attrs.count("grad_aggregate_fun") != 0) { agg_fun = src.GetAttr("grad_aggregate_fun"); @@ -114,31 +110,30 @@ Graph Gradient(Graph src) { if (src.attrs.count("zero_ops") != 0) { zero_ops = src.GetAttr >("zero_ops"); } - const Op* copy_op = (src.attrs.count("copy_op") != 0) ? - Op::Get(src.GetAttr("copy_op")) : - nullptr; + const Op* copy_op = + (src.attrs.count("copy_op") != 0) ? Op::Get(src.GetAttr("copy_op")) : nullptr; // topo sort std::vector topo_order; std::unordered_map > output_grads; DFSVisit(ys, [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { NodeEntry ograd = ys_out_grad[i]; - output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; + output_grads[ys[i].node.get()][ys[i].index].grads = {ograd}; } // Check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " << i+1 << "-th variable " + << "Cannot differentiate with respect to the " << i + 1 << "-th variable " << "because it is unreachable from the outputs."; } @@ -211,8 +206,7 @@ Graph Gradient(Graph src) { LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } - for (const auto& nodeEntry : input_grads) - CHECK(nodeEntry.node); + for (const auto& nodeEntry : input_grads) CHECK(nodeEntry.node); auto git = input_grads.begin(); CHECK((*rit)->inputs.size() <= input_grads.size()); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { @@ -252,12 +246,12 @@ Graph Gradient(Graph src) { copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { - copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); } } else { - ret.outputs[counter] = entry.sum; + ret.outputs[counter] = entry.sum; } ++counter; } @@ -271,12 +265,12 @@ Graph Gradient(Graph src) { // register pass NNVM_REGISTER_PASS(Gradient) -.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") -.set_body(Gradient) -.set_change_graph(true) -.depend_graph_attr("grad_ys") -.depend_graph_attr("grad_xs") -.depend_graph_attr("grad_ys_out_grad"); + .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") + .set_body(Gradient) + .set_change_graph(true) + .depend_graph_attr("grad_ys") + .depend_graph_attr("grad_xs") + .depend_graph_attr("grad_ys_out_grad"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/graph_algorithm.h b/nnvm/src/pass/graph_algorithm.h index 1d274ff3b96d0..b305c08bc05f7 100644 --- a/nnvm/src/pass/graph_algorithm.h +++ b/nnvm/src/pass/graph_algorithm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,12 @@ * \brief This header contains graph algorithms on StaticGraph. * It is used compute informations such as whether two * operations can run in parallel, and helps allocation. -*/ + */ #ifndef NNVM_PASS_GRAPH_ALGORITHM_H_ #define NNVM_PASS_GRAPH_ALGORITHM_H_ #include + #include namespace nnvm { @@ -41,10 +42,8 @@ namespace pass { * \param path the output path of nodes. * \return the total reward of best path. */ -inline uint32_t FindBestPath( - const IndexedGraph& graph, - const std::vector& node_reward, - std::vector* path) { +inline uint32_t FindBestPath(const IndexedGraph& graph, const std::vector& node_reward, + std::vector* path) { const uint32_t num_nodes = static_cast(graph.num_nodes()); CHECK_EQ(num_nodes, node_reward.size()); @@ -71,7 +70,8 @@ inline uint32_t FindBestPath( path->clear(); uint32_t reward = 0; for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { - path->push_back(nid); reward += node_reward[nid]; + path->push_back(nid); + reward += node_reward[nid]; } CHECK_EQ(reward, best_solution); return best_solution; @@ -88,11 +88,8 @@ inline uint32_t FindBestPath( * \param color the color index of each of the node. * \return the total number of colors. */ -inline uint32_t ColorNodeGroup( - const IndexedGraph &graph, - std::vector node_importance, - uint32_t max_ncolor, - std::vector *color) { +inline uint32_t ColorNodeGroup(const IndexedGraph& graph, std::vector node_importance, + uint32_t max_ncolor, std::vector* color) { CHECK_NE(max_ncolor, 0U); CHECK_EQ(graph.num_nodes(), node_importance.size()); diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 876dce1c113d0..fde1691ee96a5 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -21,33 +21,24 @@ * \file infer_shape.cc * \brief Inference the shapes given existin information. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { namespace { -template -Graph InferAttr(Graph &&ret, - const AttrType empty_val, - const char* infer_name, - const char* input_name, - const char* attr_key_name, - const char* attr_name, - const char* unknown_name, - IsNone fis_none, - FDefault fdefault) { +template +Graph InferAttr(Graph&& ret, const AttrType empty_val, const char* infer_name, + const char* input_name, const char* attr_key_name, const char* attr_name, + const char* unknown_name, IsNone fis_none, FDefault fdefault) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); - static auto& finfer_shape = - Op::GetAttr >(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& finfer_shape = Op::GetAttr>(infer_name); + static auto& is_backward = Op::GetAttr("TIsBackward"); // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); + static auto& fgrad = Op::GetAttr("FGradient"); // reshape shape vector AttrVector rshape; if (ret.attrs.count(attr_name) != 0) { @@ -70,8 +61,7 @@ Graph InferAttr(Graph &&ret, // get the shape hints std::string shape_hints_key = std::string(attr_name) + "_hints"; if (ret.attrs.count(shape_hints_key)) { - NodeEntryMap shape_hints = - ret.GetAttr>(shape_hints_key); + NodeEntryMap shape_hints = ret.GetAttr>(shape_hints_key); for (const auto& kv : shape_hints) { NodeEntry e = kv.first; if (idx.exist(e.node.get())) { @@ -110,7 +100,7 @@ Graph InferAttr(Graph &&ret, } } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) { CHECK_GE(inode.control_deps.size(), 1U) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; @@ -141,7 +131,7 @@ Graph InferAttr(Graph &&ret, } // out grad entries CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; + << "Cannot find matching backward op for " << inode.source->attrs.name; for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { const NodeEntry& e = igrad_node->inputs[i]; if (e.node == nullptr) { @@ -174,10 +164,9 @@ Graph InferAttr(Graph &&ret, throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); } } else { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registered by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + CHECK(!last_iter) << "Attribute " << infer_name << " is not registered by op " + << inode.source->op()->name + << " we are not able to complete the inference because of this"; } } // Save to the result map. @@ -221,32 +210,30 @@ Graph InferAttr(Graph &&ret, } NNVM_REGISTER_PASS(InferShape) -.describe("Infer the shape of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), TShape(), - "FInferShape", "shape_inputs", "shape_attr_key", - "shape", "shape_num_unknown_nodes", - [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, - nullptr); - }) -.set_change_graph(false) -.provide_graph_attr("shape"); + .describe("Infer the shape of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", + "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr); + }) + .set_change_graph(false) + .provide_graph_attr("shape"); // inference function for same type -inline bool SameType(const NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr) { +inline bool SameType(const NodeAttrs& attrs, std::vector* iattr, std::vector* oattr) { int def_v = -1; for (int v : *oattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } if (def_v == -1) { for (int v : *iattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } } @@ -261,17 +248,14 @@ inline bool SameType(const NodeAttrs& attrs, } NNVM_REGISTER_PASS(InferType) -.describe("Infer the dtype of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), -1, - "FInferType", "dtype_inputs", "dtype_attr_key", - "dtype", "dtype_num_unknown_nodes", - [](const int t) { return t == -1; }, - SameType); - }) -.set_change_graph(false) -.provide_graph_attr("dtype"); + .describe("Infer the dtype of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), -1, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", + "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, SameType); + }) + .set_change_graph(false) + .provide_graph_attr("dtype"); DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index b2fa2ca33e07a..2575a03ace03a 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -23,17 +23,15 @@ * To correctly order mutation and read to resolve * write after read problem and read after write problems. */ -#include #include +#include namespace nnvm { namespace pass { namespace { -template -inline T get_with_default(const std::unordered_map &map, - Node* key, - const T& def) { +template +inline T get_with_default(const std::unordered_map& map, Node* key, const T& def) { auto it = map.find(key); if (it != map.end()) return it->second; return def; @@ -46,19 +44,19 @@ inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { Graph OrderMutation(const Graph& src) { std::unordered_map > version_hist; DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) { - for (const NodeEntry& e : n->inputs) { - if (e.node->is_variable()) { - if (e.version != 0 && version_hist.count(e.node.get()) == 0) { - version_hist[e.node.get()] = std::vector{}; - } + for (const NodeEntry& e : n->inputs) { + if (e.node->is_variable()) { + if (e.version != 0 && version_hist.count(e.node.get()) == 0) { + version_hist[e.node.get()] = std::vector{}; } } - }); + } + }); // no mutation happens, everything if fine. if (version_hist.size() == 0) return src; // start preparing for remapping the nodes. std::unordered_map old_new; - auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) { + auto prepare = [&version_hist, &old_new](const ObjectPtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op())) { @@ -91,17 +89,17 @@ Graph OrderMutation(const Graph& src) { }; DFSVisit(src.outputs, prepare); // comparator of history entry - auto comparator = [](const NodeEntry& a, const NodeEntry &b) { + auto comparator = [](const NodeEntry& a, const NodeEntry& b) { if (a.version < b.version) return true; if (a.version > b.version) return false; return a.index > b.index; }; - for (auto &kv : version_hist) { + for (auto& kv : version_hist) { std::sort(kv.second.begin(), kv.second.end(), comparator); } // copy the nodes, as well as add control deps - for (auto &kv : old_new) { + for (auto& kv : old_new) { // copy the nodes for (const NodeEntry& e : kv.first->inputs) { auto it = old_new.find(e.node.get()); @@ -112,8 +110,7 @@ Graph OrderMutation(const Graph& src) { } } for (const ObjectPtr& p : kv.first->control_deps) { - kv.second->control_deps.emplace_back( - get_with_default(old_new, p.get(), p)); + kv.second->control_deps.emplace_back(get_with_default(old_new, p.get(), p)); } // add control deps static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -127,9 +124,8 @@ Graph OrderMutation(const Graph& src) { const NodeEntry& e = kv.first->inputs[i]; if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) { std::vector& vec = version_hist.at(e.node.get()); - auto it = std::lower_bound(vec.begin(), vec.end(), - NodeEntry{nullptr, 1, e.version}, - comparator); + auto it = + std::lower_bound(vec.begin(), vec.end(), NodeEntry{nullptr, 1, e.version}, comparator); if (IsMutate(mutate_inputs, i)) { int read_dep = 0; while (it != vec.begin()) { @@ -137,37 +133,35 @@ Graph OrderMutation(const Graph& src) { if (it->index != 0) break; ++read_dep; // depend on previous read - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } if (read_dep == 0 && it->index != 0) { // depend on last write - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } else { // depend on last write if (it->index != 0) { - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } } } } Graph ret; - for (const NodeEntry &e : src.outputs) { - ret.outputs.emplace_back(NodeEntry{ - get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); + for (const NodeEntry& e : src.outputs) { + ret.outputs.emplace_back( + NodeEntry{get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); } return ret; } NNVM_REGISTER_PASS(OrderMutation) -.describe("Return a new graph that adds control dependencies, "\ - "to order the mutation and reads if mutation exists.") -.set_body(OrderMutation) -.set_change_graph(true); + .describe( + "Return a new graph that adds control dependencies, " + "to order the mutation and reads if mutation exists.") + .set_body(OrderMutation) + .set_change_graph(true); } // namespace } // namespace pass diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 6d6866e472d6a..d45658ae24abd 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -22,9 +22,9 @@ * \brief Inference the device of each operator given known information. * Insert a copy node automatically when there is a cross device. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { @@ -43,8 +43,7 @@ Graph PlaceDevice(Graph src) { const Op* copy_op = Op::Get(src.GetAttr("device_copy_op")); auto& device_assign_map = src.GetAttr("device_assign_map"); const IndexedGraph& idx = src.indexed_graph(); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& is_backward = Op::GetAttr("TIsBackward"); DeviceVector device; // copy on write semanatics if (src.attrs.count("device") != 0) { @@ -65,15 +64,15 @@ Graph PlaceDevice(Graph src) { << "The device assignment not found for group " << device_group; device[nid] = dit->second; } else { - if (!inode.source->is_variable() && - is_backward.get(inode.source->op(), false)) { + if (!inode.source->is_variable() && is_backward.get(inode.source->op(), false)) { if (device[inode.control_deps[0]] != -1) { device[nid] = device[inode.control_deps[0]]; } } else { for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (device[e.node_id] != -1) { - device[nid] = device[e.node_id]; break; + device[nid] = device[e.node_id]; + break; } } } @@ -121,20 +120,21 @@ Graph PlaceDevice(Graph src) { auto e = inode.inputs[index]; if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { LOG(FATAL) << " mutable state cannot go across device" - << " op=" << inode.source->op()->name - << " input_state_index=" << index; + << " op=" << inode.source->op()->name << " input_state_index=" << index; } } } for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { - need_mutate = true; break; + need_mutate = true; + break; } } if (!need_mutate) { for (const uint32_t cid : inode.control_deps) { - if (new_node_map[cid] != nullptr) { - need_mutate = true; break; + if (new_node_map[cid] != nullptr) { + need_mutate = true; + break; } } } @@ -151,17 +151,15 @@ Graph PlaceDevice(Graph src) { auto copy_key = std::make_tuple(e.node_id, e.index, dev_id); auto it = copy_map.find(copy_key); if (it != copy_map.end() && it->first == copy_key) { - new_node->inputs.emplace_back( - NodeEntry{it->second, 0, 0}); + new_node->inputs.emplace_back(NodeEntry{it->second, 0, 0}); } else { ObjectPtr copy_node = Node::Create(); std::ostringstream os; - os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; + os << inode.source->inputs[i].node->attrs.name << "_" << e.index << "_copy"; copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); if (new_node_map[e.node_id] != nullptr) { - copy_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + copy_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { copy_node->inputs.push_back(inode.source->inputs[i]); } @@ -170,13 +168,11 @@ Graph PlaceDevice(Graph src) { } copy_map[copy_key] = copy_node; new_device_map[copy_node.get()] = dev_id; - new_node->inputs.emplace_back( - NodeEntry{std::move(copy_node), 0, 0}); + new_node->inputs.emplace_back(NodeEntry{std::move(copy_node), 0, 0}); } } else { if (new_node_map[e.node_id] != nullptr) { - new_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + new_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { new_node->inputs.push_back(inode.source->inputs[i]); } @@ -220,14 +216,15 @@ Graph PlaceDevice(Graph src) { } NNVM_REGISTER_PASS(PlaceDevice) -.describe("Infer the device type of each operator."\ - "Insert a copy node when there is cross device copy") -.set_body(PlaceDevice) -.set_change_graph(true) -.provide_graph_attr("device") -.depend_graph_attr("device_group_attr_key") -.depend_graph_attr("device_assign_map") -.depend_graph_attr("device_copy_op"); + .describe( + "Infer the device type of each operator." + "Insert a copy node when there is cross device copy") + .set_body(PlaceDevice) + .set_change_graph(true) + .provide_graph_attr("device") + .depend_graph_attr("device_group_attr_key") + .depend_graph_attr("device_assign_map") + .depend_graph_attr("device_copy_op"); DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 83d8f87fa9f1a..2c36cd2eef5a6 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -22,10 +22,12 @@ * \brief Assign memory tag to each of the data entries. */ #include -#include #include #include +#include + #include + #include "graph_algorithm.h" namespace nnvm { @@ -82,10 +84,10 @@ class GraphAllocator { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // find a exact match, erase from map and return @@ -95,10 +97,10 @@ class GraphAllocator { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // erase from map and return @@ -112,7 +114,7 @@ class GraphAllocator { void Release(StorageID id, uint32_t node_id) { CHECK_NE(id, kBadStorageID); if (id == kExternalStorageID || id == kDynamicStorageID) return; - StorageEntry *e = data_[id].get(); + StorageEntry* e = data_[id].get(); e->released_by_node = node_id; free_.insert({e->max_bytes, e}); } @@ -120,7 +122,7 @@ class GraphAllocator { // totoal number of bytes allocated size_t TotalAllocBytes() const { size_t total = 0; - for (auto &p : data_) { + for (auto& p : data_) { total += p->max_bytes; } return total; @@ -142,8 +144,7 @@ class GraphAllocator { if ((*idx_)[nid].source->is_variable()) continue; importance[nid] = 1; } - num_match_color_ = pass::ColorNodeGroup( - *idx_, importance, num_match_color_, &node_color_); + num_match_color_ = pass::ColorNodeGroup(*idx_, importance, num_match_color_, &node_color_); } } @@ -187,18 +188,16 @@ class GraphAllocator { * Internal method to perform the memory allocation for a graph * */ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, - const std::pair& node_range, - StorageVector* storage_ptr, + const std::pair& node_range, StorageVector* storage_ptr, std::vector* storage_inplace_index_ptr, - const std::vector& entry_ref_count, - GraphAllocator* allocator) { + const std::vector& entry_ref_count, GraphAllocator* allocator) { static auto& finplace_option = Op::GetAttr("FInplaceOption"); static auto& finplace_identity = Op::GetAttr("FInplaceIdentity"); static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); // Get reference - auto &storage = *storage_ptr; - auto &storage_inplace_index = *storage_inplace_index_ptr; + auto& storage = *storage_ptr; + auto& storage_inplace_index = *storage_inplace_index_ptr; // Get attributes from the graph const ShapeVector& shape_vec = ret.GetAttr("shape"); @@ -234,19 +233,16 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, auto sid_out = storage[eid_out]; auto sid_in = storage[eid_in]; bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && - fignore_inputs[inode.source->op()]( - inode.source->attrs).size() == inode.source->num_inputs()); + fignore_inputs[inode.source->op()](inode.source->attrs).size() == + inode.source->num_inputs()); // Identity should only be true if shape.Size() and types match bool real_identity = identity[ipair] && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]; - if (taken[kv.first] == false && - sid_out == GraphAllocator::kBadStorageID && - sid_in >= 0 && + if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) && - entry_ref_count[eid_out] > 0 && - shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && - (dtype_vec[eid_out] == dtype_vec[eid_in] || + entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + (dtype_vec[eid_out] == dtype_vec[eid_in] || GetDTypeSize(dtype_vec[eid_out]) == GetDTypeSize(dtype_vec[eid_in]))) { // inplace optimization taken[kv.first] = true; @@ -267,19 +263,19 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, uint32_t eid = idx.entry_id(nid, index); // only request memory for kBadStorageID if (storage[eid] == GraphAllocator::kBadStorageID) { - auto &eshape = shape_vec[eid]; + auto& eshape = shape_vec[eid]; size_t esize = 0; if (eshape.ndim() != 0) esize = eshape.Size(); eids.insert(std::make_pair(esize, eid)); } } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { - uint32_t eid = rit->second; - auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); - if (sid >= 0) { - storage_ref_count[sid] = entry_ref_count[eid]; - } - storage[eid] = sid; + uint32_t eid = rit->second; + auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); + if (sid >= 0) { + storage_ref_count[sid] = entry_ref_count[eid]; + } + storage[eid] = sid; } // check if certain inputs is ignored. std::vector ignore_inputs; @@ -320,7 +316,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, return num_not_allocated; } - // function to plan memory Graph PlanMemory(Graph ret) { // setup ref counter @@ -368,7 +363,7 @@ Graph PlanMemory(Graph ret) { size_t min_allocated_bytes = -1; size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); @@ -378,9 +373,8 @@ Graph PlanMemory(Graph ret) { GraphAllocator allocator(&idx, match_range); // number of entries that are not statically allocated. - size_t storage_num_not_allocated = - AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index, - ref_count, &allocator); + size_t storage_num_not_allocated = AllocMemory(ret, idx, node_range, &storage_vec, + &storage_inplace_index, ref_count, &allocator); size_t storage_allocated_bytes = allocator.TotalAllocBytes(); // Choose the plan which leads to minimal memory usage @@ -400,13 +394,13 @@ Graph PlanMemory(Graph ret) { } NNVM_REGISTER_PASS(PlanMemory) -.describe("Plan the memory allocation of each node entries.") -.set_body(PlanMemory) -.set_change_graph(false) -.depend_graph_attr("dtype") -.depend_graph_attr("shape") -.provide_graph_attr("storage_id") -.provide_graph_attr("storage_inplace_index"); + .describe("Plan the memory allocation of each node entries.") + .set_body(PlanMemory) + .set_change_graph(false) + .depend_graph_attr("dtype") + .depend_graph_attr("shape") + .provide_graph_attr("storage_id") + .provide_graph_attr("storage_inplace_index"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index a0127abe10f48..4fe92e6659617 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,6 +24,7 @@ #include #include #include + #include namespace nnvm { @@ -31,47 +32,39 @@ namespace pass { using AttrPrinter = std::function; // NOLINT(*) -template +template AttrPrinter GetVectorPrinter_(const T& vec) { return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*) os << vec[index]; }; } -AttrPrinter GetVectorPrinter(const Graph& graph, - const std::string& key) { +AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) { auto it = graph.attrs.find(key); - CHECK(it != graph.attrs.end()) - << "Cannot find " << key << " in graph attr"; + CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr"; const any& value = *(it->second); if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else { LOG(FATAL) << "Cannot handle type " << value.type().name(); return nullptr; } } - // print the graph ir in readable format -void PrintGraphIR_(Graph src, - const std::vector& join_entry_attrs, +void PrintGraphIR_(Graph src, const std::vector& join_entry_attrs, const std::vector& join_node_attrs, - std::ostream& os) { // NOLINT(*) + std::ostream& os) { // NOLINT(*) const IndexedGraph& idx = src.indexed_graph(); std::vector > trigger; // NOLINT(*) for (const std::string& key : join_entry_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) const IndexedGraph::Node& inode = idx[nid]; os << ", " << key << "="; if (inode.source->num_outputs() != 1) { @@ -89,8 +82,7 @@ void PrintGraphIR_(Graph src, } for (const std::string& key : join_node_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) os << ", " << key << "="; fp(idx.entry_id(nid, 0), os); }; @@ -101,7 +93,7 @@ void PrintGraphIR_(Graph src, if (idx.input_nodes().size() < 4) { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ", "; } os << '%' << idx[nid].source->attrs.name; @@ -109,7 +101,7 @@ void PrintGraphIR_(Graph src, } else { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ",\n "; } os << '%' << idx[nid].source->attrs.name; @@ -141,8 +133,8 @@ void PrintGraphIR_(Graph src, for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; - os << " " << "%" << nid << " = " - << inode.source->op()->name << "("; + os << " " + << "%" << nid << " = " << inode.source->op()->name << "("; bool first = true; for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (first) { @@ -213,12 +205,10 @@ Graph PrintGraphIRPass(Graph src) { std::ostringstream os; std::vector join_entry_attrs, join_node_attrs; if (src.attrs.count("join_entry_attrs") != 0) { - join_entry_attrs = src.MoveCopyAttr >( - "join_entry_attrs"); + join_entry_attrs = src.MoveCopyAttr >("join_entry_attrs"); } if (src.attrs.count("join_node_attrs") != 0) { - join_node_attrs = src.MoveCopyAttr >( - "join_node_attrs"); + join_node_attrs = src.MoveCopyAttr >("join_node_attrs"); } PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os); Graph ret; @@ -228,8 +218,8 @@ Graph PrintGraphIRPass(Graph src) { // register pass NNVM_REGISTER_PASS(PrintGraphIR) -.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") -.set_body(PrintGraphIRPass); + .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") + .set_body(PrintGraphIRPass); } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 9389995d05211..3916da43618dc 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -21,20 +21,21 @@ * \file saveload_json.cc * \brief Save and load graph to/from JSON file. */ +#include #include #include -#include + #include namespace dmlc { namespace json { // overload handler for shared ptr -template<> -struct Handler > { - inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { +template <> +struct Handler> { + inline static void Write(JSONWriter* writer, const std::shared_ptr& data) { writer->Write(*data); } - inline static void Read(JSONReader *reader, std::shared_ptr *data) { + inline static void Read(JSONReader* reader, std::shared_ptr* data) { any v; reader->Read(&v); *data = std::make_shared(std::move(v)); @@ -60,17 +61,16 @@ struct JSONNode { uint32_t index; uint32_t version; Entry() = default; - Entry(uint32_t node_id, uint32_t index, uint32_t version): - node_id(node_id), index(index), version(version) { - } - void Save(dmlc::JSONWriter *writer) const { + Entry(uint32_t node_id, uint32_t index, uint32_t version) + : node_id(node_id), index(index), version(version) {} + void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(false); writer->WriteArrayItem(node_id); writer->WriteArrayItem(index); writer->WriteArrayItem(version); writer->EndArray(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -95,7 +95,7 @@ struct JSONNode { std::vector subgraphs; // function to save JSON node. - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); if (node->op() != nullptr) { writer->WriteObjectKeyValue("op", node->op()->name); @@ -106,8 +106,7 @@ struct JSONNode { writer->WriteObjectKeyValue("name", node->attrs.name); if (node->attrs.dict.size() != 0) { // write attributes in order; - std::map dict( - node->attrs.dict.begin(), node->attrs.dict.end()); + std::map dict(node->attrs.dict.begin(), node->attrs.dict.end()); writer->WriteObjectKeyValue("attrs", dict); } writer->WriteObjectKeyValue("inputs", inputs); @@ -120,7 +119,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { node = Node::Create(); control_deps.clear(); dmlc::JSONObjectReadHelper helper; @@ -143,10 +142,10 @@ struct JSONNode { if (op_type_str != "null") { try { node->attrs.op = Op::Get(op_type_str); - } catch (const dmlc::Error &err) { + } catch (const dmlc::Error& err) { std::ostringstream os; - os << "Failed loading Op " << node->attrs.name - << " of type " << op_type_str << ": " << err.what(); + os << "Failed loading Op " << node->attrs.name << " of type " << op_type_str << ": " + << err.what(); throw dmlc::Error(os.str()); } } else { @@ -161,9 +160,9 @@ struct JSONGraph { std::vector arg_nodes; std::vector node_row_ptr; std::vector heads; - std::unordered_map > attrs; + std::unordered_map> attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); @@ -175,7 +174,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("nodes", &nodes); @@ -187,7 +186,7 @@ struct JSONGraph { } }; -void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph* jgraph) { std::unordered_map node2index; jgraph->node_row_ptr.push_back(0); DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) { @@ -212,10 +211,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } // recursively construct subgraphs - for (JSONNode &jnode : jgraph->nodes) { + for (JSONNode& jnode : jgraph->nodes) { // construct jnode's subgraphs - const std::vector> &subgraphs = jnode.node->attrs.subgraphs; - std::vector &jsubgraphs = jnode.subgraphs; + const std::vector>& subgraphs = jnode.node->attrs.subgraphs; + std::vector& jsubgraphs = jnode.subgraphs; jsubgraphs.resize(subgraphs.size()); for (uint32_t i = 0; i < subgraphs.size(); ++i) { Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]); @@ -223,10 +222,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { } } -std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { - for (const JSONNode &n : jgraph.nodes) { +std::shared_ptr JSONGraph2Symbol(const JSONGraph& jgraph, bool no_parse) { + for (const JSONNode& n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); - for (const JSONNode::Entry &e : n.inputs) { + for (const JSONNode::Entry& e : n.inputs) { CHECK(e.node_id < jgraph.nodes.size()); n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -235,7 +234,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) CHECK(nid < jgraph.nodes.size()); n.node->control_deps.push_back(jgraph.nodes[nid].node); } - for (const JSONGraph &subgraph : n.subgraphs) { + for (const JSONGraph& subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with // commit cfd3075e85807dcd8f9534c37e053583dee87524 // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), @@ -248,7 +247,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) n.node->op()->attr_parser(&(n.node->attrs)); } else if (!no_parse && n.node->is_variable()) { n.node->attrs.parsed = - Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } } // consistency check @@ -258,7 +257,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) } std::shared_ptr symbol = std::make_shared(); symbol->outputs.reserve(jgraph.heads.size()); - for (const JSONNode::Entry &e : jgraph.heads) { + for (const JSONNode::Entry& e : jgraph.heads) { CHECK(e.node_id < jgraph.nodes.size()); symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -267,10 +266,8 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // Load a graph from JSON file. Graph LoadJSON(Graph src) { - CHECK_NE(src.attrs.count("json"), 0U) - << "Load JSON require json to be presented."; - const std::string &json_str = - nnvm::get(*src.attrs.at("json")); + CHECK_NE(src.attrs.count("json"), 0U) << "Load JSON require json to be presented."; + const std::string& json_str = nnvm::get(*src.attrs.at("json")); bool no_parse = false; if (src.attrs.count("load_json_no_parse")) { no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); @@ -305,17 +302,16 @@ Graph SaveJSON(Graph src) { // register pass NNVM_REGISTER_PASS(LoadJSON) -.describe("Return a new Graph, loaded from src.attrs[\"json\"]") -.set_body(LoadJSON) -.set_change_graph(true) -.depend_graph_attr("json"); + .describe("Return a new Graph, loaded from src.attrs[\"json\"]") + .set_body(LoadJSON) + .set_change_graph(true) + .depend_graph_attr("json"); NNVM_REGISTER_PASS(SaveJSON) -.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") -.set_body(SaveJSON) -.set_change_graph(true) -.provide_graph_attr("json"); - + .describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") + .set_body(SaveJSON) + .set_change_graph(true) + .provide_graph_attr("json"); DMLC_JSON_ENABLE_ANY(std::string, str); DMLC_JSON_ENABLE_ANY(std::vector, list_int); diff --git a/nnvm/tests/cpp/op_test.cc b/nnvm/tests/cpp/op_test.cc index 4c771655d87b6..2ebd14688f466 100644 --- a/nnvm/tests/cpp/op_test.cc +++ b/nnvm/tests/cpp/op_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,16 +20,15 @@ #include #include #include -#include -NNVM_REGISTER_OP(add) -.describe("add two data together") -.set_num_inputs(2) -.set_attr("inplace_pair", std::make_pair(0, 0)); +#include NNVM_REGISTER_OP(add) -.set_attr("nick_name", "plus"); + .describe("add two data together") + .set_num_inputs(2) + .set_attr("inplace_pair", std::make_pair(0, 0)); +NNVM_REGISTER_OP(add).set_attr("nick_name", "plus"); TEST(Op, GetAttr) { using namespace nnvm; @@ -39,7 +38,7 @@ TEST(Op, GetAttr) { CHECK_EQ(nick[add], "plus"); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/nnvm/tests/cpp/tuple_test.cc b/nnvm/tests/cpp/tuple_test.cc index 7bf59b5db7c86..2c2c307aadce0 100644 --- a/nnvm/tests/cpp/tuple_test.cc +++ b/nnvm/tests/cpp/tuple_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,8 +22,8 @@ #include TEST(Tuple, Basic) { - using nnvm::Tuple; using nnvm::TShape; + using nnvm::Tuple; Tuple x{1, 2, 3}; Tuple y{1, 2, 3, 5, 6}; x = std::move(y); @@ -42,7 +42,7 @@ TEST(Tuple, Basic) { CHECK((s == TShape{1, 2, 3})); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index dc2dc1944f30f..b17174a7c6bf9 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -141,7 +141,13 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) type_codes[i] = TypeCode.TVM_CONTEXT - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytearray, bytes)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 1f68df1885db2..45bcf64a616d8 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -142,7 +142,13 @@ cdef inline int make_arg(object arg, value[0].v_ctx = (( ctypes.addressof(arg)))[0] tcode[0] = kTVMContext - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytes, bytearray)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 8d3ce19f9444d..8674e31c3b844 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -48,7 +48,6 @@ def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - # DMatrix functions lib.TVMGetLastError.restype = ctypes.c_char_p return lib, os.path.basename(lib_path[0]) diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 0d1a4e214791e..a1483a1b012b6 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -88,6 +88,10 @@ def find_lib_path(name=None, search_path=None, optional=False): dll_path.append(install_lib_dir) + if os.path.isdir(source_dir): + dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path.append(os.path.join(source_dir, "web", "dist")) + dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: if isinstance(search_path, list): @@ -154,6 +158,7 @@ def find_include_path(name=None, search_path=None, optional=False): ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source_dir = os.path.join(ffi_dir, "..", "..", "..") install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 6b06ad01c9ff5..0d6e5ac18fb3a 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -147,6 +147,7 @@ class TVMContext(ctypes.Structure): 12: 'ext_dev', 13: 'micro_dev', 14: 'hexagon', + 15: 'webgpu' } STR2MASK = { 'llvm': 1, @@ -169,6 +170,7 @@ class TVMContext(ctypes.Structure): 'ext_dev': 12, 'micro_dev': 13, 'hexagon': 14, + 'webgpu': 15, } def __init__(self, device_type, device_id): super(TVMContext, self).__init__() diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index cf81e2b50e50c..a0a826abccf62 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -145,7 +145,7 @@ def submit(self, func, *args, **kwargs): if not self.do_fork: return LocalFutureNoFork(func(*args, **kwargs)) - queue = Queue(2) + queue = Queue(2) # Size of 2 to avoid a race condition with size 1. process = Process(target=call_with_timeout, args=(queue, self.timeout, func, args, kwargs)) process.start() diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 6533e75eef933..185ed7d050194 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -245,6 +245,8 @@ def get_build_kwargs(self): if 'cuda' in self.task.target.keys: kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) + if self.task.target.device_name == 'micro_dev': + kwargs.setdefault('build_option', {})['disable_vectorize'] = True return kwargs @@ -273,8 +275,9 @@ def run(self, measure_inputs, build_results): if isinstance(res, Exception): # executor error or timeout results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time())) - else: - results.append(res) + raise Exception(f'encountered exception during measurement: {results}') + + results.append(res) return results diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index de183db41e2c6..f3edfb01dc077 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -48,6 +48,7 @@ def _lower(mod, grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(mod["main"]) return + # default case # Try graph codegen first to extract autotvm tasks. # If failed to compile, then fallback to use VM compiler. diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index f13ba5289ce54..3fbccfe80dede 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -50,7 +50,7 @@ 'llvm': "v0.04", 'cuda': "v0.08", - 'rocm': "v0.04", + 'rocm': "v0.05", 'opencl': "v0.04", 'mali': "v0.06", 'intel_graphics': "v0.02", @@ -66,6 +66,7 @@ def _alias(name): 'vtacpu': 'vta', 'metal': 'opencl', + 'webgpu': 'opencl', 'vulkan': 'opencl', 'nvptx': 'cuda', } diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py index eede450eaeafa..cfc1b2c38f852 100644 --- a/python/tvm/autotvm/tuner/callback.py +++ b/python/tvm/autotvm/tuner/callback.py @@ -149,7 +149,7 @@ def _callback(tuner, inputs, results): if res.error_no == 0: flops = inp.task.flop / np.mean(res.costs) - if logger.level < logging.DEBUG: # only print progress bar in non-debug mode + if not logger.isEnabledFor(logging.DEBUG): # only print progress bar in non-debug mode ctx.cur_flops = flops ctx.best_flops = tuner.best_flops diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py index a4c36bcd385e4..da10f73d5a532 100644 --- a/python/tvm/autotvm/tuner/ga_tuner.py +++ b/python/tvm/autotvm/tuner/ga_tuner.py @@ -50,7 +50,11 @@ def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1): # space info self.space = task.config_space - self.dims = [len(x) for x in self.space.space_map.values()] + self.dim_keys = [] + self.dims = [] + for k, v in self.space.space_map.items(): + self.dim_keys.append(k) + self.dims.append(len(v)) self.visited = set([]) @@ -123,7 +127,7 @@ def update(self, inputs, results): if len(self.visited) < len(self.space): while knob2point(tmp_gene, self.dims) in self.visited: j = np.random.randint(len(self.dims)) - tmp_gene[j] = np.random.randint(self.dims[j]) + tmp_gene[j] = np.random.randint(self.dims[j]) # pylint: disable=invalid-sequence-index next_genes.append(tmp_gene) self.visited.add(knob2point(tmp_gene, self.dims)) else: diff --git a/python/tvm/contrib/binutil.py b/python/tvm/contrib/binutil.py index 521e0885548ce..21e06df9f7f0b 100644 --- a/python/tvm/contrib/binutil.py +++ b/python/tvm/contrib/binutil.py @@ -21,7 +21,9 @@ import tvm._ffi from . import util +# TODO does this file still belong in `contrib`. is it too µTVM-specific? +# TODO shouldn't need so many `ALIGN` directives RELOCATION_LD_SCRIPT_TEMPLATE = """ /* linker symbol for use in UTVMInit */ _utvm_stack_pointer_init = 0x{stack_pointer_init:x}; @@ -118,7 +120,7 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix): size of the section in bytes """ if not os.path.isfile(binary_path): - raise RuntimeError("no such file \"{}\"".format(binary_path)) + raise RuntimeError('no such file "{}"'.format(binary_path)) # We use the "-A" flag here to get the ".rodata" section's size, which is # not included by default. size_output = run_cmd(["{}size".format(toolchain_prefix), "-A", binary_path]) @@ -160,6 +162,10 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix): # padding for most cases, but symbols can be arbitrarily large, so this # isn't bulletproof. return section_size + 32 + + # NOTE: in the past, section_size has been wrong on x86. it may be + # inconsistent. TODO: maybe stop relying on `*size` to give us the size and + # instead read the section with `*objcopy` and count the bytes. return section_size @@ -206,11 +212,13 @@ def tvm_callback_relocate_binary( rel_bin : bytearray the relocated binary """ + assert text_start < rodata_start < data_start < bss_start < stack_end stack_pointer_init = stack_end - word_size ld_script_contents = "" # TODO(weberlo): There should be a better way to configure this for different archs. + # TODO is this line even necessary? if "riscv" in toolchain_prefix: - ld_script_contents += "OUTPUT_ARCH( \"riscv\" )\n\n" + ld_script_contents += 'OUTPUT_ARCH( "riscv" )\n\n' ld_script_contents += RELOCATION_LD_SCRIPT_TEMPLATE.format( word_size=word_size, text_start=text_start, @@ -221,7 +229,7 @@ def tvm_callback_relocate_binary( tmp_dir = util.tempdir() rel_obj_path = tmp_dir.relpath("relocated.obj") - rel_ld_script_path = tmp_dir.relpath("relocated.lds") + rel_ld_script_path = tmp_dir.relpath("relocate.lds") with open(rel_ld_script_path, "w") as f: f.write(ld_script_contents) run_cmd([ @@ -229,8 +237,23 @@ def tvm_callback_relocate_binary( binary_path, "-T", rel_ld_script_path, "-o", rel_obj_path]) + with open(rel_obj_path, "rb") as f: rel_bin = bytearray(f.read()) + + gdb_init_dir = os.environ.get("MICRO_GDB_INIT_DIR") + if gdb_init_dir is not None: + gdb_init_path = f"{gdb_init_dir}/.gdbinit" + with open(gdb_init_path, "r") as f: + gdbinit_contents = f.read().split("\n") + new_contents = [] + for line in gdbinit_contents: + new_contents.append(line) + if line.startswith("target"): + new_contents.append(f"add-symbol-file {rel_obj_path}") + with open(gdb_init_path, "w") as f: + f.write("\n".join(new_contents)) + return rel_bin diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcfb..8ad47acfe989c 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -90,7 +90,8 @@ def get_target_triple(): def cross_compiler(compile_func, options=None, output_format=None, - get_target_triple=None): + get_target_triple=None, + add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -111,6 +112,10 @@ def cross_compiler(compile_func, get_target_triple: Optional[Callable] Function that can target triple according to dumpmachine option of compiler. + add_files: Optional[List[str]] + List of paths to additional object, source, library files + to pass as part of the compilation. + Returns ------- fcompile : Callable[[str, str, Optional[str]], None] @@ -133,6 +138,7 @@ def cross_compiler(compile_func, """ base_options = [] if options is None else options kwargs = {} + add_files = [] if add_files is None else add_files # handle case where compile_func is the name of the cc if isinstance(compile_func, str): @@ -144,7 +150,7 @@ def _fcompile(outputs, objects, options=None): all_options = base_options if options is not None: all_options += options - compile_func(outputs, objects, options=all_options, **kwargs) + compile_func(outputs, objects + add_files, options=all_options, **kwargs) if not output_format and hasattr(compile_func, "output_format"): output_format = compile_func.output_format diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py similarity index 65% rename from python/tvm/contrib/emscripten.py rename to python/tvm/contrib/emcc.py index 7f31273451f75..6e7e997d43ad5 100644 --- a/python/tvm/contrib/emscripten.py +++ b/python/tvm/contrib/emcc.py @@ -16,18 +16,16 @@ # under the License. """Util to invoke emscripten compilers in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import subprocess -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path +from tvm._ffi.base import py_str +from tvm._ffi.libinfo import find_lib_path + -def create_js(output, - objects, - options=None, - side_module=False, - cc="emcc"): - """Create emscripten javascript library. +def create_tvmjs_wasm(output, + objects, + options=None, + cc="emcc"): + """Create wasm that is supposed to run with the tvmjs. Parameters ---------- @@ -44,25 +42,28 @@ def create_js(output, The compile string. """ cmd = [cc] - cmd += ["-Oz"] - if not side_module: - cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"] - cmd += ["-s", "NO_EXIT_RUNTIME=1"] - extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction'] - cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]" - cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg] - else: - cmd += ["-s", "SIDE_MODULE=1"] - cmd += ["-o", output] + cmd += ["-O3"] + + cmd += ["-std=c++14"] + cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] + cmd += ["-s", "STANDALONE_WASM=1"] + cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"] + + objects = [objects] if isinstance(objects, str) else objects + with_runtime = False for obj in objects: - if obj.find("libtvm_web_runtime.bc") != -1: + if obj.find("wasm_runtime.bc") != -1: with_runtime = True - if not with_runtime and not side_module: - objects += [find_lib_path("libtvm_web_runtime.bc")[0]] + if not with_runtime: + objects += [find_lib_path("wasm_runtime.bc")[0]] + objects += [find_lib_path("tvmjs_support.bc")[0]] + objects += [find_lib_path("webgpu_runtime.bc")[0]] + + cmd += ["-o", output] cmd += objects if options: @@ -79,4 +80,4 @@ def create_js(output, msg += py_str(out) raise RuntimeError(msg) -create_js.object_format = "bc" +create_tvmjs_wasm.object_format = "bc" diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 73235f71c77ba..740d1c3f19f3c 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -18,9 +18,10 @@ import numpy as np import tvm._ffi -from .._ffi.base import string_types -from .._ffi.runtime_ctypes import TVMContext -from ..rpc import base as rpc_base +from tvm.rpc import _ffi_api as _rpc_ffi_api +from tvm.rpc import base as rpc_base +from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import TVMContext def create(graph_json_str, libmod, ctx): @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx): device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex( + assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py index f13670e398952..fd0ed0c8fb2f0 100644 --- a/python/tvm/contrib/tf_op/module.py +++ b/python/tvm/contrib/tf_op/module.py @@ -17,6 +17,7 @@ """Module container of TensorFlow TVMDSO op""" import tensorflow as tf from tensorflow.python.framework import load_library +from tensorflow.python import platform class OpModule: @@ -67,7 +68,7 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape): elif output_shape is not None: self.dynamic_output_shape = self._pack_shape_tensor(output_shape) - self.module = load_library.load_op_library('tvm_dso_op.so') + self.module = self._load_platform_specific_library("tvm_dso_op") self.tvm_dso_op = self.module.tvm_dso_op def apply(self, *params): @@ -82,6 +83,16 @@ def apply(self, *params): def __call__(self, *params): return self.apply(*params) + def _load_platform_specific_library(self, lib_name): + system = platform.system() + if system == "Darwin": + lib_file_name = lib_name + ".dylib" + elif system == "Windows": + lib_file_name = lib_name + ".dll" + else: + lib_file_name = lib_name + ".so" + return load_library.load_op_library(lib_file_name) + def _is_static_shape(self, shape): if shape is None or not isinstance(shape, list): return False diff --git a/python/tvm/error.py b/python/tvm/error.py index 4c3e6060c25a8..b3502f6b0eada 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -57,6 +57,11 @@ def __init__(self, msg): register_error("KeyError", KeyError) +@register_error +class RPCError(RuntimeError): + """Error thrown by the remote server handling the RPC call.""" + + @register_error class OpError(TVMError): """Base class of all operator errors in frontends.""" diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 4cf341335ea7b..eb802866efba0 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -29,17 +29,22 @@ def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - base_path = os.path.join(curr_path, "../../../") - index_page = os.path.join(base_path, "web/example_rpc.html") - js_files = [ - os.path.join(base_path, "web/tvm_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js.mem") + base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) + index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") + resource_files = [ + os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"), + os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js") ] - for fname in [index_page] + js_files: + resource_base = os.path.join(base_path, "web", "dist", "www") + if os.path.isdir(resource_base): + for fname in os.listdir(resource_base): + full_name = os.path.join(resource_base, fname) + if os.path.isfile(full_name): + resource_files.append(full_name) + for fname in [index_page] + resource_files: if not os.path.exists(fname): raise RuntimeError("Cannot find %s" % fname) - return index_page, js_files + return index_page, resource_files def main(args): @@ -69,7 +74,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", + parser.add_argument('--host', type=str, default="localhost", help='the hostname of the server') parser.add_argument('--port', type=int, default=9090, help='The port of the RPC') diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index dbb690267e2a0..e281e58e3879f 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -20,6 +20,7 @@ import argparse import ast +import json import multiprocessing import sys import logging @@ -41,7 +42,7 @@ def main(args): tracker_addr = (url, port) if not args.key: raise RuntimeError( - "Need key to present type of resource when tracker is available") + 'Need key to present type of resource when tracker is available') else: tracker_addr = None @@ -75,8 +76,8 @@ def init_utvm(args): dev_config = json.load(dev_conf_file) else: dev_config_args = ast.literal_eval(args.utvm_dev_config_args) - default_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['default_config'] - dev_config = default_config_func(*dev_config_args) + generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['generate_config'] + dev_config = generate_config_func(*dev_config_args) if args.utvm_dev_config or args.utvm_dev_id: # add MicroTVM overrides @@ -100,8 +101,8 @@ def server_shutdown(): parser.add_argument('--port-end', type=int, default=9199, help='The end search port of the RPC') parser.add_argument('--tracker', type=str, - help="The address of RPC tracker in host:port format. " - "e.g. (10.77.1.234:9190)") + help=("The address of RPC tracker in host:port format. " + "e.g. (10.77.1.234:9190)")) parser.add_argument('--key', type=str, default="", help="The key used to identify the device type in tracker.") parser.add_argument('--silent', action='store_true', @@ -110,17 +111,24 @@ def server_shutdown(): help="Additional library to load") parser.add_argument('--no-fork', dest='fork', action='store_false', help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.") + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.") parser.add_argument('--custom-addr', type=str, help="Custom IP Address to Report to RPC Tracker") parser.add_argument('--utvm-dev-config', type=str, - help='JSON config file for the target device (if using MicroTVM)') - parser.add_argument('--utvm-dev-id', type=str, - help='Unique ID for the target device (if using MicroTVM)') + help=('JSON config file for the target device (if using MicroTVM). ' + 'This file should contain serialized output similar to that returned ' + "from the device module's generate_config. Can't be specified when " + '--utvm-dev-config-args is specified.')) parser.add_argument('--utvm-dev-config-args', type=str, - help=('Python list of literals required to generate a default' - ' MicroTVM config (if --utvm-dev-id is specified)')) + help=("Arguments to the device module's generate_config function. " + 'Must be a python literal parseable by literal_eval. If specified, ' + "the device configuration is generated using the device module's " + "generate_config. Can't be specified when --utvm-dev-config is " + "specified.")) + parser.add_argument('--utvm-dev-id', type=str, + help=('Unique ID for the target device (if using MicroTVM). Should ' + 'match the name of a module underneath tvm.micro.device).')) parser.set_defaults(fork=True) args = parser.parse_args() diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9a881cfa6d5be..fcea9d8212220 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -16,6 +16,9 @@ # under the License. """Tool to upgrade json from historical versions.""" import json +import tvm.ir +import tvm.runtime + def create_updater(node_map, from_ver, to_ver): """Create an updater to update json loaded data. @@ -41,8 +44,12 @@ def _updater(data): nodes = data["nodes"] for idx, item in enumerate(nodes): f = node_map.get(item["type_key"], None) - if f: - nodes[idx] = f(item, nodes) + if isinstance(f, list): + for fpass in f: + item = fpass(item, nodes) + elif f: + item = f(item, nodes) + nodes[idx] = item data["attrs"]["tvm_version"] = to_ver return data return _updater @@ -84,12 +91,26 @@ def _update_global_key(item, _): del item["global_key"] return item + def _update_from_std_str(key): + def _convert(item, nodes): + str_val = item["attrs"][key] + jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val))) + root_idx = jdata["root"] + val = jdata["nodes"][root_idx] + sidx = len(nodes) + nodes.append(val) + item["attrs"][key] = '%d' % sidx + return item + + return _convert + + node_map = { # Base IR "SourceName": _update_global_key, "EnvFunc": _update_global_key, "relay.Op": _update_global_key, - "relay.TypeVar": _ftype_var, + "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")], "relay.GlobalTypeVar": _ftype_var, "relay.Type": _rename("Type"), "relay.TupleType": _rename("TupleType"), diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 9e984c08fe2c1..7c1389cc4eef7 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -17,6 +17,7 @@ """MicroTVM module for bare-metal backends""" from ..contrib import binutil -from .base import Session, create_micro_mod, cross_compiler -from .base import LibType, get_micro_host_driven_dir, get_micro_device_dir +from .base import DEVICE_SECTIONS +from .base import Session, create_micro_mod, cross_compiler, LibType +from .base import get_micro_host_driven_dir, get_micro_device_dir from . import device diff --git a/python/tvm/micro/base.py b/python/tvm/micro/base.py index 9f50f9855303d..bf4fd0ac9b767 100644 --- a/python/tvm/micro/base.py +++ b/python/tvm/micro/base.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import os +import re import sys from enum import Enum @@ -28,6 +29,18 @@ from tvm.contrib import util as _util from tvm.contrib import cc as _cc +# all sections that comprise a device's memory layout, in order from lowest +# starting address to highest +DEVICE_SECTIONS = [ + "text", + "rodata", + "data", + "bss", + "args", + "heap", + "workspace", + "stack", +] class LibType(Enum): """Enumeration of library types that can be compiled and loaded onto a device""" @@ -51,9 +64,9 @@ class Session: .. code-block:: python c_mod = ... # some module generated with "c" as the target - dev_config = micro.device.arm.stm32f746xx.default_config("127.0.0.1", 6666) + dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) with tvm.micro.Session(dev_config) as sess: - micro_mod = create_micro_mod(c_mod, dev_config) + micro_mod = sess.create_micro_mod(c_mod) """ def __init__(self, config): @@ -62,19 +75,20 @@ def __init__(self, config): # grab a binutil instance from the ID in the config dev_funcs = tvm.micro.device.get_device_funcs(config["device_id"]) - self.create_micro_lib = dev_funcs["create_micro_lib"] self.toolchain_prefix = config["toolchain_prefix"] self.mem_layout = config["mem_layout"] - self.word_size = config["word_size"] + self.word_size_bits = config["word_size_bits"] self.thumb_mode = config["thumb_mode"] + self.use_device_timer = config["use_device_timer"] self.comms_method = config["comms_method"] # First, find and compile runtime library. runtime_src_path = os.path.join(get_micro_host_driven_dir(), "utvm_runtime.c") tmp_dir = _util.tempdir() runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj") - self.create_micro_lib(runtime_obj_path, runtime_src_path, LibType.RUNTIME) - #input(f"check {runtime_obj_path}: ") + options = ["-I{}".format(get_micro_host_driven_dir())] + dev_funcs["create_micro_lib"]( + runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options) comms_method = config["comms_method"] if comms_method == "openocd": @@ -86,6 +100,8 @@ def __init__(self, config): else: raise RuntimeError(f"unknown communication method: f{self.comms_method}") + assert all(map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS)), \ + "not all sections have an assigned memory layout" self.module = _CreateSession( comms_method, runtime_obj_path, @@ -106,12 +122,15 @@ def __init__(self, config): self.mem_layout["workspace"]["size"], self.mem_layout["stack"].get("start", 0), self.mem_layout["stack"]["size"], - self.word_size, + self.word_size_bits, self.thumb_mode, + self.use_device_timer, server_addr, server_port) self._enter = self.module["enter"] self._exit = self.module["exit"] + self.get_last_batch_time = self.module["get_last_batch_time"] + self.get_last_batch_cycles = self.module["get_last_batch_cycles"] def _check_system(self): """Check if the user's system is supported by MicroTVM. @@ -119,7 +138,7 @@ def _check_system(self): Raises error if not supported. """ if not sys.platform.startswith("linux"): - raise RuntimeError("MicroTVM is currently only supported on Linux hosts") + raise RuntimeError("MicroTVM is currently only supported on Linux") # TODO(weberlo): Add 32-bit support. # It's primarily the compilation pipeline that isn't compatible. if sys.maxsize <= 2**32: @@ -133,44 +152,91 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self._exit() -def create_micro_mod(c_mod, dev_config): +def _calc_max_workspace_usage(src): + # TODO factor in alignment to the calculation (alloc sizes will be aligned up to the word size) + alloc_re = re.compile( + r'.*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*') + free_re = re.compile(r'.*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*') + max_usage = 0 + alloc_map = {} + for line in src.split("\n"): + if line.strip().startswith("//"): + continue + match = alloc_re.match(line) + if match is not None: + alloc_map[match.group(1)] = int(match.group(3)) + max_usage = max(max_usage, sum(alloc_map.values())) + else: + match = free_re.match(line) + if match is not None: + print(alloc_map) + del alloc_map[match.group(2)] + return max_usage + + +def create_micro_mod(c_mod, dev_config, lib_src_paths=None, lib_headers=None, + lib_include_paths=None): """Produces a micro module from a given module. Parameters ---------- - c_mod : tvm.runtime.Module + c_mod : tvm.module.Module module with "c" as its target backend - dev_config : Dict[str, Any] - MicroTVM config dict for the target device + lib_src_paths: TODO + TODO + + lib_headers: TODO + TODO + + lib_include_paths: TODO + TODO Return ------ - micro_mod : tvm.runtim.Module + micro_mod : tvm.module.Module micro module for the target device """ temp_dir = _util.tempdir() lib_obj_path = temp_dir.relpath("dev_lib.obj") + # TODO use dev config to dispatch on the type of C codegen to run through + # (e.g., CodeGenCArm, CodeGenCHost, CodeGenCRiscV) c_mod.export_library( lib_obj_path, - fcompile=cross_compiler(dev_config, LibType.OPERATOR)) + fcompile=cross_compiler( + dev_config, + LibType.OPERATOR, + lib_src_paths=lib_src_paths, + lib_headers=lib_headers, + lib_include_paths=lib_include_paths)) micro_mod = tvm.runtime.load_module(lib_obj_path) return micro_mod -def cross_compiler(dev_config, lib_type): - """Create a cross-compile function that wraps `create_lib` for a `Binutil` instance. +def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None, + lib_include_paths=None): + """Create a cross compile function that wraps `create_lib` for a `Binutil` instance. For use in `tvm.runtime.Module.export_library`. Parameters ---------- - dev_config : Dict[str, Any] - MicroTVM config dict for the target device + create_micro_lib : func + function for creating MicroTVM libraries for a specific device (e.g., + `tvm.micro.device.get_device_funcs('arm.stm32f746xx')['create_micro_lib']`) lib_type : micro.LibType whether to compile a MicroTVM runtime or operator library + lib_src_paths: TODO + TODO + + lib_headers: TODO + e.g., `['cmsis_gcc.h', 'arm_math.h']` + + lib_include_paths: TODO + TODO + Return ------ func : Callable[[str, str, Optional[str]], None] @@ -183,16 +249,49 @@ def cross_compiler(dev_config, lib_type): c_mod = ... # some module generated with "c" as the target fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR) - c_mod.export_library("dev_lib.obj", fcompile=fcompile) + c_mod.export_library('dev_lib.obj', fcompile=fcompile) """ - dev_funcs = tvm.micro.device.get_device_funcs(dev_config['device_id']) - create_micro_lib = dev_funcs['create_micro_lib'] + assert (lib_headers is None) == (lib_include_paths is None), \ + "must specify both `lib_headers` and `lib_include_paths` or neither" + + if lib_src_paths is None: + lib_src_paths = [] + if lib_include_paths is None: + lib_include_paths = [] + include_options = [] + for include_path in lib_include_paths: + include_options.append("-I") + include_options.append(include_path) + create_micro_lib = tvm.micro.device.get_device_funcs( + dev_config["device_id"])["create_micro_lib"] + mem_layout = dev_config["mem_layout"] + def compile_func(obj_path, src_path, **kwargs): if isinstance(obj_path, list): obj_path = obj_path[0] if isinstance(src_path, list): src_path = src_path[0] - create_micro_lib(obj_path, src_path, lib_type, kwargs.get("options", None)) + options = kwargs.get("options", []) + options += include_options + + # check that workspace allocations don't exceed available workspace memory + with open(src_path) as f: + src_contents = f.read() + max_ws_usage = _calc_max_workspace_usage(src_contents) + available_mem = mem_layout["workspace"]["size"] + if max_ws_usage > available_mem: + raise RuntimeError(f"workspace allocations in library ({max_ws_usage}) " + f"exceed available memory ({available_mem})") + # inject headers into new source path, if requested + if lib_headers: + headers_to_inject = "\n".join(map(lambda s: f"#include <{s}>", lib_headers)) + "\n" + new_src_contents = headers_to_inject + src_contents + tmp_dir = _util.tempdir() + src_path = tmp_dir.relpath(os.path.basename(src_path)) + with open(src_path, "w") as f: + f.write(new_src_contents) + + create_micro_lib(obj_path, src_path, lib_type, options, lib_src_paths=lib_src_paths) return _cc.cross_compiler(compile_func, output_format="obj") diff --git a/python/tvm/micro/device/__init__.py b/python/tvm/micro/device/__init__.py index 1ccd6847edd81..89731b9aa7974 100644 --- a/python/tvm/micro/device/__init__.py +++ b/python/tvm/micro/device/__init__.py @@ -16,7 +16,8 @@ # under the License. """Device-specific configuration for MicroTVM""" -from .base import register_device, get_device_funcs, create_micro_lib_base +from .base import create_micro_lib_base, gen_mem_layout +from .base import MemConstraint, register_device, get_device_funcs from . import host from . import arm from . import riscv_spike diff --git a/python/tvm/micro/device/arm/stm32f746xx.py b/python/tvm/micro/device/arm/stm32f746xx.py index 31b44cf9d36b5..746958504edaf 100644 --- a/python/tvm/micro/device/arm/stm32f746xx.py +++ b/python/tvm/micro/device/arm/stm32f746xx.py @@ -14,13 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Compilation and config definitions for ARM STM32F746XX devices""" -from .. import create_micro_lib_base, register_device +"""Compilation and config definitions for Arm STM32F746XX devices""" +import os +from .. import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "arm.stm32f746xx" TOOLCHAIN_PREFIX = "arm-none-eabi-" +WORD_SIZE_BITS = 32 +# +# [Device Memory Layout] +# RAM (rwx) : START = 0x20000000, LENGTH = 320K +# Flash (rx) : START = 0x8000000, LENGTH = 1024K +# +BASE_ADDR = 0x20000000 +AVAILABLE_MEM = 320000 +DEFAULT_SECTION_CONSTRAINTS = { + "text": (18000, MemConstraint.ABSOLUTE_BYTES), + "rodata": (100, MemConstraint.ABSOLUTE_BYTES), + "data": (100, MemConstraint.ABSOLUTE_BYTES), + "bss": (600, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (100.0, MemConstraint.WEIGHT), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (32, MemConstraint.ABSOLUTE_BYTES), +} -def create_micro_lib(obj_path, src_path, lib_type, options=None): +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -36,23 +55,40 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + TODO """ if options is None: options = [] + else: + options = list(options) + options += [ + # TODO(weberlo): make a debug flag + "-O2", "-mcpu=cortex-m7", "-mlittle-endian", "-mfloat-abi=hard", "-mfpu=fpv5-sp-d16", "-mthumb", + "-ffast-math", "-gdwarf-5", + "-DARM_MATH_CM7", + "-D__FPU_PRESENT=1U", + "-DARM_MATH_DSP", + "-Wno-unused-variable", + "-Wno-unused-parameter", + "-I{}".format(os.environ["CMSIS_ST_PATH"]), + "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"]) ] create_micro_lib_base( - obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options) + obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options, + lib_src_paths=lib_src_paths) -def default_config(server_addr, server_port): - """Generates a default configuration for ARM STM32F746XX devices +def generate_config(server_addr, server_port, section_constraints=None): + """Generates a configuration for Arm STM32F746XX devices Parameters ---------- @@ -62,55 +98,23 @@ def default_config(server_addr, server_port): server_port : int port of OpenOCD server to connect to + section_constraints: Optional[Dict[str, [Number, MemConstraint]]] + maps section name to the quantity of available memory + Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - # - # [Device Memory Layout] - # RAM (rwx) : START = 0x20000000, LENGTH = 320K - # FLASH (rx) : START = 0x8000000, LENGTH = 1024K - # - "mem_layout": { - "text": { - "start": 0x20000180, - "size": 20480, - }, - "rodata": { - "start": 0x20005180, - "size": 20480, - }, - "data": { - "start": 0x2000a180, - "size": 768, - }, - "bss": { - "start": 0x2000a480, - "size": 768, - }, - "args": { - "start": 0x2000a780, - "size": 1280, - }, - "heap": { - "start": 0x2000ac80, - "size": 262144, - }, - "workspace": { - "start": 0x2004ac80, - "size": 20480, - }, - "stack": { - "start": 0x2004fc80, - "size": 80, - }, - }, - "word_size": 4, + "mem_layout": gen_mem_layout(BASE_ADDR, AVAILABLE_MEM, WORD_SIZE_BITS, section_constraints), + "word_size_bits": WORD_SIZE_BITS, "thumb_mode": True, + "use_device_timer": True, "comms_method": "openocd", "server_addr": server_addr, "server_port": server_port, @@ -119,5 +123,5 @@ def default_config(server_addr, server_port): register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/micro/device/base.py b/python/tvm/micro/device/base.py index ae53b9cc539fc..767284c9c254b 100644 --- a/python/tvm/micro/device/base.py +++ b/python/tvm/micro/device/base.py @@ -17,12 +17,13 @@ """Base definitions for MicroTVM config""" import glob import os -from pathlib import Path +import enum +import pathlib from tvm.contrib import util as _util from tvm.contrib.binutil import run_cmd from tvm._ffi.libinfo import find_include_path -from tvm.micro import LibType, get_micro_host_driven_dir, get_micro_device_dir +from tvm.micro import DEVICE_SECTIONS, LibType, get_micro_host_driven_dir, get_micro_device_dir _DEVICE_REGISTRY = {} @@ -38,7 +39,7 @@ def register_device(device_id, device_funcs): dictionary with compilation and config generation functions as values """ if device_id in _DEVICE_REGISTRY: - raise RuntimeError(f"\"{device_id}\" already exists in the device registry") + raise RuntimeError(f'"{device_id}" already exists in the device registry') _DEVICE_REGISTRY[device_id] = device_funcs @@ -56,7 +57,7 @@ def get_device_funcs(device_id): dictionary with compilation and config generation functions as values """ if device_id not in _DEVICE_REGISTRY: - raise RuntimeError(f"\"{device_id}\" does not exist in the binutil registry") + raise RuntimeError(f'"{device_id}" does not exist in the binutil registry') device_funcs = _DEVICE_REGISTRY[device_id] return device_funcs @@ -67,7 +68,9 @@ def create_micro_lib_base( toolchain_prefix, device_id, lib_type, - options=None): + options=None, + lib_src_paths=None, + ): """Compiles code into a binary for the target micro device. Parameters @@ -92,7 +95,12 @@ def create_micro_lib_base( options : List[str] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + paths to additional source files to be compiled into the library """ + # look at these (specifically `strip`): + # https://stackoverflow.com/questions/15314581/g-compiler-flag-to-minimize-binary-size base_compile_cmd = [ f"{toolchain_prefix}gcc", "-std=c11", @@ -100,7 +108,6 @@ def create_micro_lib_base( "-Wextra", "--pedantic", "-c", - "-O0", "-g", "-nostartfiles", "-nodefaultlibs", @@ -114,40 +121,48 @@ def create_micro_lib_base( src_paths = [] include_paths = find_include_path() + [get_micro_host_driven_dir()] tmp_dir = _util.tempdir() - # we might transform the src path in one of the branches below + # we need to create a new src file in the operator branch new_in_src_path = in_src_path if lib_type == LibType.RUNTIME: dev_dir = _get_device_source_dir(device_id) + dev_src_paths = glob.glob(f"{dev_dir}/*.[csS]") # there needs to at least be a utvm_timer.c file assert dev_src_paths assert "utvm_timer.c" in map(os.path.basename, dev_src_paths) + src_paths += dev_src_paths elif lib_type == LibType.OPERATOR: - # create a temporary copy of the source, so we can inject the dev lib + # create a temporary copy of the operator source, so we can inject the dev lib # header without modifying the original. temp_src_path = tmp_dir.relpath("temp.c") with open(in_src_path, "r") as f: src_lines = f.read().splitlines() - src_lines.insert(0, "#include \"utvm_device_dylib_redirect.c\"") + src_lines.insert(0, '#include "utvm_device_dylib_redirect.c"') with open(temp_src_path, "w") as f: f.write("\n".join(src_lines)) new_in_src_path = temp_src_path - base_compile_cmd += ["-c"] else: raise RuntimeError("unknown lib type") src_paths += [new_in_src_path] + # add any src paths required by the operator + if lib_src_paths is not None: + src_paths += lib_src_paths + + # print(f"include paths: {include_paths}") for path in include_paths: base_compile_cmd += ["-I", path] prereq_obj_paths = [] + # print(src_paths) for src_path in src_paths: - curr_obj_path = Path(src_path).with_suffix(".o").name + curr_obj_path = tmp_dir.relpath(pathlib.Path(src_path).with_suffix(".o").name) assert curr_obj_path not in prereq_obj_paths prereq_obj_paths.append(curr_obj_path) curr_compile_cmd = base_compile_cmd + [src_path, "-o", curr_obj_path] + # TODO(weberlo): make compilation fail if there are any warnings run_cmd(curr_compile_cmd) ld_cmd = [f"{toolchain_prefix}ld", "-relocatable"] @@ -156,6 +171,65 @@ def create_micro_lib_base( run_cmd(ld_cmd) +# TODO we shouldn't need an enum for this. too much bureaucracy. +class MemConstraint(enum.Enum): + """Represents a constraint on the device's memory layout""" + ABSOLUTE_BYTES = 0 + WEIGHT = 1 + + +def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints): + """Template function to generate memory layout for devices. + + Parameters + ---------- + base_addr: Number + The address where usable memory begins on this device. + + available_mem: Number + Available memory at base_addr, given in bytes. + + word_size_bits: Number + Number of bits in one word on this device. + + section_constraints: Optional[Dict[str, [Number, MemConstraint]]] + maps section name to the quantity of available memory + """ + assert word_size_bits in (32, 64), "only 32- or 64-bit devices are supported now" + word_size_bytes = word_size_bits // 8 + byte_sum = sum(x[0] + for x in section_constraints.values() + if x[1] == MemConstraint.ABSOLUTE_BYTES) + weight_sum = sum(x[0] + for x in section_constraints.values() + if x[1] == MemConstraint.WEIGHT) + assert byte_sum <= available_mem + available_weight_mem = available_mem - byte_sum + + res = {} + curr_addr = base_addr + for section in DEVICE_SECTIONS: + (val, cons_type) = section_constraints[section] + if cons_type == MemConstraint.ABSOLUTE_BYTES: + assert val % word_size_bytes == 0, \ + f"constraint {val} for {section} section is not word-aligned" + size = val + res[section] = { + "start": curr_addr, + "size": size, + } + else: + size = int((val / weight_sum) * available_weight_mem) + size = (size // word_size_bytes) * word_size_bytes + res[section] = { + "start": curr_addr, + "size": size, + } + curr_addr += size + + return res + + def _get_device_source_dir(device_id): """Grabs the source directory for device-specific uTVM files""" dev_subdir = "/".join(device_id.split(".")) diff --git a/python/tvm/micro/device/host.py b/python/tvm/micro/device/host.py index a5495b60cf99f..0cf29874ab573 100644 --- a/python/tvm/micro/device/host.py +++ b/python/tvm/micro/device/host.py @@ -17,12 +17,26 @@ """Compilation and config definitions for the host emulated device""" import sys -from . import create_micro_lib_base, register_device +from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "host" TOOLCHAIN_PREFIX = "" +WORD_SIZE_BITS = 64 if sys.maxsize > 2**32 else 32 -def create_micro_lib(obj_path, src_path, lib_type, options=None): +# we pretend we only have 320kb in the default case, so we can use `gen_mem_layout` +DEFAULT_AVAILABLE_MEM = 3200000 +DEFAULT_SECTION_CONSTRAINTS = { + "text": (20480, MemConstraint.ABSOLUTE_BYTES), + "rodata": (20480, MemConstraint.ABSOLUTE_BYTES), + "data": (768, MemConstraint.ABSOLUTE_BYTES), + "bss": (4096, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (262144, MemConstraint.ABSOLUTE_BYTES), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (80, MemConstraint.ABSOLUTE_BYTES), +} + +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -38,59 +52,65 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + paths to additional source files to be compiled into the library """ if options is None: options = [] + else: + options = list(options) + # Cannot increase optimization level on host due to code loading method. + options.append("-O0") if sys.maxsize > 2**32 and sys.platform.startswith("linux"): options += ["-mcmodel=large"] create_micro_lib_base( - obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options) + obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options, + lib_src_paths=lib_src_paths) -def default_config(): - """Generates a default configuration for the host emulated device +def generate_config(available_mem=None, section_constraints=None): + """Generates a configuration for the host emulated device + + Parameters + ---------- + available_mem: int + number of RW bytes available for use on device + + section_constraints: Optional[Dict[str, Dict[Number, MemConstraint]]] + maps section name to the quantity of available memory Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ + if available_mem is None: + available_mem = DEFAULT_AVAILABLE_MEM + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS + mem_layout = gen_mem_layout(0, available_mem, WORD_SIZE_BITS, section_constraints) + # TODO the host emulated device is an outlier, since we don't know how what + # its base address will be until we've created it in the C++. is there any + # way to change the infrastructure around this so it's not so much of an + # outlier? + + # need to zero out all start addresses, because they don't make sense for a + # host device (the memory region is allocated in the backend) + for section in mem_layout: + mem_layout[section]["start"] = 0 return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": { - "text": { - "size": 20480, - }, - "rodata": { - "size": 20480, - }, - "data": { - "size": 768, - }, - "bss": { - "size": 768, - }, - "args": { - "size": 1280, - }, - "heap": { - "size": 262144, - }, - "workspace": { - "size": 20480, - }, - "stack": { - "size": 80, - }, - }, - "word_size": 8 if sys.maxsize > 2**32 else 4, + "mem_layout": mem_layout, + "word_size_bits": WORD_SIZE_BITS, "thumb_mode": False, + "use_device_timer": False, "comms_method": "host", } register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/micro/device/riscv_spike.py b/python/tvm/micro/device/riscv_spike.py index 923e5dfb23a22..32881cab6ba9f 100644 --- a/python/tvm/micro/device/riscv_spike.py +++ b/python/tvm/micro/device/riscv_spike.py @@ -15,14 +15,25 @@ # specific language governing permissions and limitations # under the License. """Compilation and config definitions for Spike, a RISC-V functional ISA simulator""" -from collections import OrderedDict -from . import create_micro_lib_base, register_device +from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "riscv_spike" TOOLCHAIN_PREFIX = "riscv64-unknown-elf-" +WORD_SIZE_BITS = 64 -def create_micro_lib(obj_path, src_path, lib_type, options=None): +DEFAULT_SECTION_CONSTRAINTS = { + "text": (18000, MemConstraint.ABSOLUTE_BYTES), + "rodata": (128, MemConstraint.ABSOLUTE_BYTES), + "data": (128, MemConstraint.ABSOLUTE_BYTES), + "bss": (2048, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (100.0, MemConstraint.WEIGHT), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (32, MemConstraint.ABSOLUTE_BYTES), +} + +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -38,6 +49,9 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + TODO """ create_micro_lib_base( obj_path, @@ -45,11 +59,13 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, - options=options) + options=options, + lib_src_paths=lib_src_paths + ) -def default_config(base_addr, server_addr, server_port): - """Generates a default configuration for Spike +def generate_config(base_addr, available_mem, server_addr, server_port, section_constraints=None): + """Generates a configuration for Spike Parameters ---------- @@ -62,56 +78,31 @@ def default_config(base_addr, server_addr, server_port): server_port : int port of OpenOCD server to connect to + TODO correct type annotation? + section_constraints: Optional[Dict[str, Tuple[Number, MemConstraint]]] + TODO + Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ - res = { + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS + return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": OrderedDict([ - ("text", { - "size": 20480, - }), - ("rodata", { - "size": 20480, - }), - ("data", { - "size": 768, - }), - ("bss", { - "size": 768, - }), - ("args", { - "size": 1280, - }), - ("heap", { - "size": 262144, - }), - ("workspace", { - "size": 20480, - }), - ("stack", { - "size": 80, - }), - ]), - "word_size": 4, - "thumb_mode": True, + "mem_layout": gen_mem_layout(base_addr, available_mem, WORD_SIZE_BITS, section_constraints), + "word_size_bits": WORD_SIZE_BITS, + "thumb_mode": False, + "use_device_timer": False, "comms_method": "openocd", "server_addr": server_addr, "server_port": server_port, } - # generate section start addresses from the given `base_addr` - curr_offset = 0 - mem_layout = res["mem_layout"] - for region_dict in mem_layout.values(): - region_dict["start"] = base_addr + curr_offset - curr_offset += region_dict["size"] - return res register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index c97e67ba6ee38..1d97b55773617 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -154,7 +154,9 @@ def __call__(self, args, attrs, type_args): "nn.dropout": op.nn.dropout_raw, "zeros": op.zeros, "split": op.split, - "cast": op.cast + "cast": op.cast, + "clip": op.clip, + "right_shift": op.right_shift, } TYPE_PREFIXES = [ diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index bf91bc1ffa3ad..43065bef838a1 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _): return _op.transpose(inexpr, axes=(0,) + keras_layer.dims) +def _convert_embedding(inexpr, keras_layer, etab): + indices = inexpr + weightList = keras_layer.get_weights() + weight = etab.new_const(weightList[0]) + out = _op.take(weight, indices.astype('int32'), axis=0) + + return out + def _convert_dense(inexpr, keras_layer, etab): weightList = keras_layer.get_weights() weight = etab.new_const(weightList[0].transpose([1, 0])) @@ -893,7 +901,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument 'Maximum' : _convert_merge, 'Dot' : _convert_merge, 'Permute' : _convert_permute, - # 'Embedding' : _convert_embedding, + 'Embedding' : _convert_embedding, # 'RepeatVector' : _convert_repeat_vector, 'InputLayer' : _default_skip, diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 775eb53d25928..7dbc7881f43cd 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1712,6 +1712,33 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): res = _op.nn.relu(res) return res + +def _mx_broadcast_to(inputs, attrs): + data = inputs[0] + tgt_shape = attrs.get_int_tuple("shape", []) + + return _op.broadcast_to(data, tgt_shape) + + +def _mx_logical_not(inputs, input_types): + data = inputs[0] + dtype = _infer_type(data).checked_type.dtype + data = _op.cast(data, "bool") if dtype != "bool" else data + + return _op.cast(_op.logical_not(data), dtype) + + +def _mx_broadcast_logical(logical_op): + def impl(inputs, input_types): + lhs_type = _infer_type(inputs[0]).checked_type.dtype + rhs_type = _infer_type(inputs[1]).checked_type.dtype + lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0] + rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1] + + return _op.cast(logical_op(lhs, rhs), lhs_type) + return impl + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -1738,12 +1765,15 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "_copy" : _rename(_op.copy), "relu" : _rename(_op.nn.relu), "broadcast_add" : _rename(_op.add), + "broadcast_plus" : _rename(_op.add), "broadcast_sub" : _rename(_op.subtract), + "broadcast_minus" : _rename(_op.subtract), "broadcast_mul" : _rename(_op.multiply), "broadcast_div" : _rename(_op.divide), "broadcast_mod" : _rename(_op.mod), "broadcast_maximum" : _rename(_op.maximum), "broadcast_minimum" : _rename(_op.minimum), + "broadcast_power" : _rename(_op.power), "arctan" : _rename(_op.atan), "broadcast_equal" : _mx_compare(_op.equal, _rename), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), @@ -1751,6 +1781,11 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or), + "broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and), + "broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor), + "broadcast_to" : _mx_broadcast_to, + "logical_not" : _mx_logical_not, "_equal" : _mx_compare(_op.equal, _rename), "_not_equal" : _mx_compare(_op.not_equal, _rename), "_greater" : _mx_compare(_op.greater, _rename), @@ -1860,6 +1895,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "reverse" : _mx_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, + "broadcast_axes": _mx_broadcast_axis, "BlockGrad" : _mx_BlockGrad, "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, @@ -1897,7 +1933,6 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # - # "broadcast_to", # "contrib_fifo_buffer": _mx_contrib_fifo_buffer, "ring_buffer": _mx_contrib_fifo_buffer, # Qnn ops diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 245b3853ae90d..1a4aee0a0d6c6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -26,6 +26,7 @@ from .. import expr as _expr from .. import function as _function from .. import op as _op +from .. import vision as _vision from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import infer_type, infer_value, infer_value_simulated, get_name @@ -324,7 +325,6 @@ class Conv(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. input_shape = infer_shape(inputs[0]) - if 'auto_pad' in attr: attr['auto_pad'] = attr['auto_pad'].decode('utf-8') if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): @@ -349,7 +349,10 @@ def _impl_v1(cls, inputs, attr, params): attr.pop('auto_pad') elif len(attr['kernel_shape']) == 2: sym_pad = True - padding = attr['pads'] + if 'pads' in attr: + padding = attr['pads'] + else: + padding = [0, 0, 0, 0] for i in range(0, len(padding), 2): sym_pad = sym_pad and padding[i] == padding[i + 1] @@ -556,6 +559,31 @@ def _impl_v2(cls, inputs, attr, params): }, )(inputs, attr, params) + @classmethod + def _impl_v11(cls, inputs, attr, params): + pad_width = [] + pads = infer_value_simulated(inputs[1], params).asnumpy() + if len(inputs) == 3: + value = infer_value_simulated(inputs[2], params).asnumpy().item() + else: + value = 0 + attr["pad_value"] = value + dims = int(len(pads) / 2) + for i in range(dims): + pad_width.append((pads[i], pads[i+dims])) + attr['pad_width'] = pad_width + pad_mode = attr.get('mode', b'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + + return AttrCvt('pad')(inputs[:1], attr, params) + + + class ParametricSoftPlus(OnnxOpConverter): """ Operator converter for ParametricSoftPlus. @@ -575,7 +603,12 @@ class Prelu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) - return _op.nn.prelu(inputs[0], inputs[1]) + alpha_shape = infer_shape(inputs[1]) + if len(alpha_shape) != 1: + alpha = _op.reshape(inputs[1], (-1,)) + else: + alpha = inputs[1] + return _op.nn.prelu(inputs[0], alpha) class Reciprocal(OnnxOpConverter): @@ -615,7 +648,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): if get_name(inputs[1]) in params: # pop shape out of parameters since it wont be needed later. - shape = tuple(params.pop(inputs[1].name_hint).asnumpy()) + shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: data, shape = inputs @@ -781,7 +814,10 @@ def _impl_v9(cls, inputs, attr, params): if not scales: #Here we are going to higher OPSET version. assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) - scales = params[inputs[1].name_hint].asnumpy() + if get_name(inputs[1]) in params: + scales = params[inputs[1].name_hint].asnumpy() + else: + scales = infer_value_simulated(inputs[1], params).asnumpy() inputs = inputs[:1] assert scales[0] == 1.0 and scales[1] == 1.0 input_shape = infer_shape(inputs[0]) @@ -942,6 +978,14 @@ def _impl_v1(cls, inputs, attr, params): extras={'axis': axis})(inputs, {}) +class GatherND(OnnxOpConverter): + """ Operator converter for GatherND. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.gather_nd(inputs[0], inputs[1]) + + class Greater(OnnxOpConverter): """ Operator logical greater. """ @@ -1059,6 +1103,11 @@ class ReduceProd(Reduce): """ name = 'prod' +class ReduceLogSumExp(Reduce): + """ Operator converter for ReduceLogSumExp. + """ + name = 'logsumexp' + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -1468,8 +1517,54 @@ def _impl_v9(cls, inputs, attr, params): raise ValueError("Expect 1 input only") output = AttrCvt(op_name='argwhere')(inputs, attr, params) + # ONNX NonZero always outputs int64 + output = _op.cast(output, "int64") return _op.transpose(output, axes=(1, 0)) +class TopK(OnnxOpConverter): + """Operator converter for TopK + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 2: + raise ValueError("Expect 2 input only") + axis = attr.get("axis", -1) + largest = attr.get("largest", 1) + + if largest == 0: + raise ValueError("TVM only supports finding TopK largest elements") + + K = int(infer_value(inputs[1], params).asnumpy()[0]) + + return _op.topk(inputs[0], k=K, axis=axis) + + +class RoiAlign(OnnxOpConverter): + """Operator converter for TopK + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 3: + raise ValueError("Expect 3 inputs only") + x = inputs[0] + rois = inputs[1] + batch_indices = inputs[2] + mode = attr.get("mode", "avg") + if mode != b'avg': + raise ValueError("RoiAlign in Relay only uses avg mode") + output_height = attr.get("output_height", 1) + output_width = attr.get("output_width", 1) + + sampling_ratio = attr.get("sampling_ratio", 0) + spatial_scale = attr.get("spatial_scale", 1.0) + + batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1) + batch_indices = _op.cast( + batch_indices, infer_type(rois).type_annotation.dtype) + rois = _op.concatenate([batch_indices, rois], 1) + + return _vision.roi_align(x, rois, [output_height, output_width], + spatial_scale, sampling_ratio) # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1520,6 +1615,9 @@ def _get_convert_map(opset): 'Reciprocal': Reciprocal.get_converter(opset), 'Floor': Renamer('floor'), 'Ceil': Renamer('ceil'), + 'Round': Renamer('round'), + 'IsInf': Renamer('isinf'), + 'IsNaN': Renamer('isnan'), 'Sqrt': Renamer('sqrt'), 'Relu': Renamer('relu'), 'LeakyRelu': Renamer('leaky_relu'), @@ -1565,16 +1663,21 @@ def _get_convert_map(opset): # Recurrent Layers 'LSTM': LSTM.get_converter(opset), + # defs/vision + 'RoiAlign': RoiAlign.get_converter(opset), + # defs/reduction 'ReduceMax': ReduceMax.get_converter(opset), 'ReduceMin': ReduceMin.get_converter(opset), 'ReduceSum': ReduceSum.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset), 'ReduceProd': ReduceProd.get_converter(opset), - # 'ReduceProd' - # 'ReduceLogSumExp' + 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), + + #defs/sorting 'ArgMax': ArgMax.get_converter(opset), 'ArgMin': ArgMin.get_converter(opset), + 'TopK': TopK.get_converter(opset), # defs/tensor 'Cast': Cast.get_converter(opset), @@ -1587,6 +1690,7 @@ def _get_convert_map(opset): 'DepthToSpace': DepthToSpace.get_converter(opset), 'SpaceToDepth': SpaceToDepth.get_converter(opset), 'Gather': Gather.get_converter(opset), + 'GatherND': GatherND.get_converter(opset), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e185f5817e879..64f30f35b3767 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -923,7 +923,7 @@ def _impl(inputs, input_types): axes[src] = dst axes[dst] = src else: - axes = inputs[1] + axes = _infer_shape(inputs[1], prelude.mod) return _op.transform.transpose(data, axes) return _impl diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a30387e34f2ac..ab9e9e656516f 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1944,6 +1944,8 @@ def _impl(inputs, attr, params, mod): # for N to 1 mapping, currently not supported(?) _convert_map = { 'Abs' : AttrCvt('abs'), + 'Acos' : AttrCvt('acos'), + 'Acosh' : AttrCvt('acosh'), 'Add' : _elemwise('add'), 'AddN' : _add_n(), 'AddV2' : _elemwise('add'), @@ -1951,8 +1953,11 @@ def _impl(inputs, attr, params, mod): 'Any' : _reduce('any'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), + 'Asin' : AttrCvt('asin'), + 'Asinh' : AttrCvt('asinh'), 'Assert' : _assert(), 'Atan' : AttrCvt('atan'), + 'Atanh' : AttrCvt('atanh'), 'Atan2' : _atan2(), 'AvgPool' : _pooling('avg_pool'), 'AvgPool3D' : _pool3d('avg_pool3d'), @@ -1972,6 +1977,7 @@ def _impl(inputs, attr, params, mod): 'Conv2DBackpropInput' : _conv('conv_transpose'), 'Conv3D' : _conv3d('conv'), 'Cos' : AttrCvt('cos'), + 'Cosh' : AttrCvt('cosh'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthToSpace' : _depth_to_space(), @@ -2048,6 +2054,7 @@ def _impl(inputs, attr, params, mod): 'Sigmoid' : AttrCvt('sigmoid'), 'Sign' : AttrCvt('sign'), 'Sin' : AttrCvt('sin'), + 'Sinh' : AttrCvt('sinh'), 'Size' : _size(), 'Slice' : _slice(), 'Softmax' : _softmax(), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 275d0ce11ef24..5a645c67cf61b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -29,9 +29,9 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape +from .tflite_flexbuffer import FlexBufferDecoder __all__ = ['from_tflite'] @@ -65,6 +65,7 @@ def __init__(self, model, subgraph, exp_tab): self.convert_map = { 'ABS': self.convert_abs, 'ADD': self.convert_add, + 'ADD_N': self.convert_add_n, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, 'CAST': self.convert_cast, @@ -79,6 +80,7 @@ def __init__(self, model, subgraph, exp_tab): 'ELU': self.convert_elu, 'EQUAL': self.convert_equal, 'EXP': self.convert_exp, + 'FILL': self.convert_fill, 'FLOOR_DIV': self.convert_floor_div, 'FLOOR_MOD': self.convert_floor_mod, 'FLOOR': self.convert_floor, @@ -88,16 +90,18 @@ def __init__(self, model, subgraph, exp_tab): 'GREATER': self.convert_greater, 'HARD_SWISH': self.convert_hard_swish, 'L2_NORMALIZATION': self.convert_l2_normalization, + 'L2_POOL_2D': self.convert_l2_pool2d, 'LESS_EQUAL': self.convert_less_equal, 'LESS': self.convert_less, 'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn, 'LOG': self.convert_log, 'LOGICAL_AND': self.convert_logical_and, + 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, - 'MEAN': self._convert_reduce_mean, + 'MEAN': self.convert_reduce_mean, 'MINIMUM': self.convert_minimum, 'MIRROR_PAD': self.convert_mirror_pad, 'MUL': self.convert_mul, @@ -107,29 +111,31 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, - 'REDUCE_ANY': self._convert_reduce_any, - 'REDUCE_MAX': self._convert_reduce_max, - 'REDUCE_MIN': self._convert_reduce_min, - 'REDUCE_PROD': self._convert_reduce_prod, + 'REDUCE_ANY': self.convert_reduce_any, + 'REDUCE_MAX': self.convert_reduce_max, + 'REDUCE_MIN': self.convert_reduce_min, + 'REDUCE_PROD': self.convert_reduce_prod, 'RELU':self.convert_relu, 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'SPACE_TO_DEPTH': self.convert_space_to_depth, 'SPLIT': self.convert_split, + 'SPLIT_V': self.convert_split_v, 'SQRT': self.convert_sqrt, 'SQUARE': self.convert_square, 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'SQUEEZE': self.convert_squeeze, 'STRIDED_SLICE': self.convert_strided_slice, 'SUB': self.convert_sub, - 'SUM': self._convert_reduce_sum, + 'SUM': self.convert_reduce_sum, 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, @@ -137,6 +143,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -317,6 +324,45 @@ def dequantize(self, expr, tensor): input_zero_point=tensor.qnn_params['zero_point']) return dequantized + + def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, + scale, zero_point, dtype): + """Convert TFLite fused activation function. The expr is an input quantized tensor with + scale and zero point """ + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + # Quantize a float value to an quantized integer value + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin = float(tvm.tir.op.min_value(dtype).value) + qmax = float(tvm.tir.op.max_value(dtype).value) + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == ActivationFunctionType.NONE: + return expr + if fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(expr, + a_min=max(qmin, quantize(0)), + a_max=min(qmax, quantize(6.0))) + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(expr, + a_min=max(qmin, quantize(-1.0)), + a_max=min(qmax, quantize(1.0))) + if fused_activation_fn == ActivationFunctionType.RELU: + return _op.clip(expr, + a_min=max(qmin, quantize(0.0)), + a_max=qmax) + + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise tvm.error.OpNotImplemented( + 'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str)) + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -333,6 +379,10 @@ def convert_max_pool2d(self, op): """Convert TFLite max pool2d""" return self.convert_pool2d(op, "max") + def convert_l2_pool2d(self, op): + """Convert TFLite l2 pool2d""" + return self.convert_pool2d(op, "l2") + def convert_reshape(self, op): """Convert TFLite reshape""" try: @@ -421,7 +471,6 @@ def convert_l2_normalization(self, op): try: from tflite.BuiltinOptions import BuiltinOptions from tflite.L2NormOptions import L2NormOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -446,17 +495,15 @@ def convert_l2_normalization(self, op): if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + # TFL uses only the default epsilon value out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) # if we have fused activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'TFLite quantized L2_NORMALIZATION operator\ - with fused activation function is not supported yet.') + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -601,7 +648,6 @@ def convert_concatenation(self, op): try: from tflite.ConcatenationOptions import ConcatenationOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -633,14 +679,20 @@ def convert_concatenation(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], axis=concatenation_axis) - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.concatenate')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def _convert_unary_elemwise(self, relay_op, op): @@ -783,7 +835,6 @@ def _convert_elemwise(self, relay_op, op): from tflite.MulOptions import MulOptions from tflite.DivOptions import DivOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -791,28 +842,9 @@ def _convert_elemwise(self, relay_op, op): assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - if self.has_expr(lhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - lhs_expr = self.get_expr(lhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type()) - lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor), - dtype=lhs_type_str) - rhs_tensor = input_tensors[1] - if self.has_expr(rhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - rhs_expr = self.get_expr(rhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) - rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), - dtype=rhs_type_str) + lhs_expr = self.get_tensor_expr(lhs_tensor) + rhs_expr = self.get_tensor_expr(rhs_tensor) output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" @@ -848,13 +880,20 @@ def _convert_elemwise(self, relay_op, op): op_options = op.BuiltinOptions() options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Elemwise operators with fused activation are not supported yet.') - out = self.convert_fused_activation_function(out, fused_activation_fn) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): @@ -864,6 +903,20 @@ def convert_add(self, op): return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) + def convert_add_n(self, op): + """Convert TFLite ADD_N""" + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + + input_tensors = self.get_input_tensors(op) + assert not input_tensors[0].qnn_params, "TFLite does not support quantized ADD_N." + lhs_expr = self.get_tensor_expr(input_tensors[0]) + for rhs_tensor in input_tensors[1:]: + assert not rhs_tensor.qnn_params, "TFLite does not support quantized ADD_N" + rhs_expr = self.get_tensor_expr(rhs_tensor) + lhs_expr = _op.add(lhs_expr, rhs_expr) + return lhs_expr + def convert_sub(self, op): """Convert TFLite SUB""" # Check if the input tensor is quantized, call QNN op @@ -986,6 +1039,16 @@ def convert_logical_or(self, op): """Convert tflite LOGICAL_OR""" return self._convert_logical_binary(_op.logical_or, op) + def convert_logical_not(self, op): + """Convert tflite LOGICAL_NOT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + data = self.get_expr(input_tensors[0].tensor_idx) + out = _op.logical_not(data) + + return out + def convert_gather(self, op): """Method to Convert TFLite GATHER operator""" try: @@ -1208,8 +1271,23 @@ def convert_zeros_like(self, op): return out + def convert_fill(self, op): + """Convert TFLite FILL""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + if self.has_expr(input_tensors[0].tensor_idx): + raise tvm.error.OpNotImplemented("For dims parameter of Fill operator," + " only constant values are supported.") + + in_dims = list(self.get_tensor_value(input_tensors[0])) + in_value_expr = self.get_expr(input_tensors[1].tensor_idx) + out = _op.full(in_value_expr, in_dims) + + return out + def _convert_reduce(self, relay_op, op): - """Generic method to Convert TFLite MEAN operators""" + """Generic method to Convert TFLite REDUCE operators""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.ReducerOptions import ReducerOptions @@ -1253,22 +1331,22 @@ def _convert_reduce(self, relay_op, op): return out - def _convert_reduce_min(self, op): + def convert_reduce_min(self, op): return self._convert_reduce(_op.reduce.min, op) - def _convert_reduce_max(self, op): + def convert_reduce_max(self, op): return self._convert_reduce(_op.reduce.max, op) - def _convert_reduce_mean(self, op): + def convert_reduce_mean(self, op): return self._convert_reduce(_op.reduce.mean, op) - def _convert_reduce_prod(self, op): + def convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) - def _convert_reduce_sum(self, op): + def convert_reduce_sum(self, op): return self._convert_reduce(_op.reduce.sum, op) - def _convert_reduce_any(self, op): + def convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) def convert_fully_connected(self, op): @@ -1277,7 +1355,6 @@ def convert_fully_connected(self, op): from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1297,16 +1374,28 @@ def convert_fully_connected(self, op): input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() - # reshape input tensor from N H W C to N H*W*C - input_size_per_batch = 1 - for s in range(1, len(input_tensor_shape)): - input_size_per_batch *= input_tensor_shape[s] - assert input_size_per_batch == weight_tensor_shape[1], \ - "input size and weight size are mismatched" - target_shape = tuple((input_tensor_shape[0], input_size_per_batch)) + # Weight should have only 2 dimensions(TFLite convention) + assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" + + # Input shape: [i_batch_size, ..., n_inputs] + # Filter shape: [n_inputs, n_units] + # + # As we will transform Fully_Connected Input to Dense Op inputs as below + # Dense expected Input shape: [batch_size, n_units] + # Dense expected Weight shape: [out_dim, n_units] + # Dense output shape: [batch_size, out_dim] + # So it is evident that input shape: [batch_size = input_size / n_units, n_units] + input_size = 1 + for _, shape in enumerate(input_tensor_shape): + input_size *= shape + + # First get the batch size + batch_size = int(input_size / weight_tensor_shape[1]) + target_shape = tuple((batch_size, weight_tensor_shape[1])) in_expr = self.get_expr(input_tensor_idx) in_expr = _op.reshape(in_expr, target_shape) + #TODO: Change the output shape calculation based on keep_dim option assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions op_options = op.BuiltinOptions() fully_connected_options = FullyConnectedOptions() @@ -1318,8 +1407,11 @@ def convert_fully_connected(self, op): assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) - weight_value = self.get_tensor_value(weight_tensor) - weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + if self.has_expr(weight_tensor.tensor_idx): + weight_expr = self.get_expr(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_shape = _infer_shape(weight_expr) if input_tensor.qnn_params: @@ -1344,15 +1436,6 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str) out = _op.nn.bias_add(out, bias_expr) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.dense')) - # Finally if the dense is quantized. Add a requantize at the end. if output_tensor.qnn_params: data_scale = input_tensor.qnn_params['scale'] @@ -1362,6 +1445,8 @@ def convert_fully_connected(self, op): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1369,6 +1454,19 @@ def convert_fully_connected(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_squeeze(self, op): @@ -1403,7 +1501,9 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert fused_activation_fn != ActivationFunctionType.NONE + + if fused_activation_fn == ActivationFunctionType.NONE: + return in_expr if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) if fused_activation_fn == ActivationFunctionType.RELU: @@ -1414,13 +1514,12 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.TensorType import TensorType from tflite.Conv2DOptions import Conv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions @@ -1551,17 +1650,9 @@ def convert_conv(self, op, conv_type): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.conv2d')) - - # Finally if the conv is quantized. Add a requantize at the end. + # Handle fused activation. if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) @@ -1569,6 +1660,8 @@ def convert_conv(self, op, conv_type): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1576,6 +1669,18 @@ def convert_conv(self, op, conv_type): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_split(self, op): @@ -1613,6 +1718,35 @@ def convert_split(self, op): return out + def convert_split_v(self, op): + """SPLIT_V implementation.""" + input_tensors = self.get_input_tensors(op) + + assert len(input_tensors) == 3, "input tensors length should be 3" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + in_expr = self.get_expr(input_tensor_idx) + + if self.has_expr(input_tensors[1].tensor_idx): + raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, " + "only constant values are supported.") + size_splits = list(self.get_tensor_value(input_tensors[1])) + size_splits = tuple(np.cumsum(size_splits)[:-1]) + + axis_tensor = input_tensors[2] + split_axis = self.get_tensor_value(axis_tensor) + + out = _op.split(in_expr, size_splits, axis=int(split_axis)) + # Relay does not like a TupleWrapper of 1 element, further this + # only shows up with tf1.13 if we use a split with num_splits==1. + # In tf 1.14 this doesn't appear as it is automatically a reshape + # operation. + if isinstance(out, _expr.TupleWrapper) and out.size == 1: + out = out[0] + + return out + def convert_slice(self, op): """Convert TFLite SLICE""" input_tensors = self.get_input_tensors(op) @@ -1636,6 +1770,18 @@ def convert_slice(self, op): return out + def convert_select(self, op): + """Convert TFLite SELECT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + cond = self.get_tensor_expr(input_tensors[0]) + x = self.get_tensor_expr(input_tensors[1]) + y = self.get_tensor_expr(input_tensors[2]) + + out = _op.where(cond, x, y) + + return out + def convert_transpose(self, op): """transpose implementation.""" input_tensors = self.get_input_tensors(op) @@ -1710,7 +1856,6 @@ def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.Pool2DOptions import Pool2DOptions from tflite.Padding import Padding except ImportError: @@ -1771,17 +1916,33 @@ def convert_pool2d(self, op, pool_type): assert self.has_same_qnn_params(input_tensor, output_tensor), \ "qnn.op.max_pool2d requires input and output qnn params to be same" out = _op.nn.max_pool2d(in_expr, **params) + elif pool_type == "l2": + # L2_POOL_2D is equivalent to square_root(avg_pool(square(in_data))) + # TFLite does not have support for quantised L2_POOL_2D op. + assert not input_tensor.qnn_params, \ + "As TFLite does not have support for quantized L2_POOL_2D, \ + Quantized input is not expected." + exp_type = self.get_tensor_type_str(output_tensor.tensor.Type()) + square_exp = _op.power(in_expr, relay.const(2, exp_type)) + avg_pool_exp = _op.nn.avg_pool2d(square_exp, **params) + out = _op.sqrt(avg_pool_exp) else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if input_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.pool2d')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_pad(self, op): @@ -2170,33 +2331,21 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - _option_names = [ - "w_scale", - "max_detections", - "_output_quantized", - "detections_per_class", - "x_scale", - "nms_score_threshold", - "num_classes", - "max_classes_per_detection", - "use_regular_nms", - "y_scale", - "h_scale", - "_support_output_type_float_in_quantized_op", - "nms_iou_threshold" - ] - - custom_options = get_custom_options(op, _option_names) - if custom_options["use_regular_nms"]: - raise tvm.error.OpAttributeUnImplemented( - "use_regular_nms=True is not yet supported for operator {}." - .format("TFLite_Detection_PostProcess") - ) + flexbuffer = op.CustomOptionsAsNumpy().tobytes() + custom_options = FlexBufferDecoder(flexbuffer).decode() + + if "use_regular_nms" in custom_options: + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" cls_pred = self.get_expr(inputs[1].tensor_idx) loc_prob = self.get_expr(inputs[0].tensor_idx) + batch_size = inputs[1].tensor.Shape(0) anchor_values = self.get_tensor_value(inputs[2]) anchor_boxes = len(anchor_values) anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type()) @@ -2224,7 +2373,7 @@ def convert_detection_postprocess(self, op): loc_prob = _op.concatenate( [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2 ) - loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4]) + loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4]) # anchor coords are in yxhw format # need to convert to ltrb @@ -2267,10 +2416,14 @@ def convert_detection_postprocess(self, op): ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) ret = _op.vision.get_valid_counts(ret, 0) valid_count = ret[0] + # keep only the top 'max_detections' rows + ret = _op.strided_slice(ret[1], + [0, 0, 0], + [batch_size, custom_options["max_detections"], anchor_boxes]) # the output needs some reshaping to match tflite - ret = _op.split(ret[1], 6, axis=2) - cls_ids = ret[0] - scores = ret[1] + ret = _op.split(ret, 6, axis=2) + cls_ids = _op.reshape(ret[0], [batch_size, -1]) + scores = _op.reshape(ret[1], [batch_size, -1]) boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2) ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) return ret @@ -2281,6 +2434,31 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr + + +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + + def build_str_map(obj): """Build string map of TFLite enum int value @@ -2346,98 +2524,13 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_custom_options(op, option_names): - """Get the options of a custom operator. - - This implements partial flexbuffer deserialization to be able - to read custom options. It is not intended to be a general - purpose flexbuffer deserializer and as such only supports a - limited number of types and assumes the data is a flat map. - - Parameters - ---------- - op: - A custom TFlite operator. - option_names: list - A complete list of the custom option names. - - Returns - ------- - options: dict - A dictionary of the custom options. - - """ - import struct - from enum import IntEnum - - class _FlexBufferType(IntEnum): - """Flexbuffer type schema from flexbuffers.h""" - FBT_NULL = 0 - FBT_INT = 1 - FBT_UINT = 2 - FBT_FLOAT = 3 - # Types above stored inline, types below store an offset. - FBT_KEY = 4 - FBT_STRING = 5 - FBT_INDIRECT_INT = 6 - FBT_INDIRECT_UINT = 7 - FBT_INDIRECT_FLOAT = 8 - FBT_MAP = 9 - FBT_VECTOR = 10 # Untyped. - FBT_VECTOR_INT = 11 # Typed any size (stores no type table). - FBT_VECTOR_UINT = 12 - FBT_VECTOR_FLOAT = 13 - FBT_VECTOR_KEY = 14 - FBT_VECTOR_STRING = 15 - FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). - FBT_VECTOR_UINT2 = 17 - FBT_VECTOR_FLOAT2 = 18 - FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). - FBT_VECTOR_UINT3 = 20 - FBT_VECTOR_FLOAT3 = 21 - FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). - FBT_VECTOR_UINT4 = 23 - FBT_VECTOR_FLOAT4 = 24 - FBT_BLOB = 25 - FBT_BOOL = 26 - FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type - - buffer = op.CustomOptionsAsNumpy().tobytes() - value_vector_offset = buffer[-3] - buffer = buffer[:-3] - num_bytes = 4 # Assume all values are stored in 32 bit width - value_vector_size = struct.unpack( - "> 2) - value_offset = -value_vector_offset + i*num_bytes - value_bytes = buffer[value_offset:value_offset+num_bytes] - if flex_type == _FlexBufferType.FBT_BOOL: - value = bool(value_bytes[0]) - if flex_type == _FlexBufferType.FBT_INT: - value = struct.unpack("> 2) + value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width] + if value_type == FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + elif value_type == FlexBufferType.FBT_INT: + value = struct.unpack("> 2) + byte_width = 1 << BitWidth(root_packed_type & 3) + + if root_type == FlexBufferType.FBT_MAP: + return self.decode_map(root_end, byte_width, root_byte_width) + raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.") diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index ac006a4debe49..e029e0cc3589d 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -16,10 +16,10 @@ # under the License. #pylint: disable=invalid-name, unused-argument, len-as-condition """Backend compiler related feature registration""" -import topi from tvm.runtime import convert from tvm.te.hybrid import script +import topi from topi.util import get_const_tuple from .op import register_compute, register_shape_func from .op import register_broadcast_schedule, register_injective_schedule @@ -34,7 +34,12 @@ register_broadcast_schedule("cosh") register_broadcast_schedule("sin") register_broadcast_schedule("sinh") +register_broadcast_schedule("acos") +register_broadcast_schedule("acosh") +register_broadcast_schedule("asin") +register_broadcast_schedule("asinh") register_broadcast_schedule("atan") +register_broadcast_schedule("atanh") register_broadcast_schedule("exp") register_broadcast_schedule("erf") register_broadcast_schedule("sqrt") diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 82e8e30e56c14..8be335842f0ee 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -35,6 +35,7 @@ power, sin, sinh, + sqrt, zeros_like, equal, shape_of, @@ -98,10 +99,9 @@ def cos_grad(orig, grad): @register_gradient("cosh") def cosh_grad(orig, grad): - """Returns [grad * (-sinh(x))]""" + """Returns [grad * sinh(x)]""" x = orig.args[0] - ones = ones_like(x) - return [grad * (-ones * sinh(x))] + return [grad * sinh(x)] @register_gradient("sin") @@ -110,18 +110,61 @@ def sin_grad(orig, grad): x = orig.args[0] return [grad * cos(x)] + @register_gradient("sinh") def sinh_grad(orig, grad): """Returns [grad * cosh(x)]""" x = orig.args[0] return [grad * cosh(x)] + +@register_gradient("acos") +def acos_grad(orig, grad): + """Returns [grad * -1/((1 - (x ^ 2)) ^ 1/2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * (-ones / sqrt(ones - (x * x)))] + + +@register_gradient("acosh") +def acosh_grad(orig, grad): + """Returns [grad * 1/((x - 1) ^ 1/2 * (x + 1) ^ 1/2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt((x * x) - ones)] + + +@register_gradient("asin") +def asin_grad(orig, grad): + """Returns [grad * 1/((1 - (x ^ 2)) ^ (1/2))]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt(ones - (x * x))] + + +@register_gradient("asinh") +def asinh_grad(orig, grad): + """Returns [grad * 1/((1 + (x ^ 2)) ^ (1/2))]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt(ones + (x * x))] + + @register_gradient("atan") def atan_grad(orig, grad): """Returns [grad * 1 / (1 + x ^ 2)]""" x = orig.args[0] - a = const(2.0) - return [grad * ones_like(x) / (ones_like(x) + power(x, a))] + ones = ones_like(x) + return [grad * ones / (ones + (x * x))] + + +@register_gradient("atanh") +def atanh_grad(orig, grad): + """Returns [grad * 1 / (1 - x ^ 2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / (ones - (x * x))] + @register_gradient("exp") def exp_grad(orig, grad): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5f6aa898711b8..ad8c654269a8e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -502,6 +502,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype): reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) +# dilate +@reg.register_compute("nn.dilate") +def compute_dilate(attrs, inputs, out_dtype): + return [topi.nn.dilate(inputs[0], attrs.strides)] + +reg.register_broadcast_schedule("nn.dilate") +reg.register_pattern("nn.dilate", OpPattern.INJECTIVE) + + # cross_entropy_with_logits @reg.register_compute("nn.cross_entropy_with_logits") def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): @@ -697,6 +706,21 @@ def pad_shape_func(attrs, inputs, _): pad_width.append(get_const_tuple(pair)) return [_pad_shape_func(inputs[0], convert(pad_width))] +@script +def _dilate_shape_func(data_shape, strides): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0]): + out[i] = (data_shape[i] - 1) * strides[i] + 1 + + return out + +@reg.register_shape_func("nn.dilate", False) +def dilate_shape_func(attrs, inputs, _): + """ + Shape function for dilate op. + """ + return [_dilate_shape_func(inputs[0], convert(attrs.strides))] + reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 622b0faaccea2..c879eb6cb415c 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1347,6 +1347,25 @@ def pad(data, return _make.pad(data, pad_width, pad_value, pad_mode) +def dilate(data, strides): + """Dilate data with zeros. + + Parameters + ---------- + data : tvm.relay.Expr + n-D, can be any layout. + + strides : + Dilation stride on each dimension, 1 means no dilation. + + Returns + ------- + Output : tvm.relay.Expr + The computed result + """ + return _make.dilate(data, strides) + + def mirror_pad(data, pad_width, mode="SYMMETRIC"): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index a47be76738304..a1c73ef41ba54 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -350,6 +350,11 @@ class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" +@tvm._ffi.register_object("relay.attrs.DilateAttrs") +class DilateAttrs(Attrs): + """Attributes used in dilate operators""" + + @tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index d3226012e8872..988c94928d336 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin from . import _make -from .tensor import sqrt +from .tensor import sqrt, log, exp from .transform import squeeze from ..expr import Tuple, TupleWrapper @@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis return _make.prod(data, axis, keepdims, exclude) + + +def logsumexp(data, axis=None, keepdims=False): + """Compute the log of the sum of exponentials of input elements over given axes. + + This function is more numerically stable than log(sum(exp(input))). + It avoids overflows caused by taking the exp of large inputs and underflows + caused by taking the log of small inputs. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a standard deviation operation is performed. + The default, axis=None, will compute the log of the sum of exponentials of all elements + in the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + axis = [axis] if isinstance(axis, int) else axis + max_x = max(data, axis, True) + exp_x = exp(data - max_x) + sum_x = sum(exp_x, axis, True) + out_x = log(sum_x) + max_x + if not keepdims: + out_x = squeeze(out_x, axis) + return out_x diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 942d4c7f86af9..6bdec67617e12 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -20,24 +20,25 @@ import logging import topi +from ....target import arm_isa from .generic import * from .. import op as _op logger = logging.getLogger('strategy') -@schedule_injective.register("arm_cpu") +@schedule_injective.register(["arm_cpu", "micro_dev"]) def schedule_injective_arm_cpu(_, outs, target): """schedule injective ops for arm cpu""" with target: return topi.arm_cpu.schedule_injective(outs) -@schedule_concatenate.register("arm_cpu") +@schedule_concatenate.register(["arm_cpu", "micro_dev"]) def schedule_concatenate_arm_cpu(_, outs, target): """schedule concatenate for arm cpu""" with target: return topi.arm_cpu.schedule_concatenate(outs) -@conv2d_strategy.register("arm_cpu") +@conv2d_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" strategy = _op.OpStrategy() @@ -51,6 +52,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") + isa = arm_isa.IsaAnalyzer(target) + if groups == 1: if layout == "NCHW": if kernel_layout == "OIHW": @@ -102,11 +105,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), name="conv2d_hwcn.generic") elif layout == "NHWC": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), - name="conv2d_nhwc_spatial_pack.arm_cpu") + channels = data.shape[3] + if "SMLAD" in isa and (channels % 4) == 0 and kernel_layout == "HWOI": + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_direct_simd), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd), + name='conv2d_direct_simd.micro_dev') + elif kernel_layout == "HWIO": + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.arm_cpu") + else: + raise RuntimeError("Unsupported kernel layout {} for conv2d NHWC". + format(kernel_layout)) + + else: raise RuntimeError("Unsupported conv2d layout {} for arm cpu".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -232,7 +246,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out format(layout)) return strategy -@conv2d_transpose_strategy.register("arm_cpu") +@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d_transpose arm cpu strategy""" layout = attrs.data_layout diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e9769783269fb..83e4e40b53b9b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -58,7 +58,7 @@ def schedule_pool_grad_cuda(attrs, outs, target): def schedule_adaptive_pool_cuda(attrs, outs, target): """schedule adaptive pooling ops for cuda""" with target: - return topi.cuda.schedule_adaptive_pool(outs) + return topi.cuda.schedule_adaptive_pool(outs, attrs.layout) @softmax_strategy.register(["cuda", "gpu"]) def softmax_strategy_cuda(attrs, inputs, out_type, target): @@ -136,8 +136,32 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.conv2d_nhwc), wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), name="conv2d_nhwc.cuda") - N, _, _, _ = get_const_tuple(data.shape) - _, _, CI, CO = get_const_tuple(kernel.shape) + N, H, W, _ = get_const_tuple(data.shape) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + # Winograd shape related judgment + judge_winograd_tensorcore, judge_winograd_shape = winograd_judge(N, H, W, KH, KW, + CI, CO, padding, + stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=False) + if judge_winograd_shape: + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_tensorcore), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore), + name="conv2d_nhwc_winograd_tensorcore.cuda", + plevel=5) + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct), + name="conv2d_nhwc_winograd_direct.cuda", + plevel=5) if target.target_name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ @@ -220,6 +244,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout + data, kernel = inputs + stride_h, stride_w = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") assert dilation == (1, 1), "Do not support dilate now" assert groups == 1, "Do not supoort arbitrary group number" strategy = _op.OpStrategy() @@ -229,6 +256,30 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty wrap_topi_schedule( topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform), name="conv2d_nchw_winograd_without_weight_transform.cuda") + elif layout == "NHWC": + N, H, W, _ = get_const_tuple(data.shape) + alpha, _, CI, CO = get_const_tuple(kernel.shape) + dilation_h, dilation_w = dilation + judge_winograd_tensorcore, _ = winograd_judge(N, H, W, alpha, alpha, CI, CO, + padding, stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=True) + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_tensorcore_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform), + name="conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct_without_weight_transform), + name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda") else: raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}". format(layout)) @@ -516,3 +567,26 @@ def proposal_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_proposal), name="proposal.cuda") return strategy + +def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h, + stride_w, dilation_h, dilation_w, pre_flag): + """Winograd judgement about tensorcore and shape""" + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + if pre_flag: + alpha = KH + KH = KW = alpha + 1 - tile_size + pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 + nH, nW = (OH + tile_size - 1) // tile_size, (OW + tile_size - 1) // tile_size + P = N * nH * nW + judge_winograd_tensorcore = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + judge_winograd_shape = 2 < KH < 8 and 2 < KW < 8 and KH == KW and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1 + return judge_winograd_tensorcore, judge_winograd_shape diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index a50a25808e6d9..d5ae5cd3d2d7c 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -152,6 +152,66 @@ def sinh(data): """ return _make.sinh(data) +def acos(data): + """Compute elementwise acos of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.acos(data) + +def acosh(data): + """Compute elementwise acosh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.acosh(data) + +def asin(data): + """Compute elementwise asin of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.asin(data) + +def asinh(data): + """Compute elementwise asinh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.asinh(data) + def atan(data): """Compute elementwise atan of data. @@ -167,6 +227,21 @@ def atan(data): """ return _make.atan(data) +def atanh(data): + """Compute elementwise atanh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.atanh(data) + def exp(data): """Compute elementwise exp of data. diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index b1c19092b4c70..c96a730ee6ed2 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,8 +20,8 @@ import tvm from tvm import relay +import numpy as np from .. import op as reg -from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. @@ -54,6 +54,15 @@ def qnn_dense_legalize(attrs, inputs, types): # Helper functions. ################### +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, relay.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5c1baef4db945..5a3106d1e7875 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make @@ -156,7 +156,7 @@ def concatenate(data, Parameters ---------- - data : Union(List[relay.Expr], Tuple[relay.Expr]) + data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr]) The list of quantized tensors. input_scales : List[relay.Expr] @@ -180,15 +180,16 @@ def concatenate(data, The concatenated quantized tensor. """ - data = list(data) - if not data: - raise ValueError("relay.concatenate requires data to be non-empty.") + if isinstance(data, (list, tuple)): + data = Tuple(data) + elif isinstance(data, TupleWrapper): + data = data.tuple_value if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") input_scales = list(input_scales) input_zero_points = list(input_zero_points) - return _make.concatenate(Tuple(data), + return _make.concatenate(data, Tuple(input_scales), Tuple(input_zero_points), output_scale, diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1a231eb1aaedf..dc7937c0b3465 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -183,11 +183,16 @@ def get_workload_official(model_url, model_sub_path): model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official']) dir_path = os.path.dirname(model_path) - import tarfile if model_path.endswith("tgz") or model_path.endswith("gz"): + import tarfile tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() + elif model_path.endswith("zip"): + import zipfile + zip_object = zipfile.ZipFile(model_path) + zip_object.extractall(path=dir_path) + zip_object.close() else: raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) diff --git a/python/tvm/rpc/__init__.py b/python/tvm/rpc/__init__.py index 5f959eb44745b..b64ba33d9e096 100644 --- a/python/tvm/rpc/__init__.py +++ b/python/tvm/rpc/__init__.py @@ -26,4 +26,6 @@ """ from .server import Server -from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker +from .client import connect, connect_tracker +from .client import RPCSession, LocalSession, PopenSession, TrackerSession +from .minrpc import with_minrpc diff --git a/python/tvm/relay/frontend/util.py b/python/tvm/rpc/_ffi_api.py similarity index 53% rename from python/tvm/relay/frontend/util.py rename to python/tvm/rpc/_ffi_api.py index a7f89a30b9964..1a7cc739b5c1f 100644 --- a/python/tvm/relay/frontend/util.py +++ b/python/tvm/rpc/_ffi_api.py @@ -14,20 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -""" Utility functions that are used across many directories. """ -from __future__ import absolute_import -import numpy as np -from .. import expr as _expr +"""FFI APIs for tvm.rpc""" +import tvm._ffi -def get_scalar_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ - "Expr is not a constant scalar." - value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint + +tvm._ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d99..f0e33f8503f28 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -17,8 +17,6 @@ """Base definitions for RPC.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import socket import time import json @@ -26,7 +24,6 @@ import struct import random import logging -import tvm._ffi from .._ffi.base import py_str @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5): logger.warning("Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period) time.sleep(retry_period) - - -# Still use tvm.rpc for the foreign functions -tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ed57e0d4276d4..3f38c4f2f668b 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -15,19 +15,20 @@ # specific language governing permissions and limitations # under the License. """RPC client tools""" -from __future__ import absolute_import - import os +import stat import socket import struct import time + import tvm._ffi from tvm.contrib import util from tvm._ffi.base import TVMError from tvm.runtime import ndarray as nd -from tvm.runtime import load_module as _load_module from . import base +from . import server +from . import _ffi_api class RPCSession(object): @@ -38,9 +39,23 @@ class RPCSession(object): # pylint: disable=invalid-name def __init__(self, sess): self._sess = sess - self._tbl_index = base._SessTableIndex(sess) + self._tbl_index = _ffi_api.SessTableIndex(sess) self._remote_funcs = {} + def system_lib(self): + """Get system-wide library module. + + Returns + ------- + module : runtime.Module + The system-wide library module. + + See Also + -------- + tvm.runtime.system_lib + """ + return self.get_function("runtime.SystemLib")() + def get_function(self, name): """Get function from the session. @@ -145,7 +160,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return base._LoadRemoteModule(self._sess, path) + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" @@ -175,6 +190,10 @@ def ext_dev(self, dev_id=0): """Construct extension device.""" return self.context(12, dev_id) + def webgpu(self, dev_id=0): + """Construct WebGPU device.""" + return self.context(15, dev_id) + class LocalSession(RPCSession): """RPCSession interface backed by local environment. @@ -183,28 +202,41 @@ class LocalSession(RPCSession): need to be ran both locally and remotely. """ def __init__(self): - # pylint: disable=super-init-not-called - self.context = nd.context - self.get_function = tvm._ffi.get_global_func - self._temp = util.tempdir() + self._temp = server._server_env([]) + RPCSession.__init__(self, _ffi_api.LocalSession()) - def upload(self, data, target=None): - if isinstance(data, bytearray): - if not target: - raise ValueError("target must present when file is a bytearray") - blob = data - else: - blob = bytearray(open(data, "rb").read()) - if not target: - target = os.path.basename(data) - with open(self._temp.relpath(target), "wb") as f: - f.write(blob) - def download(self, path): - return bytearray(open(self._temp.relpath(path), "rb").read()) +@tvm._ffi.register_func("rpc.PopenSession") +def _popen_session(binary): + temp = util.tempdir() + + if isinstance(binary, (bytes, bytearray)): + path_exec = temp.relpath("server.minrpc") + with open(path_exec, "wb") as outfile: + outfile.write(binary) + os.chmod(path_exec, stat.S_IXUSR | stat.S_IRUSR) + path_exec = os.path.abspath(path_exec) + else: + path_exec = os.path.abspath(binary) + if not os.path.isfile(path_exec): + raise RuntimeError(f"{path_exec} does not exist.") + if not os.access(path_exec, os.X_OK): + raise RuntimeError(f"{path_exec} is not executable.") + + sess = _ffi_api.CreatePipeClient(path_exec) + return sess - def load_module(self, path): - return _load_module(self._temp.relpath(path)) + +class PopenSession(RPCSession): + """RPCSession interface backed by popen. + + Parameters + ---------- + binary : List[Union[str, bytes]] + The binary to be executed. + """ + def __init__(self, binary): + RPCSession.__init__(self, _popen_session(binary)) class TrackerSession(object): @@ -378,7 +410,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0): +def connect(url, port, key="", session_timeout=0, session_constructor_args=None): """Connect to RPC Server Parameters @@ -397,15 +429,43 @@ def connect(url, port, key="", session_timeout=0): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. + session_constructor_args: List + List of additional arguments to passed as the remote session constructor. + The first element of the list is always a string specifying the name of + the session constructor, the following args are the positional args to that function. + Returns ------- sess : RPCSession The connected session. + + Examples + -------- + Normal usage + .. code-block:: python + + client = rpc.connect(server_url, server_port, server_key) + + Session_constructor can be used to customize the session in the remote + The following code connects to a remote internal server via a proxy + by constructing another RPCClientSession on the proxy machine and use that + as the serving session of the proxy endpoint. + + .. code-block:: python + + client_via_proxy = rpc.connect( + proxy_server_url, proxy_server_port, proxy_server_key, + session_constructor_args=[ + "rpc.Connect", internal_url, internal_port, internal_key]) + """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - sess = base._Connect(url, port, key) + session_constructor_args = session_constructor_args if session_constructor_args else [] + if not isinstance(session_constructor_args, (list, tuple)): + raise TypeError("Expect the session constructor to be a list or tuple") + sess = _ffi_api.Connect(url, port, key, *session_constructor_args) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py new file mode 100644 index 0000000000000..760c5362f11db --- /dev/null +++ b/python/tvm/rpc/minrpc.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utils to path.""" +import os +from tvm._ffi import libinfo +from tvm.contrib import cc + + +def find_minrpc_server_libpath(server="posix_popen_server"): + """Get the path of minrpc server libary. + + Parameters + ---------- + server : str + The kind of built in minrpc server. + + Returns + ------- + path : str + The path to the min server library. + """ + curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", "..")) + + path = os.path.join( + source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server)) + + candidates = [path] + if not os.path.isfile(path): + raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates)) + return path + + +def with_minrpc(compile_func, + server="posix_popen_server", + runtime="libtvm"): + """Attach the compiler function with minrpc related options. + + Parameters + ---------- + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] + The compilation function to decorate. + + server : str + The server type. + + runtime : str + The runtime library. + + Returns + ------- + fcompile : function + The return compilation. + """ + server_path = find_minrpc_server_libpath(server) + runtime_path = libinfo.find_lib_path( + [runtime, runtime + ".so", runtime + ".dylib"])[0] + + runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) + options = ["-std=c++14"] + # Make sure the rpath to the libtvm is set so we can do local tests. + # Note that however, this approach won't work on remote. + # Always recommend to to link statically. + options += ["-Wl,-rpath=" + runtime_dir] + options += ["-I" + path for path in libinfo.find_include_path()] + fcompile = cc.cross_compiler( + compile_func, + options=options, + add_files=[server_path, runtime_path]) + fcompile.__name__ = "with_minrpc" + fcompile.need_system_lib = True + return fcompile diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index c3a3647948eea..994e230b982ae 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -42,6 +42,7 @@ raise ImportError( "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) +from . import _ffi_api from . import base from .base import TrackerCode from .server import _server_env @@ -129,7 +130,7 @@ def close_pair(self): def on_close_event(self): """on close event""" assert not self._done - logging.info("RPCProxy:on_close %s ...", self.name()) + logging.info("RPCProxy:on_close_event %s ...", self.name()) if self.match_key: key = self.match_key if self._proxy._client_pool.get(key, None) == self: @@ -157,10 +158,12 @@ def on_message(self, message): self.on_data(message) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) + self._close_process = True + if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None - logging.info("%s Close socket..", self.name()) self.on_close_event() @@ -186,6 +189,7 @@ def send_data(self, message): self.on_error(err) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None @@ -549,7 +553,7 @@ def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) - on_message = base._CreateEventDrivenServer( + on_message = _ffi_api.CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 627d67a0a8359..15a3c7de789d3 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import util +from . import _ffi_api from . import base from . base import TrackerCode @@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) @@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" sockfd = sock.fileno() temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) + _ffi_api.ServerLoop(sockfd) if not work_path: temp.remove() logger.info("Finish serving %s", addr) @@ -325,9 +326,12 @@ def __init__(self, key="", load_library=None, custom_addr=None, - silent=False): + silent=False, + utvm_dev_id=None, + utvm_dev_config_args=None, + ): try: - if base._ServerLoop is None: + if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") @@ -355,6 +359,10 @@ def __init__(self, cmd += ["--custom-addr", custom_addr] if silent: cmd += ["--silent"] + if utvm_dev_id is not None: + assert utvm_dev_config_args is not None + cmd += [f"--utvm-dev-id={utvm_dev_id}"] + cmd += [f"--utvm-dev-config-args={utvm_dev_config_args}"] # prexec_fn is not thread safe and may result in deadlock. # python 3.2 introduced the start_new_session parameter as diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 7845a26bfca29..3cdb28f8c4965 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -109,7 +109,6 @@ def __call__(self, *args): # pylint: disable=not-callable return self.entry_func(*args) - def __repr__(self): return "Module(%s, %x)" % (self.type_key, self.handle.value) @@ -245,6 +244,7 @@ def _dso_exportable(self): def export_library(self, file_name, fcompile=None, + addons=None, **kwargs): """Export the module and its imported device code one library. @@ -284,7 +284,7 @@ def export_library(self, modules = self._collect_dso_modules() temp = _util.tempdir() - files = [] + files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None @@ -314,9 +314,12 @@ def export_library(self, if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"): llvm_target_triple = fcompile.get_target_triple() + if getattr(fcompile, "need_system_lib", False) and not is_system_lib: + raise ValueError("%s need --system-lib option" % str(fcompile)) + if self.imported_modules: if enabled("llvm") and llvm_target_triple: - path_obj = temp.relpath("devc.o") + path_obj = temp.relpath("devc." + object_format) m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple) m.save(path_obj) files.append(path_obj) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 10bbb6ef54c29..9f5f0f685e8df 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -219,7 +219,7 @@ def context(dev_type, dev_id=0): """ if isinstance(dev_type, string_types): if '-device=micro_dev' in dev_type: - dev_type = 'micro_dev' + dev_type = TVMContext.STR2MASK['micro_dev'] else: dev_type = dev_type.split()[0] if dev_type not in TVMContext.STR2MASK: @@ -478,6 +478,22 @@ def hexagon(dev_id=0): return TVMContext(14, dev_id) +def webgpu(dev_id=0): + """Construct a webgpu device. + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + ctx : TVMContext + The created context + """ + return TVMContext(15, dev_id) + + cl = opencl mtl = metal diff --git a/python/tvm/target/arm_isa.py b/python/tvm/target/arm_isa.py new file mode 100644 index 0000000000000..c40296e507130 --- /dev/null +++ b/python/tvm/target/arm_isa.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Defines functions to analyze available opcodes in the ARM ISA.""" + + +ARM_ISA_MAP = { + 'armv7e-m': ['SMLAD'], +} + + +class IsaAnalyzer(object): + + def __init__(self, target): + self.target = target + # TODO: actually parse -mcpu + arch = 'armv7e-m' + self._isa_map = ARM_ISA_MAP[arch] + + def __contains__(self, instruction): + return instruction in self._isa_map diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index d22b35039042c..939956c1a0054 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -19,8 +19,9 @@ """ # expose all operators in tvm tir.op from tvm.tir import any, all, min_value, max_value, trace -from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil +from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil from tvm.tir import sinh, cosh, log2, log10 +from tvm.tir import asin, asinh, acos, acosh, atan, atanh from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 7d06eea9632e5..07e0c9ca0f276 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -37,7 +37,9 @@ from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp -from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2 +from .op import sin, sinh, asin, asinh +from .op import cos, cosh, acos, acosh +from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import isnan, isfinite, isinf, copysign diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index e783fe7dc3e74..b87db19738b91 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -522,6 +522,38 @@ def cosh(x): return call_pure_intrin(x.dtype, "cosh", x) +def acos(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acos", x) + + +def acosh(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acosh", x) + + def sin(x): """Take sin of input x. @@ -554,6 +586,38 @@ def sinh(x): return call_pure_intrin(x.dtype, "sinh", x) +def asin(x): + """Take asin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asin", x) + + +def asinh(x): + """Take asinh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asinh", x) + + def atan(x): """Take atan of input x. @@ -570,6 +634,22 @@ def atan(x): return call_pure_intrin(x.dtype, "atan", x) +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "atanh", x) + + def atan2(x1, x2): """Take arctan2(x1, x2). diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml index 3c51bb384c683..5a1f1d27514f8 100644 --- a/rust/.rustfmt.toml +++ b/rust/.rustfmt.toml @@ -29,3 +29,4 @@ merge_derives = true use_try_shorthand = false use_field_init_shorthand = false force_explicit_abi = true + diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8467f6a92ea80..b4a159c92a2b9 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,9 +22,11 @@ members = [ "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", + "runtime/tests/test_wasm32", "runtime/tests/test_nn", "frontend", "frontend/tests/basics", "frontend/tests/callback", - "frontend/examples/resnet" + "frontend/examples/resnet", + "tvm-sys" ] diff --git a/rust/common/build.rs b/rust/common/build.rs index b3ae7b6d18377..07326f41f8018 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -51,6 +51,7 @@ fn main() { .layout_tests(false) .derive_partialeq(true) .derive_eq(true) + .derive_default(true) .generate() .expect("unable to generate bindings") .write_to_file(PathBuf::from("src/c_runtime_api.rs")) diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs index d0a66a62b8bf8..a8f4f989c1467 100644 --- a/rust/common/src/array.rs +++ b/rust/common/src/array.rs @@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray { shape: arr.shape().as_ptr() as *const i64 as *mut i64, strides: arr.strides().as_ptr() as *const isize as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 2ae64e7a32b32..33b2993bf3da2 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -31,8 +31,13 @@ pub mod ffi { include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - pub type BackendPackedCFunc = - extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + out_ret_value: *mut TVMValue, + out_ret_tcode: *mut u32, + ) -> c_int; } pub mod array; diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 2b6c7c217e28c..c38b3ff8e527f 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -297,6 +297,7 @@ impl<'a> Tensor<'a> { self.strides.as_ref().unwrap().as_ptr() } as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 518bf724f3194..71541ba27826e 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -382,7 +382,18 @@ named! { // Converts a bytes to String. named! { name, - map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec())) + do_parse!( + len_l: le_u32 >> + len_h: le_u32 >> + data: take!(len_l) >> + ( + if len_h == 0 { + String::from_utf8(data.to_vec()).unwrap() + } else { + panic!("Too long string") + } + ) + ) } // Parses a TVMContext diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs index 856dd78193bc3..cb4d7776dd0bb 100644 --- a/rust/runtime/src/module/mod.rs +++ b/rust/runtime/src/module/mod.rs @@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< (val, code as i32) }) .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + let ret: TVMRetValue = TVMRetValue::default(); + let (mut ret_val, mut ret_type_code) = ret.to_tvm_value(); + let exit_code = func( + values.as_ptr(), + type_codes.as_ptr(), + values.len() as i32, + &mut ret_val, + &mut ret_type_code, + ); if exit_code == 0 { - Ok(TVMRetValue::default()) + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code)) } else { Err(tvm_common::errors::FuncCallError::get_with_context( func_name.clone(), diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index f473bbf3990ae..b8be01270ae7d 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -18,7 +18,6 @@ */ use std::{ - env, os::raw::{c_int, c_void}, sync::{ atomic::{AtomicUsize, Ordering}, @@ -27,6 +26,9 @@ use std::{ thread::{self, JoinHandle}, }; +#[cfg(not(target_arch = "wasm32"))] +use std::env; + use crossbeam::channel::{bounded, Receiver, Sender}; use tvm_common::ffi::TVMParallelGroupEnv; @@ -147,7 +149,10 @@ impl ThreadPool { fn run_worker(queue: Receiver) { loop { - let task = queue.recv().expect("should recv"); + let task = match queue.recv() { + Ok(v) => v, + Err(_) => break, + }; let result = task.run(); if result == ::min_value() { break; diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs index 8344dfbb1adfe..65ad25324cae4 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/runtime/src/workspace.rs @@ -64,7 +64,7 @@ impl WorkspacePool { .iter() .fold(None, |cur_ws_idx: Option, &idx| { let ws_size = self.workspaces[idx].size(); - if !ws_size >= size { + if ws_size < size { return cur_ws_idx; } cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { @@ -92,9 +92,8 @@ impl WorkspacePool { break; } } - if let Some(ws_idx) = ws_idx { - self.free.push(ws_idx); - } + let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?; + self.free.push(ws_idx); Ok(()) } } @@ -135,6 +134,5 @@ pub extern "C" fn TVMBackendFreeWorkspace( Ok(()) => 0, Err(_) => -1, }) as c_int - }); - 0 + }) } diff --git a/rust/runtime/tests/test_wasm32/.cargo/config b/rust/runtime/tests/test_wasm32/.cargo/config new file mode 100644 index 0000000000000..6b77899cb3333 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/.cargo/config @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasi" diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml new file mode 100644 index 0000000000000..1d3373a9e60f1 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-wasm32" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs new file mode 100644 index 0000000000000..8b72be2902677 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/build.rs @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::PathBuf, process::Command}; + +fn main() { + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest_wasm32.a"); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + obj_file.exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + let output = Command::new(ar) + .arg("rcs") + .arg(&lib_file) + .arg(&obj_file) + .output() + .expect("Failed to execute command"); + assert!( + lib_file.exists(), + "Could not create archive: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-lib=static=test_wasm32"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); +} diff --git a/tests/webgl/test_local_multi_stage.py b/rust/runtime/tests/test_wasm32/src/build_test_lib.py old mode 100644 new mode 100755 similarity index 50% rename from tests/webgl/test_local_multi_stage.py rename to rust/runtime/tests/test_wasm32/src/build_test_lib.py index 54a554b74ed9b..6016c60c4ea34 --- a/tests/webgl/test_local_multi_stage.py +++ b/rust/runtime/tests/test_wasm32/src/build_test_lib.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,34 +15,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_multi_stage(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute((n,), lambda i: A[i] + 1, name="B") - C = te.compute((n,), lambda i: B[i] * 2, name="C") +"""Prepares a simple TVM library for testing.""" - s = te.create_schedule(C.op) - s[B].opengl() - s[C].opengl() +from os import path as osp +import sys - f = tvm.build(s, [A, C], "opengl", name="multi_stage") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) - f(a, c) +import tvm +from tvm import te - tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2) +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o')) -if __name__ == "__main__": - test_local_multi_stage() +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_wasm32/src/main.rs b/rust/runtime/tests/test_wasm32/src/main.rs new file mode 100644 index 0000000000000..a46cfa979becd --- /dev/null +++ b/rust/runtime/tests/test_wasm32/src/main.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern "C" { + static __tvm_module_ctx: i32; +} + +#[no_mangle] +unsafe fn __get_tvm_module_ctx() -> i32 { + // Refer a symbol in the libtest_wasm32.a to make sure that the link of the + // library is not optimized out. + __tvm_module_ctx +} + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +fn main() { + // try static + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml new file mode 100644 index 0000000000000..fe4d0bf987bf7 --- /dev/null +++ b/rust/tvm-sys/Cargo.toml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-sys" +version = "0.1.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +edition = "2018" + +[features] +bindings = [] + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +ndarray = "0.12" +enumn = "^0.1" + +[build-dependencies] +bindgen = "0.51" diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs new file mode 100644 index 0000000000000..85e16bead0852 --- /dev/null +++ b/rust/tvm-sys/build.rs @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate bindgen; + +use std::path::PathBuf; + +use std::env; + +fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + crate_dir + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); + + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-search={}/build", tvm_home); + } + + // @see rust-bindgen#550 for `blacklist_type` + bindgen::Builder::default() + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) + .clang_arg(format!("-I{}/include/", tvm_home)) + .blacklist_type("max_align_t") + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs new file mode 100644 index 0000000000000..1627e9e228606 --- /dev/null +++ b/rust/tvm-sys/src/array.rs @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + mem, + os::raw::{c_int, c_void}, +}; + +use crate::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const i64 as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs new file mode 100644 index 0000000000000..40f28f45d76ba --- /dev/null +++ b/rust/tvm-sys/src/byte_array.rs @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use std::os::raw::c_char; + +use crate::ffi::TVMByteArray; + +/// A newtype wrapping a raw TVM byte-array. +/// +/// ## Example +/// +/// ``` +/// let v = b"hello"; +/// let barr = tvm_sys::ByteArray::from(&v); +/// assert_eq!(barr.len(), v.len()); +/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); +/// ``` +pub struct ByteArray { + /// The raw FFI ByteArray. + array: TVMByteArray, +} + +impl ByteArray { + /// Gets the underlying byte-array + pub fn data(&self) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) } + } + + /// Gets the length of the underlying byte-array + pub fn len(&self) -> usize { + self.array.size + } + + /// Converts the underlying byte-array to `Vec` + pub fn to_vec(&self) -> Vec { + self.data().to_vec() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +// Needs AsRef for Vec +impl> From for ByteArray { + fn from(arg: T) -> Self { + let arg = arg.as_ref(); + ByteArray { + array: TVMByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len(), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert() { + let v = vec![1u8, 2, 3]; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); + let v = b"hello"; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); + } +} diff --git a/rust/tvm-sys/src/context.rs b/rust/tvm-sys/src/context.rs new file mode 100644 index 0000000000000..64b58b9f42c93 --- /dev/null +++ b/rust/tvm-sys/src/context.rs @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides [`Context`] and related device queries. +//! +//! Create a new context for device type and device id. +//! +//! # Example +//! +//! ``` +//! # use tvm_sys::{DeviceType, Context}; +//! let cpu = DeviceType::from("cpu"); +//! let ctx = Context::new(cpu , 0); +//! let cpu0 = Context::cpu(0); +//! assert_eq!(ctx, cpu0); +//! ``` +//! +//! Or from a supported device name. +//! +//! ``` +//! use tvm_sys::Context; +//! let cpu0 = Context::from("cpu"); +//! println!("{}", cpu0); +//! ``` + +use std::convert::TryFrom; +use std::fmt::{self, Display, Formatter}; +use std::str::FromStr; + +use crate::ffi::{self, *}; +use crate::packed_func::{ArgValue, RetValue}; + +use anyhow::Result; +use enumn::N; +use thiserror::Error; + +/// Device type represents the set of devices supported by +/// [TVM](https://github.com/apache/incubator-tvm). +/// +/// ## Example +/// +/// ``` +/// use tvm_sys::DeviceType; +/// let cpu = DeviceType::from("cpu"); +/// println!("device is: {}", cpu); +///``` + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)] +#[repr(i64)] +pub enum DeviceType { + CPU = 1, + GPU, + CPUPinned, + OpenCL, + Vulkan, + Metal, + VPI, + ROCM, + ExtDev, +} + +impl Default for DeviceType { + /// default device is cpu. + fn default() -> Self { + DeviceType::CPU + } +} + +impl From for ffi::DLDeviceType { + fn from(device_type: DeviceType) -> Self { + device_type as Self + } +} + +impl From for DeviceType { + fn from(device_type: ffi::DLDeviceType) -> Self { + Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType") + } +} + +impl Display for DeviceType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + DeviceType::CPU => "cpu", + DeviceType::GPU => "gpu", + DeviceType::CPUPinned => "cpu_pinned", + DeviceType::OpenCL => "opencl", + DeviceType::Vulkan => "vulkan", + DeviceType::Metal => "metal", + DeviceType::VPI => "vpi", + DeviceType::ROCM => "rocm", + DeviceType::ExtDev => "ext_device", + // DeviceType(_) => "rpc", + } + ) + } +} + +impl<'a> From<&'a str> for DeviceType { + fn from(type_str: &'a str) -> Self { + match type_str { + "cpu" => DeviceType::CPU, + "llvm" => DeviceType::CPU, + "stackvm" => DeviceType::CPU, + "gpu" => DeviceType::GPU, + "cuda" => DeviceType::GPU, + "nvptx" => DeviceType::GPU, + "cl" => DeviceType::OpenCL, + "opencl" => DeviceType::OpenCL, + "metal" => DeviceType::Metal, + "vpi" => DeviceType::VPI, + "rocm" => DeviceType::ROCM, + _ => panic!("{:?} not supported!", type_str), + } + } +} + +impl<'a> From<&DeviceType> for ArgValue<'a> { + fn from(dev: &DeviceType) -> Self { + Self::Int(*dev as _) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct Context { + pub device_type: DeviceType, + pub device_id: usize, +} + +impl Context { + pub fn new(device_type: DeviceType, device_id: usize) -> Context { + Context { + device_type, + device_id, + } + } +} + +impl<'a> From<&'a Context> for DLContext { + fn from(ctx: &'a Context) -> Self { + Self { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Default for Context { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU.into(), + device_id: 0, + } + } +} + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a Context from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for Context { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type.into()),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl Context { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type.into(), + device_id: device_id, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); + +impl<'a> From<&'a str> for Context { + fn from(target: &str) -> Self { + Context::new(DeviceType::from(target), 0) + } +} + +impl From for Context { + fn from(ctx: ffi::DLContext) -> Self { + Context { + device_type: DeviceType::from(ctx.device_type), + device_id: ctx.device_id as usize, + } + } +} + +impl From for ffi::DLContext { + fn from(ctx: Context) -> Self { + ffi::DLContext { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Display for Context { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.device_type, self.device_id) + } +} + +impl From for RetValue { + fn from(ret_value: Context) -> RetValue { + RetValue::Context(ret_value.into()) + } +} + +impl TryFrom for Context { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::Context(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context() { + let ctx = Context::cpu(0); + println!("ctx: {}", ctx); + let default_ctx = Context::new(DeviceType::CPU, 0); + assert_eq!(ctx.clone(), default_ctx); + assert_ne!(ctx, Context::gpu(0)); + + let str_ctx = Context::new(DeviceType::GPU, 0); + assert_eq!(str_ctx.clone(), str_ctx); + assert_ne!(str_ctx, Context::new(DeviceType::CPU, 0)); + } +} diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs new file mode 100644 index 0000000000000..5dd414c179602 --- /dev/null +++ b/rust/tvm-sys/src/datatype.rs @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::any::TypeId; +use std::convert::TryFrom; +use std::str::FromStr; + +use crate::ffi::DLDataType; +use crate::packed_func::RetValue; + +use thiserror::Error; + +const DL_INT_CODE: u8 = 0; +const DL_UINT_CODE: u8 = 1; +const DL_FLOAT_CODE: u8 = 2; +const DL_HANDLE: u8 = 3; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + code: u8, + bits: u8, + lanes: u16, +} + +impl DataType { + pub fn new(code: u8, bits: u8, lanes: u16) -> DataType { + DataType { code, bits, lanes } + } + + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits as usize * self.lanes as usize) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 64) + || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 64) + || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 64) + } + + pub fn code(&self) -> usize { + self.code as usize + } + + pub fn bits(&self) -> usize { + self.bits as usize + } + + pub fn lanes(&self) -> usize { + self.lanes as usize + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code, + bits: dtype.bits, + lanes: dtype.lanes, + } + } +} + +#[derive(Debug, Error)] +pub enum ParseDataTypeError { + #[error("invalid number: {0}")] + InvalidNumber(std::num::ParseIntError), + #[error("missing data type specifier (e.g., int32, float64)")] + MissingDataType, + #[error("unknown type: {0}")] + UnknownType(String), +} + +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl FromStr for DataType { + type Err = ParseDataTypeError; + + fn from_str(type_str: &str) -> Result { + use ParseDataTypeError::*; + + if type_str == "bool" { + return Ok(DataType::new(1, 1, 1)); + } + + let mut type_lanes = type_str.split('x'); + let typ = type_lanes.next().ok_or(MissingDataType)?; + let lanes = type_lanes + .next() + .map(|l| ::from_str_radix(l, 10)) + .unwrap_or(Ok(1)) + .map_err(InvalidNumber)?; + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + ( + name, + u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?, + ) + } + None => (typ, 32), + }; + + let type_code = match type_name { + "int" => DL_INT_CODE, + "uint" => DL_UINT_CODE, + "float" => DL_FLOAT_CODE, + "handle" => DL_HANDLE, + _ => return Err(UnknownType(type_name.to_string())), + }; + + Ok(DataType::new(type_code, bits, lanes)) + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.bits == 1 && self.lanes == 1 { + return write!(f, "bool"); + } + let mut type_str = match self.code { + DL_INT_CODE => "int", + DL_UINT_CODE => "uint", + DL_FLOAT_CODE => "float", + DL_HANDLE => "handle", + _ => "unknown", + } + .to_string(); + + type_str += &self.bits.to_string(); + if self.lanes > 1 { + type_str += &format!("x{}", self.lanes); + } + f.write_str(&type_str) + } +} + +impl From for RetValue { + fn from(dt: DataType) -> RetValue { + RetValue::DataType((&dt).into()) + } +} + +impl TryFrom for DataType { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::DataType(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs new file mode 100644 index 0000000000000..8479ec62f19f6 --- /dev/null +++ b/rust/tvm-sys/src/errors.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] +pub struct ValueDowncastError { + pub actual_type: String, + pub expected_type: &'static str, +} + +#[derive(Error, Debug)] +#[error("Function call `{context:?}` returned error: {message:?}")] +pub struct FuncCallError { + context: String, + message: String, +} + +impl FuncCallError { + pub fn get_with_context(context: String) -> Self { + Self { + context, + message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } + .to_str() + .expect("double fault") + .to_owned(), + } + } +} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs new file mode 100644 index 0000000000000..dd28e3603f903 --- /dev/null +++ b/rust/tvm-sys/src/lib.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This crate contains the minimal interface over TVM's +//! C runtime API. +//! +//! These common bindings are useful to both runtimes +//! written in Rust, as well as higher level API bindings. +//! +//! See the `tvm-rt` or `tvm` crates for full bindings to +//! the TVM API. + +/// The low-level C runtime FFI API for TVM. +pub mod ffi { + #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); + + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; +} + +pub mod array; +pub mod byte_array; +pub mod context; +pub mod datatype; +pub mod errors; +#[macro_use] +pub mod packed_func; +pub mod value; + +pub use byte_array::ByteArray; +pub use context::{Context, DeviceType}; +pub use datatype::DataType; +pub use errors::*; +pub use packed_func::{ArgValue, RetValue}; diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs new file mode 100644 index 0000000000000..e4b27397900c5 --- /dev/null +++ b/rust/tvm-sys/src/packed_func.rs @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + convert::TryFrom, + ffi::{CStr, CString}, + os::raw::c_void, +}; + +use crate::{errors::ValueDowncastError, ffi::*}; + +pub use crate::ffi::TVMValue; + +pub trait PackedFunc: + Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +impl PackedFunc for T where + T: Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +/// Calls a packed function and returns a `RetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// Constructs a derivative of a TVMPodValue. +macro_rules! TVMPODValue { + { + $(#[$m:meta])+ + $name:ident $(<$a:lifetime>)? { + $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? + }, + match $value:ident { + $($tvm_type:ident => { $from_tvm_type:expr })+ + }, + match &self { + $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ + } + $(,)? + } => { + $(#[$m])+ + #[derive(Clone, Debug)] + pub enum $name $(<$a>)? { + Int(i64), + UInt(i64), + Float(f64), + Null, + DataType(DLDataType), + String(CString), + Context(TVMContext), + Handle(*mut c_void), + ArrayHandle(TVMArrayHandle), + ObjectHandle(*mut c_void), + ModuleHandle(TVMModuleHandle), + FuncHandle(TVMFunctionHandle), + NDArrayHandle(*mut c_void), + $($extra_variant($variant_type)),+ + } + + impl $(<$a>)? $name $(<$a>)? { + pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { + use $name::*; + #[allow(non_upper_case_globals)] + unsafe { + match type_code as _ { + DLDataTypeCode_kDLInt => Int($value.v_int64), + DLDataTypeCode_kDLUInt => UInt($value.v_int64), + DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMTypeCode_kTVMNullptr => Null, + TVMTypeCode_kTVMDataType => DataType($value.v_type), + TVMTypeCode_kTVMContext => Context($value.v_ctx), + TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + $( $tvm_type => { $from_tvm_type } ),+ + _ => unimplemented!("{}", type_code), + } + } + } + + pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + use $name::*; + match self { + Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), + UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), + Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + String(val) => { + ( + TVMValue { v_handle: val.as_ptr() as *mut c_void }, + TVMTypeCode_kTVMStr, + ) + } + Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + ArrayHandle(val) => { + ( + TVMValue { v_handle: *val as *const _ as *mut c_void }, + TVMTypeCode_kTVMNDArrayHandle, + ) + }, + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ModuleHandle(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + FuncHandle(val) => ( + TVMValue { v_handle: *val }, + TVMTypeCode_kTVMPackedFuncHandle + ), + NDArrayHandle(val) => + (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + $( $self_type($val) => { $from_self_type } ),+ + } + } + } + } +} + +TVMPODValue! { + /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way + /// to obtain a `ArgValue` is automatically via `call_packed!`. + ArgValue<'a> { + Bytes(&'a TVMByteArray), + Str(&'a CStr), + }, + match value { + TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + }, + match &self { + Bytes(val) => { + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + } +} + +TVMPODValue! { + /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. + /// Can be downcasted using `try_from` if it contains the desired type. + /// + /// # Example + /// + /// ``` + /// use std::convert::{TryFrom, TryInto}; + /// use tvm_sys::RetValue; + /// + /// let a = 42u32; + /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap(); + /// + /// let s = "hello, world!"; + /// let t: RetValue = s.to_string().into(); + /// assert_eq!(String::try_from(t).unwrap(), s); + /// ``` + RetValue { + Bytes(TVMByteArray), + Str(&'static CStr), + }, + match value { + TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + }, + match &self { + Bytes(val) => + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + Str(val) => + { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + } +} + +#[macro_export] +macro_rules! try_downcast { + ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { + match $val { + $( $pat => { Ok($converter) } )+ + _ => Err($crate::errors::ValueDowncastError { + actual_type: format!("{:?}", $val), + expected_type: stringify!($into), + }), + } + }; +} + +/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_pod_value { + ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { + $( + impl<'a> From<$type> for ArgValue<'a> { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + fn from(val: &'a $type) -> Self { + Self::$variant(*val as $inner_ty) + } + } + + impl<'a> TryFrom> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl TryFrom for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type }) + } + } + )+ + }; +} + +impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); +impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); +impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(DataType, DLDataType, [DLDataType]); +impl_pod_value!(Context, TVMContext, [TVMContext]); + +impl<'a> From<&'a str> for ArgValue<'a> { + fn from(s: &'a str) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(s: String) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From<&'a CStr> for ArgValue<'a> { + fn from(s: &'a CStr) -> Self { + Self::Str(s) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(s: CString) -> Self { + Self::String(s) + } +} + +impl<'a> From<&'a TVMByteArray> for ArgValue<'a> { + fn from(s: &'a TVMByteArray) -> Self { + Self::Bytes(s) + } +} + +impl<'a> TryFrom> for &'a str { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +/// Converts an unspecialized handle to a ArgValue. +impl From<*const T> for ArgValue<'static> { + fn from(ptr: *const T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +/// Converts an unspecialized mutable handle to a ArgValue. +impl From<*mut T> for ArgValue<'static> { + fn from(ptr: *mut T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +impl<'a> From<&'a mut DLTensor> for ArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + Self::ArrayHandle(arr as *mut DLTensor) + } +} + +impl<'a> From<&'a DLTensor> for ArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + Self::ArrayHandle(arr as *const _ as *mut DLTensor) + } +} + +impl TryFrom for String { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!( + val -> String, + |RetValue::String(s)| { s.into_string().unwrap() }, + |RetValue::Str(s)| { s.to_str().unwrap().to_string() } + ) + } +} + +impl From for RetValue { + fn from(s: String) -> Self { + Self::String(std::ffi::CString::new(s).unwrap()) + } +} + +impl From for RetValue { + fn from(arr: TVMByteArray) -> Self { + Self::Bytes(arr) + } +} + +impl TryFrom for TVMByteArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val }) + } +} + +impl Default for RetValue { + fn default() -> Self { + Self::Int(0) + } +} + +impl TryFrom for std::ffi::CString { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> std::ffi::CString, + |RetValue::Str(val)| { val.into() }) + } +} diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs new file mode 100644 index 0000000000000..a9ad5f523fdef --- /dev/null +++ b/rust/tvm-sys/src/value.rs @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::str::FromStr; + +use crate::ffi::*; + +use thiserror::Error; + +macro_rules! impl_pod_tvm_value { + ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { + $( + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val as $field_ty } + } + } + + impl From for $ty { + fn from(val: TVMValue) -> Self { + unsafe { val.$field as $ty } + } + } + )+ + }; + ($field:ident, $ty:ty) => { + impl_pod_tvm_value!($field, $ty, $ty); + } +} + +impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); +impl_pod_tvm_value!(v_float64, f64, f32, f64); +impl_pod_tvm_value!(v_type, DLDataType); +impl_pod_tvm_value!(v_ctx, TVMContext); + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for TVMContext { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type ),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl TVMContext { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type, + device_id: device_id as i32, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 83dfc64009cf3..037c76665d4b6 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -20,9 +20,9 @@ /*! * \file tvm/arith/analyzer.cc */ +#include #include #include -#include #include namespace tvm { @@ -33,34 +33,33 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) { -} + int_set(this) {} -void Analyzer::Bind(const Var& var, const PrimExpr& expr) { +void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); - this->const_int_bound.Update(var, this->const_int_bound(new_expr)); - this->modular_set.Update(var, this->modular_set(new_expr)); - this->rewrite_simplify.Update(var, new_expr); - this->canonical_simplify.Update(var, new_expr); + this->const_int_bound.Update(var, this->const_int_bound(new_expr), override); + this->modular_set.Update(var, this->modular_set(new_expr), override); + this->rewrite_simplify.Update(var, new_expr, override); + this->canonical_simplify.Update(var, new_expr, override); } -void Analyzer::Bind(const Var& var, const Range& range) { +void Analyzer::Bind(const Var& var, const Range& range, bool override) { CHECK(range.defined()); if (tir::is_one(range->extent)) { - this->Bind(var, range->min); + this->Bind(var, range->min, override); } else { - this->const_int_bound.Bind(var, range); + this->const_int_bound.Bind(var, range, override); } // skip modular_set // skip rewrite simplify } -void Analyzer::Bind(const Map& variables) { +void Analyzer::Bind(const Map& variables, bool override) { for (const auto& iter : variables) { - this->Bind(iter.first, iter.second); + this->Bind(iter.first, iter.second, override); } } @@ -92,6 +91,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { return false; } +bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { + if (const auto* ptr = expr.as()) { + return ptr->value < upper_bound; + } + auto bd = this->const_int_bound(this->rewrite_simplify(expr)); + if (bd->max_value < upper_bound) return true; + return false; +} + bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; @@ -115,63 +123,53 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - auto self = std::make_shared(); - auto f = [self](std::string name) -> PackedFunc { - if (name == "const_int_bound") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->const_int_bound(args[0]); - }); - } else if (name == "modular_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->modular_set(args[0]); - }); - } else if (name == "const_int_bound_update") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - self->const_int_bound.Update(args[0], args[1], args[2]); - }); - } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->Simplify(args[0]); - }); - } else if (name == "rewrite_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->rewrite_simplify(args[0]); - }); - } else if (name == "canonical_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->canonical_simplify(args[0]); - }); - } else if (name == "int_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->int_set(args[0], args[1]); - }); - } else if (name == "bind") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - if (args[1].IsObjectRef()) { - self->Bind(args[0], args[1].operator Range()); - } else { - self->Bind(args[0], args[1].operator PrimExpr()); - } - }); - } else if (name == "enter_constraint_context") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr >( - new With(self.get(), args[0])); - auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { - ctx.reset(); - }; - *ret = PackedFunc(fexit); - }); - } - return PackedFunc(); - }; - *ret = TypedPackedFunc(f); +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + auto self = std::make_shared(); + auto f = [self](std::string name) -> PackedFunc { + if (name == "const_int_bound") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); }); + } else if (name == "modular_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); }); + } else if (name == "const_int_bound_update") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + self->const_int_bound.Update(args[0], args[1], args[2]); + }); + } else if (name == "Simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); }); + } else if (name == "rewrite_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "canonical_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); }); + } else if (name == "bind") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator PrimExpr()); + } + }); + } else if (name == "enter_constraint_context") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); + auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; + *ret = PackedFunc(fexit); + }); + } + return PackedFunc(); + }; + *ret = TypedPackedFunc(f); }); } // namespace arith diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index eeaaa8af0ae53..496eb204f24b7 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -21,13 +21,14 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include #include #include #include -#include -#include #include +#include + #include "interval_set.h" namespace tvm { @@ -37,7 +38,7 @@ using namespace tir; // a visitor to find the path to the target variable // from a expression. -class VariablePathFinder: public ExprVisitor { +class VariablePathFinder : public ExprVisitor { public: explicit VariablePathFinder(PrimExpr target) : target_(target) {} @@ -67,17 +68,17 @@ std::vector GetPath(PrimExpr target, PrimExpr expr) { return v.path_; } -enum CompareOp {kGreater, kLess, kEqual}; +enum CompareOp { kGreater, kLess, kEqual }; // a visitor to deduce the bound of a variable from a expression -class BoundDeducer: public ExprVisitor { +class BoundDeducer : public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(PrimExpr target, PrimExpr expr, const std::unordered_map& hint_map, const std::unordered_map& relax_map) - : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} + : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); @@ -119,7 +120,7 @@ class BoundDeducer: public ExprVisitor { result_ += op->b; } else { result_ -= op->a; - result_ = - result_; + result_ = -result_; comp_op = ReverseOp(comp_op); } this->VisitExpr(left ? op->a : op->b); @@ -148,7 +149,7 @@ class BoundDeducer: public ExprVisitor { // always use relax bound bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); - result_ = floordiv(result_, operand); // rounding down here + result_ = floordiv(result_, operand); // rounding down here if (!divided) { if (comp_op == kGreater) { @@ -193,7 +194,7 @@ class BoundDeducer: public ExprVisitor { Analyzer analyzer_; }; -class BoundDeduceInputChecker: public ExprVisitor { +class BoundDeduceInputChecker : public ExprVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; @@ -219,9 +220,12 @@ void BoundDeducer::Init() { CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { switch (comp_op) { - case kEqual: return kEqual; // IntSet can not represent range for `NE - case kGreater: return kLess; - case kLess: return kGreater; + case kEqual: + return kEqual; // IntSet can not represent range for `NE + case kGreater: + return kLess; + case kLess: + return kGreater; default: LOG(FATAL) << "Not a valid compare op"; return kGreater; // return some default value @@ -318,18 +322,18 @@ void BoundDeducer::Relax() { // Both LHS and RHS of the EQ should behave as constants e.g. i == j, // can not be resolved when either `i` or `j` or both are variables with // some Range OR `i` and `j` both should be a single point in IntSet - if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max()) - || !analyzer_.CanProve(a.min() == a.max()))) { + if (comp_op == kEqual && + (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { success_ = false; return; } - expr_ = (comp_op == kGreater) ? a.min() : a.max(); + expr_ = (comp_op == kGreater) ? a.min() : a.max(); result_ = (comp_op == kGreater) ? b.max() : b.min(); } IntSet DeduceBound(PrimExpr v, PrimExpr e, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) { + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); @@ -347,8 +351,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(PrimExpr v, PrimExpr e, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, const Map& relax_map) { std::unordered_map hmap; for (auto kv : hint_map) { @@ -361,16 +364,11 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, return DeduceBound(v, e, hmap, rmap); } - TVM_REGISTER_GLOBAL("arith.DeduceBound") -.set_body_typed([]( - PrimExpr v, PrimExpr cond, - const Map hint_map, - const Map relax_map -) { - return DeduceBound(v, cond, hint_map, relax_map); -}); - + .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, + const Map relax_map) { + return DeduceBound(v, cond, hint_map, relax_map); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 2bb0189890a6a..273870704ca2c 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -22,8 +22,8 @@ * \brief Canonical form based simplification. */ #include -#include #include +#include #include "const_fold.h" #include "pattern_match.h" @@ -37,7 +37,6 @@ using namespace tir; class SumExpr; class SplitExpr; - /*! * \brief Base class of all temporary expression introduced * for canonicalization. @@ -53,8 +52,7 @@ class CanonicalExprNode : public PrimExprNode { virtual PrimExpr Normalize() const = 0; // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const uint32_t _type_child_slots = 2; @@ -111,9 +109,7 @@ class SplitExprNode : public CanonicalExprNode { DivMode div_mode{kTruncDiv}; /*! \brief verify that this is a valid entry. */ - void Verify() const { - CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); - } + void Verify() const { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; @@ -135,13 +131,9 @@ class SplitExprNode : public CanonicalExprNode { return res; } - PrimExpr Normalize() const final { - return NormalizeWithScale(1); - } + PrimExpr Normalize() const final { return NormalizeWithScale(1); } - void MulToSelf(int64_t scale) { - this->scale *= scale; - } + void MulToSelf(int64_t scale) { this->scale *= scale; } inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -186,9 +178,7 @@ class SumExprNode : public CanonicalExprNode { /*! \brief Base value in the summation. */ int64_t base{0}; /*! \brief The expression equals zero. */ - bool IsZero() const { - return base == 0 && args.size() == 0; - } + bool IsZero() const { return base == 0 && args.size() == 0; } /*! * \brief Return the normal Expr that is equivalent to self. * \return The normal expression. @@ -198,9 +188,7 @@ class SumExprNode : public CanonicalExprNode { if (this->args.size() == 0) { return make_const(this->dtype, this->base); } - return Normalize_(this->dtype, - SimplifySplitExprs(args), - base); + return Normalize_(this->dtype, SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -239,9 +227,7 @@ class SumExprNode : public CanonicalExprNode { * \brief add constant value to self. * \param value to be added. */ - void AddToSelf(int64_t value) { - this->base += value; - } + void AddToSelf(int64_t value) { this->base += value; } /*! * \brief self += other * scale; * \param other The expression to be added. @@ -257,8 +243,7 @@ class SumExprNode : public CanonicalExprNode { if (args[start]->IndexEqual(other)) break; } for (size_t j = start; j < args.size(); ++j) { - if (!args[j]->IndexEqual(other) || - other->lower_factor > args[j]->lower_factor) { + if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) { other.CopyOnWrite()->scale *= scale; this->args.insert(this->args.begin() + j, other); return; @@ -286,8 +271,7 @@ class SumExprNode : public CanonicalExprNode { * \param args The original list of arguments. * \return simplified version. */ - static std::vector - SimplifySplitExprs(std::vector args) { + static std::vector SimplifySplitExprs(std::vector args) { // NOTE: This algorithm relies on the factor that args are divided into segments // and each segment is sorted in descending order of lower_factor. for (size_t i = 0; i < args.size(); ++i) { @@ -297,14 +281,12 @@ class SumExprNode : public CanonicalExprNode { SplitExpr& rhs = args[j]; if (!lhs->IndexEqual(rhs)) break; if (lhs->upper_factor < rhs->lower_factor) break; - if (lhs->upper_factor == rhs->upper_factor && - lhs->lower_factor == rhs->lower_factor && + if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { // folding same co-efficient. rhs.CopyOnWrite()->scale += lhs->scale; lhs.CopyOnWrite()->scale = 0; - } else if (lhs->lower_factor == rhs->upper_factor && - rhs->scale != 0 && + } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 && lhs->scale % rhs->scale == 0 && lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { @@ -385,9 +367,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static PrimExpr Normalize_(DataType dtype, - const std::vector& args, - int64_t base) { + static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { // Positive scales first PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { @@ -432,9 +412,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { public: using Rewriter = RewriteSimplifier::Impl; - explicit Impl(Analyzer* parent) - : Rewriter(parent) {} - + explicit Impl(Analyzer* parent) : Rewriter(parent) {} PrimExpr CanonicalSimplify(PrimExpr expr) { expr = operator()(expr); @@ -448,9 +426,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } // Normal mutation without normalization. - PrimExpr CanonicalMutate(PrimExpr expr) { - return Rewriter::VisitExpr(expr); - } + PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } using Rewriter::VisitExpr_; PrimExpr VisitExpr_(const AddNode* op) final; @@ -486,9 +462,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param out_divisible The result divisible component. * \param out_non_divisible The non-divisible component. */ - void SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, + void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible); /*! * \brief Normalize expr to normal expr. @@ -568,8 +542,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { PrimExpr SimplifyReduceCombiner(const ReduceNode* op); }; -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -594,8 +567,7 @@ VisitExpr_(const AddNode* op) { return std::move(ret); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -620,9 +592,7 @@ VisitExpr_(const SubNode* op) { return std::move(ret); } - -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -660,11 +630,9 @@ VisitExpr_(const MulNode* op) { } } -void CanonicalSimplifier::Impl:: -SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, - SumExpr* out_non_divisible) { +void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, + SumExpr* out_divisible, + SumExpr* out_non_divisible) { auto divisible = make_object(); auto non_divisible = make_object(); divisible->dtype = psum->dtype; @@ -686,8 +654,7 @@ SeparateDivisibleParts(const SumExprNode* psum, *out_non_divisible = SumExpr(non_divisible); } -SplitExpr CanonicalSimplifier::Impl:: -SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -728,8 +695,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -764,8 +730,7 @@ VisitExpr_(const DivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (TryCompare(temp, cval) != kLT) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); } } return std::move(lhs); @@ -789,8 +754,7 @@ VisitExpr_(const DivNode* op) { } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -821,8 +785,7 @@ VisitExpr_(const FloorDivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); } } return std::move(lhs); @@ -845,8 +808,7 @@ VisitExpr_(const FloorDivNode* op) { } } -SplitExpr CanonicalSimplifier::Impl:: -SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -860,16 +822,15 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { // (x / c1) % c2 => (x % (c1 * c2)) / c2 int64_t new_upper_factor = lhs->lower_factor * scaled_cval; // try to see if we can reduce the existing upper modular. - if (lhs->upper_factor == SplitExprNode::kPosInf || - lhs->upper_factor % new_upper_factor == 0) { + if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) { // we gained a new upper factor that is smaller // than the original one // Perhaps there are more chances in simplifying the index // Do a recursive call to simplify the mod with the new factor. - if (new_upper_factor < lhs->upper_factor && - lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(this->VisitExpr(ModImpl( - lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { + auto updated = ToSplitExpr(this->VisitExpr( + ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + updated.CopyOnWrite()->scale = lhs->scale; // re-apply the lower_factor if (lhs->lower_factor != 1) { return SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -895,8 +856,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -940,8 +900,7 @@ VisitExpr_(const ModNode* op) { // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); int64_t new_base = psum->base % cval; - if (cbound->min_value >= 0 && - cbound->min_value - psum->base + new_base >= 0) { + if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); @@ -965,8 +924,7 @@ VisitExpr_(const ModNode* op) { } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -990,8 +948,7 @@ VisitExpr_(const FloorModNode* op) { return floormod(temp, c1.Eval()); } else { // If temp < cval && temp >=0 then can remove the mod. - if (TryCompare(temp, cval) == kLT && - analyzer_->CanProveGreaterEqual(temp, 0)) { + if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) { return temp; } else { // contonue to use logic below. @@ -1026,8 +983,7 @@ VisitExpr_(const FloorModNode* op) { } // Simplify reduce expression. -PrimExpr CanonicalSimplifier::Impl:: -SimplifyReduceCombiner(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { @@ -1061,8 +1017,7 @@ SimplifyReduceCombiner(const ReduceNode* op) { // components which have side effects should also be preserved for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || - HasSideEffect(op->combiner->identity_element[i]) || + if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || HasSideEffect(op->combiner->result[i])) { mark_used(i); } @@ -1090,14 +1045,11 @@ SimplifyReduceCombiner(const ReduceNode* op) { } } - CommReducer new_combiner = - CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return ReduceNode::make( - new_combiner, new_source, op->axis, op->condition, new_value_index); + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + return ReduceNode::make(new_combiner, new_source, op->axis, op->condition, new_value_index); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { // Recursively call simplification when necessary. PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); @@ -1108,10 +1060,8 @@ VisitExpr_(const ReduceNode* op) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return this->VisitExpr( - SelectNode::make(op->condition, - op->source[op->value_index], - op->combiner->identity_element[op->value_index])); + return this->VisitExpr(SelectNode::make(op->condition, op->source[op->value_index], + op->combiner->identity_element[op->value_index])); } // combiner simplification. ret = SimplifyReduceCombiner(op); @@ -1122,19 +1072,13 @@ PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } -void CanonicalSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } -CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -CanonicalSimplifier::~CanonicalSimplifier() { - delete impl_; -} +CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h index fddd34bcaab83..39530ff9d49fa 100644 --- a/src/arith/compute_expr.h +++ b/src/arith/compute_expr.h @@ -26,8 +26,9 @@ #include #include -#include + #include +#include namespace tvm { namespace arith { @@ -39,7 +40,7 @@ namespace arith { * \tparam Op the computation operator * \return The result. */ -template +template inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { return OP::make(lhs, rhs); } @@ -52,46 +53,45 @@ inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { * \tparam Op The computation operator * \return The result. */ -template -inline PrimExpr ComputeReduce( - const Array& values, PrimExpr empty_value); +template +inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value); -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a + b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a - b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return a * b; } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncdiv(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return truncmod(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return max(a, b); } -template<> +template <> inline PrimExpr Compute(PrimExpr a, PrimExpr b) { return min(a, b); } -template +template inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value) { if (values.size() == 0U) { CHECK(empty_value.defined()); @@ -106,4 +106,4 @@ inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_valu } // namespace arith } // namespace tvm -#endif // TVM_ARITH_COMPUTE_EXPR_H_ +#endif // TVM_ARITH_COMPUTE_EXPR_H_ diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index a440af9942026..ad6570ee563d1 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "int_operator.h" namespace tvm { @@ -43,7 +45,7 @@ namespace arith { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template +template inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } @@ -57,7 +59,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template +template inline PrimExpr TryConstFold(PrimExpr a); /*! @@ -70,254 +72,250 @@ inline PrimExpr TryConstFold(PrimExpr a); * \return the checked result. */ inline bool IsIndexType(const DataType& type) { - return type.is_int() && type.lanes() == 1 && - (type.bits() == 32 || type.bits() == 64); + return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } - -#define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using tir::FloatImmNode; \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const FloatImmNode* fa = a.as(); \ - const FloatImmNode* fb = b.as(); \ +#define TVM_ARITH_CONST_PROPAGATION(BODY) \ + using tir::FloatImmNode; \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const FloatImmNode* fa = a.as(); \ + const FloatImmNode* fb = b.as(); \ BODY; - -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const DataType& ta = a.dtype(); \ - const DataType& tb = b.dtype(); \ - if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ - BODY; \ - } \ - +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const DataType& ta = a.dtype(); \ + const DataType& tb = b.dtype(); \ + if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ + BODY; \ + } // specialization of constant folders. -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value + pb->value); - if (pa && pa->value == 0) return b; - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value + fb->value); - if (fa && fa->value == 0) return b; - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); + if (pa && pa->value == 0) return b; + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value + fb->value); + if (fa && fa->value == 0) return b; + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value - pb->value); - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value - fb->value); - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value - fb->value); + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value * pb->value); - if (pa) { - if (pa->value == 1) return b; - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - if (pb->value == 0) return b; - } - if (fa && fb) return FloatImm(rtype, fa->value * fb->value); - if (fa) { - if (fa->value == 1) return b; - if (fa->value == 0) return a; - } - if (fb) { - if (fb->value == 1) return a; - if (fb->value == 0) return b; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); + if (pa) { + if (pa->value == 1) return b; + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + if (pb->value == 0) return b; + } + if (fa && fb) return FloatImm(rtype, fa->value * fb->value); + if (fa) { + if (fa->value == 1) return b; + if (fa->value == 0) return a; + } + if (fb) { + if (fb->value == 1) return a; + if (fb->value == 0) return b; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - // due to division and mod can have different modes - // NOTE: this will assumes truc div. - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value / pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, fa->value / fb->value); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value / pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, fa->value / fb->value); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value % pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value % pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, arith::floordiv(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, std::floor(fa->value / fb->value)); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, floormod(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, floormod(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); + }); return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); @@ -328,7 +326,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); @@ -339,7 +337,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> +template <> inline PrimExpr TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { @@ -364,9 +362,7 @@ struct SymbolicLimits { * * \return positive infinity. */ -inline PrimExpr pos_inf() { - return SymbolicLimits::pos_inf_; -} +inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; } /*! * \brief Check if value is positive infinity. @@ -374,9 +370,7 @@ inline PrimExpr pos_inf() { * * \return The check result. */ -inline bool is_pos_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::pos_inf_); -} +inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); } /*! * \brief Opaque expression representing negative infinity. @@ -386,9 +380,7 @@ inline bool is_pos_inf(const PrimExpr& value) { * * \return negative infinity. */ -inline PrimExpr neg_inf() { - return SymbolicLimits::neg_inf_; -} +inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; } /*! * \brief Check if value is negative infinity. @@ -396,9 +388,7 @@ inline PrimExpr neg_inf() { * * \return The check result. */ -inline bool is_neg_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::neg_inf_); -} +inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); } } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 57dfc157fc219..0f4d9c0620862 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -20,10 +20,12 @@ /*! * \file tvm/arith/const_int_bound.cc */ -#include #include +#include #include + #include + #include "int_operator.h" #include "pattern_match.h" @@ -34,8 +36,7 @@ using namespace tir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound::ConstIntBound( - int64_t min_value, int64_t max_value) { +ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { auto node = make_object(); node->min_value = min_value; node->max_value = max_value; @@ -46,8 +47,7 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.ConstIntBound") -.set_body_typed(MakeConstIntBound); +TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { @@ -60,31 +60,29 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ConstIntBound["; - PrintBoundValue(p->stream, op->min_value); - p->stream << ','; - PrintBoundValue(p->stream, op->max_value); - p->stream << ']'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ConstIntBound["; + PrintBoundValue(p->stream, op->min_value); + p->stream << ','; + PrintBoundValue(p->stream, op->max_value); + p->stream << ']'; + }); // internal entry for const int bound struct ConstIntBoundAnalyzer::Entry { int64_t min_value; int64_t max_value; - bool is_const(int64_t value) const { - return min_value == max_value && min_value == value; - } + bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } bool operator==(const Entry& other) const { return min_value == other.min_value && max_value == other.max_value; } }; -class ConstIntBoundAnalyzer::Impl : - public ExprFunctor { +class ConstIntBoundAnalyzer::Impl + : public ExprFunctor { public: /*! \brief additional bound info about expr \in bound */ struct BoundInfo { @@ -94,46 +92,39 @@ class ConstIntBoundAnalyzer::Impl : Entry bound; BoundInfo() {} - BoundInfo(PrimExpr expr, Entry bound) - : expr(expr), bound(bound) { - } + BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} }; - void Bind(const Var& var, const Range& range) { + void Bind(const Var& var, const Range& range, bool override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); Entry ret; ret.min_value = a.min_value; ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); - Update(var, ret, false); + Update(var, ret, override); } - void Update(const Var& var, - const Entry& info, - bool override) { + void Update(const Var& var, const Entry& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) - << ", new=" << ConstIntBound(info.min_value, info.max_value); + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" + << ConstIntBound(it->second.min_value, it->second.max_value) + << ", new=" << ConstIntBound(info.min_value, info.max_value); } } var_map_[var] = info; } - void Update(const Var& var, - const ConstIntBound& info, - bool override) { + void Update(const Var& var, const ConstIntBound& info, bool override) { Update(var, MakeBound(info->min_value, info->max_value), override); } // Override visitor behaviors Entry VisitExprDefault_(const Object* op) final { - return Everything( - static_cast(op)->dtype); + return Everything(static_cast(op)->dtype); } Entry VisitExpr(const PrimExpr& expr) final { @@ -147,15 +138,16 @@ class ConstIntBoundAnalyzer::Impl : } } if (bound_) { - const PrimExprNode* op = expr.as(); - auto val = bound_->find(op); + auto val = bound_->find(expr); if (val != bound_->end()) { - CHECK(val->second->min_value == res.min_value && - val->second->max_value == res.max_value) - << "Detected bound for " << expr - << "conflicts with memorization"; + auto everything = Everything(expr->dtype); + CHECK( + (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || + (val->second->min_value == everything.min_value && + val->second->max_value == everything.max_value)) + << "Detected bound for " << expr << "conflicts with memorization"; } - (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); } return res; } @@ -176,9 +168,7 @@ class ConstIntBoundAnalyzer::Impl : return Intersect(a, b); } - Entry VisitExpr_(const IntImmNode* op) final { - return MakeBound(op->value, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -223,8 +213,7 @@ class ConstIntBoundAnalyzer::Impl : // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; // other case, we can get close to 0 - return MakeBound(0, - std::min(a.max_value, b_max_cap)); + return MakeBound(0, std::min(a.max_value, b_max_cap)); } else { return MakeBound(std::max(a.min_value, -b_max_cap), std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); @@ -367,7 +356,7 @@ class ConstIntBoundAnalyzer::Impl : // additional bound info std::vector additional_info_; // look up table for memorization - std::unordered_map* bound_{nullptr}; + BoundMapType* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -382,7 +371,7 @@ class ConstIntBoundAnalyzer::Impl : * \tparam F the operator function type. * \return The result. */ - template + template static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) { Entry ret; // The boundary point must be shihft of the original boundary. @@ -560,35 +549,28 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { return ConstIntBound(ret.min_value, ret.max_value); } -ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, - std::unordered_map* bound) { +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) { impl_->bound_ = bound; Entry ret = impl_->VisitExpr(expr); impl_->bound_ = nullptr; return ConstIntBound(ret.min_value, ret.max_value); } -void ConstIntBoundAnalyzer::Update(const Var& var, - const ConstIntBound& info, - bool override) { +void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { impl_->Update(var, info, override); } -void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { - impl_->Bind(var, range); +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) { + impl_->Bind(var, range, override); } std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) - : impl_(new Impl()) { -} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} -ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { - delete impl_; -} +ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index c7f90f535e292..2bc7209655f3b 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -21,13 +21,13 @@ * \file detect_linear_equation.cc * \brief Utility to detect patterns in the expression. */ +#include #include -#include #include -#include +#include #include +#include #include -#include namespace tvm { namespace arith { @@ -45,11 +45,9 @@ struct IntervalEntry { PrimExpr max_value; }; -class LinearEqDetector - : public ExprFunctor { +class LinearEqDetector : public ExprFunctor { public: - explicit LinearEqDetector(Var var) - : var_(var) {} + explicit LinearEqDetector(Var var) : var_(var) {} bool Detect(const PrimExpr& e, LinearEqEntry* ret) { *ret = VisitExpr(e, e); @@ -142,8 +140,7 @@ class LinearEqDetector } }; -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars) { +Array DetectLinearEquation(const PrimExpr& e, const Array& vars) { PrimExpr base = e; Array coeff; @@ -157,9 +154,7 @@ Array DetectLinearEquation(const PrimExpr& e, } std::unordered_set vset; - auto vset_contains = [&](const VarNode* node) { - return vset.count(node) != 0; - }; + auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; }; for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); @@ -173,9 +168,8 @@ Array DetectLinearEquation(const PrimExpr& e, } // Detect clip condition as min max value -bool DetectClipBound( - const PrimExpr& cond, - std::unordered_map* bmap) { +bool DetectClipBound(const PrimExpr& cond, + std::unordered_map* bmap) { int flag = 0; Var var; auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { @@ -237,8 +231,7 @@ bool DetectClipBound( return false; } - -template +template void SplitCommExpr(const PrimExpr& e, std::vector* ret) { if (const OP* op = e.as()) { SplitCommExpr(op->a, ret); @@ -276,12 +269,11 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") -.set_body_typed(DetectLinearEquation); +TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); TVM_REGISTER_GLOBAL("arith.DetectClipBound") -.set_body_typed([](const PrimExpr& e, const Array& vars) { - return DetectClipBound(e, vars); -}); + .set_body_typed([](const PrimExpr& e, const Array& vars) { + return DetectClipBound(e, vars); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 81443db38d479..0ac4a893a77fc 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,13 +21,13 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include +#include #include #include -#include -#include -#include #include +#include namespace tvm { namespace arith { @@ -37,12 +37,8 @@ using namespace tir; // Find Read region of the tensor in the stmt. class BufferTouchedDomain final : public StmtExprVisitor { public: - BufferTouchedDomain(const Buffer &buffer, - bool consider_loads, - bool consider_stores) - : buffer_(buffer), - consider_loads_(consider_loads), - consider_stores_(consider_stores) {} + BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores) + : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {} Domain Find(const Stmt& stmt) { operator()(stmt); @@ -54,17 +50,15 @@ class BufferTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const ForNode *op) final { + void VisitStmt_(const ForNode* op) final { const VarNode* var = op->loop_var.get(); - dom_map_[var] = IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } void VisitStmt_(const LetStmtNode* op) final { - dom_map_[op->var.get()] = - arith::EvalSet(op->value, dom_map_); + dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(op->var.get()); } @@ -107,21 +101,18 @@ class BufferTouchedDomain final : public StmtExprVisitor { } } - const Buffer &buffer_; + const Buffer& buffer_; bool consider_loads_, consider_stores_; std::vector > bounds_; std::unordered_map dom_map_; }; -Domain DomainTouched(const Stmt& stmt, - const Buffer& buffer, - bool consider_loads, +Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, bool consider_stores) { return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } -TVM_REGISTER_GLOBAL("arith.DomainTouched") -.set_body_typed(DomainTouched); +TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 34efa986e9856..62858d2dc9e2c 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -22,19 +22,18 @@ * \brief The integer constraints data structures. */ #include +#include #include #include -#include -#include #include #include +#include namespace tvm { namespace arith { -IntConstraints::IntConstraints(Array variables, - Map ranges, +IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { @@ -46,7 +45,7 @@ IntConstraints::IntConstraints(Array variables, CHECK(relations.defined()); for (const auto& var : variables) { CHECK(var.dtype().is_int() || var.dtype().is_uint()) - << "Variables in IntConstraints must be integers"; + << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); node->ranges = std::move(ranges); @@ -57,18 +56,13 @@ IntConstraints::IntConstraints(Array variables, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraints(" - << op->variables - << ", " << op->ranges - << ", " << op->relations - << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations + << ")"; + }); -IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, - IntConstraints dst, +IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { ObjectPtr node = make_object(); @@ -82,15 +76,12 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraintsTransform(" - << "\n\t" << op->src - << "\n\t" << op->dst - << "\n\t" << op->src_to_dst - << "\n\t" << op->dst_to_src - << "\n)"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraintsTransform(" + << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t" + << op->dst_to_src << "\n)"; + }); } // namespace arith } // namespace tvm diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 3be34b6387779..8e4dda0284e21 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -38,56 +38,41 @@ namespace arith { * \return Whether overflow can happen. * \tparam Op The integer operator. */ -template -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x > max_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x < min_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if (y == 0) return false; if (y > 0) { - if (x < min_value / y) return true; - if (x > max_value / y) return true; + if (x < min_value / y) return true; + if (x > max_value / y) return true; } else { if (y == -1 && x == std::numeric_limits::min()) return true; - if (x > min_value / y) return true; - if (x < max_value / y) return true; + if (x > min_value / y) return true; + if (x < max_value / y) return true; } return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return y == 0; } @@ -97,9 +82,7 @@ inline bool WillOverflow(int64_t x, * \param y The right operand. * \return the result. */ -inline int64_t truncdiv(int64_t x, int64_t y) { - return x / y; -} +inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } /*! * \brief Compute the truncdiv remainder of two integers. @@ -107,9 +90,7 @@ inline int64_t truncdiv(int64_t x, int64_t y) { * \param y The right operand. * \return the result. */ -inline int64_t truncmod(int64_t x, int64_t y) { - return x % y; -} +inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } /*! * \brief Peform floor division of two integers. @@ -120,13 +101,10 @@ inline int64_t truncmod(int64_t x, int64_t y) { inline int64_t floordiv(int64_t x, int64_t y) { int64_t rdiv = x / y; int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rdiv : (rdiv - 1); } - /*! * \brief Compute the floordiv remainder of two integers. * \param x The left operand. @@ -135,9 +113,7 @@ inline int64_t floordiv(int64_t x, int64_t y) { */ inline int64_t floormod(int64_t x, int64_t y) { int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rmod : rmod + y; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 027259a4d2252..7462808c92a90 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -22,23 +22,24 @@ * \brief The integer set functions */ #include +#include #include #include -#include -#include #include #include +#include + #include "interval_set.h" #include "pattern_match.h" namespace tvm { namespace arith { +using tir::is_one; +using tir::is_zero; using tir::make_const; using tir::make_zero; -using tir::is_zero; -using tir::is_one; PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); @@ -54,9 +55,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.IntervalSet") -.set_body_typed(MakeIntervalSet); - +TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -77,15 +76,15 @@ IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } // type traits -template +template struct is_logical_op { static const bool value = false; }; -#define TVM_DECLARE_LOGICAL_OP(OP) \ - template<> \ - struct is_logical_op { \ - static const bool value = true; \ +#define TVM_DECLARE_LOGICAL_OP(OP) \ + template <> \ + struct is_logical_op { \ + static const bool value = true; \ }; TVM_DECLARE_LOGICAL_OP(AndNode); @@ -102,18 +101,15 @@ TVM_DECLARE_LOGICAL_OP(NotNode); * \brief Combine two interval set under arithmetic operations. * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); if (!res.defined()) res = Op::make(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.dtype(), 0), - make_const(a->min_value.dtype(), 1)); + return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -122,47 +118,36 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasLowerBound() ? - a->min_value + b->min_value : neg_inf(); + a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasUpperBound() ? - a->max_value + b->max_value : pos_inf(); + a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); return IntervalSet(min_value, max_value); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasUpperBound() ? - a->min_value - b->max_value : neg_inf(); + a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasLowerBound() ? - a->max_value - b->min_value : pos_inf(); + a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -194,10 +179,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -229,10 +212,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -259,11 +240,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -295,10 +273,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -311,6 +287,16 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { + if (divisor.as()) { + // a mod b = a - (a / b) * b if a_max / b == a_min / b + auto qmax = floordiv(a->max_value, divisor); + auto qmin = floordiv(a->min_value, divisor); + if (analyzer->CanProve(qmax == qmin)) { + auto tmax = a->max_value - divisor * qmin; + auto tmin = a->min_value - divisor * qmin; + return IntervalSet(tmin, tmax); + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -321,30 +307,24 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { - return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); + return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(max(a->min_value, b->min_value), - max(a->max_value, b->max_value)); + return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value)); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(min(a->min_value, b->min_value), - min(a->max_value, b->max_value)); + return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value)); } // internal helper function to get an interval set @@ -360,20 +340,12 @@ using namespace tir; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. -class IntervalSetEvaluator : - public ExprFunctor { +class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, - const Map& dom_map, - bool eval_vec = false) - : analyzer_(analyzer), - dom_map_(dom_map), - eval_vec_(eval_vec) { - } + IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, bool eval_vec = false) + : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {} - IntervalSet Eval(const PrimExpr& val) { - return this->VisitExpr(val); - } + IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set IntervalSet Eval(IntervalSet val) { // avoid recursive indefinite recursive expansion. @@ -394,8 +366,7 @@ class IntervalSetEvaluator : auto it = dom_map_.find(var); if (it != dom_map_.end()) { IntervalSet res = ToIntervalSet((*it).second); - if (res->min_value.same_as(var) && - res->max_value.same_as(var)) { + if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } // recursively evaluate mapped result @@ -406,74 +377,39 @@ class IntervalSetEvaluator : } } + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AddNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const SubNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MulNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const DivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const ModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MinNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MaxNode* op) final { - return VisitBinaryExpr_(op); - } - - IntervalSet VisitExpr_(const EQNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AndNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const OrNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); @@ -482,16 +418,12 @@ class IntervalSetEvaluator : if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; - if (vstride> 0) { - return Combine( - analyzer_, - base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + if (vstride > 0) { + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine( - analyzer_, - base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -516,12 +448,11 @@ class IntervalSetEvaluator : private: // whether set is exactly single point that equals value. - bool MatchPoint(const IntervalSet& set, - const PrimExpr& value) const { + bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { return set->min_value.same_as(value) && set->max_value.same_as(value); } - template + template inline IntervalSet VisitBinaryExpr_(const T* op) { IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); @@ -541,9 +472,7 @@ class IntervalSetEvaluator : class IntSetAnalyzer::Impl { public: - explicit Impl(Analyzer* analyzer) - : analyzer_(analyzer) { - } + explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); @@ -553,16 +482,11 @@ class IntSetAnalyzer::Impl { Analyzer* analyzer_; }; -IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -IntSetAnalyzer::~IntSetAnalyzer() { - delete impl_; -} +IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, - const Map& dom_map) { +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -574,8 +498,8 @@ Range IntSet::cover_range(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); CHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::make_by_min_extent( - s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); + return Range::make_by_min_extent(s_int->min_value, + analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -654,17 +578,11 @@ PrimExpr IntSet::point_value() const { return s_int->min_value; } -IntSet IntSet::nothing() { - return IntervalSet::Empty(); -} +IntSet IntSet::nothing() { return IntervalSet::Empty(); } -IntSet IntSet::everything() { - return IntervalSet::Everything(); -} +IntSet IntSet::everything() { return IntervalSet::Everything(); } -IntSet IntSet::single_point(PrimExpr x) { - return IntervalSet::SinglePoint(x); -} +IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); } IntSet IntSet::interval(PrimExpr min, PrimExpr max) { if (min.same_as(max)) { @@ -692,7 +610,7 @@ bool IntSet::match_range(const Range& b) const { if (!a_int) return false; Analyzer ana; return ProveEqual(&ana, a_int->min_value, b->min) && - ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); + ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } IntSet Union(const Array& sets) { @@ -703,8 +621,7 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), - ana.Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -715,8 +632,7 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), - ana.Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { @@ -727,8 +643,7 @@ Map ConvertDomMap(const Map& dom_map) { return dmap; } -Map ConvertDomMap( - const std::unordered_map& dom_map) { +Map ConvertDomMap(const std::unordered_map& dom_map) { Map dmap; for (auto kv : dom_map) { dmap.Set(GetRef(kv.first), kv.second); @@ -736,8 +651,7 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } @@ -748,49 +662,40 @@ IntSet IntSet::vector(PrimExpr x) { return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map) { +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); return std::move(res); } -IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map) { +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); const IntervalSetNode* s_int = s.as(); - PrimExpr vmax = s_int->HasUpperBound() ? - m.Eval(s_int->max_value).max() : s_int->max_value; - PrimExpr vmin = s_int->HasLowerBound() ? - m.Eval(s_int->min_value).min() : s_int->min_value; + PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; + PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; return IntervalSet(vmin, vmax); } class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator( - Analyzer* analyzer, - const Map& dom_map) + explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -802,9 +707,8 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr( - PrimExpr e, - const std::unordered_map& dom_map) { +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); SubExprIntervalSetEvaluator m(&ana, dmap); @@ -812,42 +716,32 @@ ExprIntSetMap EvalSetForEachSubExpr( return m.expr_map; } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } TVM_REGISTER_NODE_TYPE(IntervalSetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntervalSet" - << "[" << op->min_value << ", " - << op->max_value << ']'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntervalSet" + << "[" << op->min_value << ", " << op->max_value << ']'; + }); -TVM_REGISTER_GLOBAL("arith.intset_single_point") -.set_body_typed(IntSet::single_point); +TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::single_point); -TVM_REGISTER_GLOBAL("arith.intset_vector") -.set_body_typed(IntSet::vector); +TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::vector); -TVM_REGISTER_GLOBAL("arith.intset_interval") -.set_body_typed(IntSet::interval); +TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::interval); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") -.set_body_method(&IntSet::min); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") -.set_body_method(&IntSet::max); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") -.set_body_method(&IntSet::is_nothing); +TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::is_nothing); -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") -.set_body_method(&IntSet::is_everything); +TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::is_everything); } // namespace arith } // namespace tvm diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 51b500adb4125..eb308dd385a4d 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -26,7 +26,9 @@ #include #include + #include + #include "const_fold.h" namespace tvm { @@ -53,26 +55,18 @@ class IntervalSetNode : public IntSetNode { } /*! \return Whether the interval has upper bound. */ - bool HasUpperBound() const { - return !is_pos_inf(max_value) && !IsEmpty(); - } + bool HasUpperBound() const { return !is_pos_inf(max_value) && !IsEmpty(); } /*! \return Whether the interval has lower bound. */ - bool HasLowerBound() const { - return !is_neg_inf(min_value) && !IsEmpty(); - } + bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ - bool IsSinglePoint() const { - return min_value.same_as(max_value); - } + bool IsSinglePoint() const { return min_value.same_as(max_value); } /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. return is_pos_inf(min_value) || is_neg_inf(max_value); } /*! \return whether interval represent everything */ - bool IsEverything() const { - return is_neg_inf(min_value) && is_pos_inf(max_value); - } + bool IsEverything() const { return is_neg_inf(min_value) && is_pos_inf(max_value); } static constexpr const char* _type_key = "arith.IntervalSet"; TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); @@ -97,24 +91,18 @@ class IntervalSet : public IntSet { * \param value The value to be represented. * \return The result set. */ - static IntervalSet SinglePoint(PrimExpr value) { - return IntervalSet(value, value); - } + static IntervalSet SinglePoint(PrimExpr value) { return IntervalSet(value, value); } /*! * \brief Create an IntervalSet that represents everything. * \param value The value to be represented. * \return The result set. */ - static IntervalSet Everything() { - return IntervalSet(neg_inf(), pos_inf()); - } + static IntervalSet Everything() { return IntervalSet(neg_inf(), pos_inf()); } /*! * \brief Create an empty eet. * \return The result set. */ - static IntervalSet Empty() { - return IntervalSet(pos_inf(), neg_inf()); - } + static IntervalSet Empty() { return IntervalSet(pos_inf(), neg_inf()); } TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); @@ -136,7 +124,7 @@ TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); * \param b The second set. * \return The result set. */ -TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); +TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b); } // namespace arith } // namespace tvm diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 0ae98412b0177..e09ff1d65a5e7 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -20,24 +20,22 @@ /*! * \file tvm/arith/ir_mutator_with_analyzer.cc */ +#include "ir_mutator_with_analyzer.h" + #include #include -#include "ir_mutator_with_analyzer.h" namespace tvm { namespace arith { using namespace tir; -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const ForNode* op) { - analyzer_->Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); +Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprMutator::VisitStmt_(op); } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const LetStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -45,8 +43,7 @@ VisitStmt_(const LetStmtNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -56,8 +53,7 @@ VisitStmt_(const LetStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const IfThenElseNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { @@ -65,9 +61,8 @@ VisitStmt_(const IfThenElseNode* op) { then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { - With ctx(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(condition))); - else_case = this->VisitStmt(op->else_case); + With ctx(analyzer_, analyzer_->rewrite_simplify(NotNode::make(condition))); + else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; if (is_zero(condition)) { @@ -77,8 +72,7 @@ VisitStmt_(const IfThenElseNode* op) { return EvaluateNode::make(0); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -90,14 +84,11 @@ VisitStmt_(const IfThenElseNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == tir::attr::thread_extent || - op->attr_key == tir::attr::virtual_thread) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { @@ -105,16 +96,13 @@ VisitStmt_(const AttrStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AssertStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -125,8 +113,7 @@ VisitStmt_(const AssertStmtNode* op) { } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const CallNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); @@ -146,21 +133,17 @@ VisitExpr_(const CallNode* op) { if (is_one(cond)) { return true_value; } - if (cond.same_as(op->args[0]) && - true_value.same_as(op->args[1]) && + if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, - {cond, true_value, false_value}, - op->call_type); + return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const LetNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -168,16 +151,14 @@ VisitExpr_(const LetNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const SelectNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr true_value, false_value; { @@ -185,8 +166,7 @@ VisitExpr_(const SelectNode* op) { true_value = VisitExpr(op->true_value); } { - With constraint(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -196,8 +176,7 @@ VisitExpr_(const SelectNode* op) { return true_value; } // normal path - if (cond.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { @@ -205,8 +184,7 @@ VisitExpr_(const SelectNode* op) { } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const ReduceNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index f6004e2ad9b9a..004265bbe50a6 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -24,8 +24,9 @@ #ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ -#include #include +#include + #include namespace tvm { @@ -42,11 +43,10 @@ namespace arith { */ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { public: - explicit IRMutatorWithAnalyzer(Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} - using StmtExprMutator::VisitStmt_; using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; // override functions that need to populate the context information. tir::Stmt VisitStmt_(const tir::ForNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index b2dbe9d10c08f..810949b56e1ff 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -34,23 +34,18 @@ namespace tir { class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: - PrimExpr Simplify(const PrimExpr& expr) { - return analyzer_.Simplify(expr); - } + PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 40cd7f8793ee9..7ddb8f5251e7e 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -21,13 +21,15 @@ * \file modular_set.cc * \brief Modular set analysis */ -#include #include -#include +#include #include +#include + #include -#include #include +#include + #include "pattern_match.h" namespace tvm { @@ -46,19 +48,15 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ModularSet(" - << "coeff=" << op->coeff << ", base=" - << op->base << ')'; - }); - -ModularSet MakeModularSet(int64_t coeff, int64_t base) { - return ModularSet(coeff, base); -} + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ModularSet(" + << "coeff=" << op->coeff << ", base=" << op->base << ')'; + }); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith.ModularSet") -.set_body_typed(MakeModularSet); +TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { @@ -77,37 +75,27 @@ struct ModularSetAnalyzer::Entry { this->base = base; } - bool is_const() const { - return coeff == 0; - } + bool is_const() const { return coeff == 0; } - bool operator==(const Entry& other) const { - return coeff == other.coeff && base == other.base; - } + bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } bool operator==(const ModularSet& other) const { - return other.defined() && - coeff == other->coeff && base == other->base; + return other.defined() && coeff == other->coeff && base == other->base; } }; -class ModularSetAnalyzer::Impl : - public ExprFunctor { +class ModularSetAnalyzer::Impl : public ExprFunctor { public: - explicit Impl(Analyzer* parent) - : parent_(parent) {} + explicit Impl(Analyzer* parent) : parent_(parent) {} - void Update(const Var& var, - const ModularSet& info, - bool override) { + void Update(const Var& var, const ModularSet& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ModularSet(it->second.coeff, it->second.base) - << ", new=" << info; + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" << ModularSet(it->second.coeff, it->second.base) + << ", new=" << info; } } var_map_[var] = Entry(info->coeff, info->base); @@ -127,17 +115,11 @@ class ModularSetAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Object* op) final { - return Everything(); - } + Entry VisitExprDefault_(const Object* op) final { return Everything(); } - Entry VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } - Entry VisitExpr_(const IntImmNode* op) final { - return Entry(0, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -167,9 +149,7 @@ class ModularSetAnalyzer::Impl : return Entry(coeff, a.base * b.base); } - Entry DivByConst(const PrimExpr& lhs, - int64_t val, - bool round_down) { + Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { @@ -179,8 +159,7 @@ class ModularSetAnalyzer::Impl : } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. - if (a.base > 0 && val > 0 && - (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { + if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { return Entry(a.coeff / val, a.base / val); } } @@ -269,9 +248,7 @@ class ModularSetAnalyzer::Impl : } var_map_[var] = Intersect(old, entry); // reover function. - return [this, old, var]() { - var_map_[var] = old; - }; + return [this, old, var]() { var_map_[var] = old; }; } /*! * \brief Create union of two sets. @@ -385,16 +362,12 @@ class ModularSetAnalyzer::Impl : * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. */ - static Entry Everything() { - return Entry(1, 0); - } + static Entry Everything() { return Entry(1, 0); } /*! * \brief return an empty set * \return Bound that represent everything dtype can represent. */ - static Entry Nothing() { - return Entry(0, 1); - } + static Entry Nothing() { return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { @@ -402,9 +375,7 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { return ModularSet(ret.coeff, ret.base); } -void ModularSetAnalyzer::Update(const Var& var, - const ModularSet& info, - bool override) { +void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) { impl_->Update(var, info, override); } @@ -412,13 +383,9 @@ std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constr return impl_->EnterConstraint(constraint); } -ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -ModularSetAnalyzer::~ModularSetAnalyzer() { - delete impl_; -} +ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 14cfbd6f57f96..2a02303d7b16b 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -67,7 +67,9 @@ #include #include + #include + #include "const_fold.h" namespace tvm { @@ -84,7 +86,7 @@ namespace arith { * * \tparam Derived The type of the derived class. */ -template +template class Pattern { public: /*! @@ -108,30 +110,26 @@ class Pattern { * * \return whether value matches the pattern. */ - template + template bool Match(const NodeType& value) const { derived().InitMatch_(); return derived().Match_(value); } /*! \return Derived instance of current class. */ - const Derived& derived() const { - return *static_cast(this); - } + const Derived& derived() const { return *static_cast(this); } }; /*! * \brief Default deep equality checker * \tparam T the comparison point. */ -template +template class PEqualChecker { public: - bool operator()(const T& lhs, const T& rhs) const { - return lhs == rhs; - } + bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } }; -template<> +template <> class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { @@ -140,20 +138,16 @@ class PEqualChecker { } }; -template<> +template <> class PEqualChecker { public: - bool operator()(const IntImm& lhs, const IntImm& rhs) const { - return lhs->value == rhs->value; - } + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; -template<> +template <> class PEqualChecker { public: - bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { - return lhs.same_as(rhs); - } + bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } }; /*! @@ -166,15 +160,13 @@ class PEqualChecker { * \note PVar is not thread safe. * Do not use the same PVar in multiple threads. */ -template -class PVar : public Pattern > { +template +class PVar : public Pattern> { public: // Store PVars by reference in the expression. using Nested = const PVar&; - void InitMatch_() const { - filled_ = false; - } + void InitMatch_() const { filled_ = false; } bool Match_(const T& value) const { if (!filled_) { @@ -186,9 +178,8 @@ class PVar : public Pattern > { } } - template::value>::type> + template ::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { return Match_(GetRef(ptr)); @@ -214,21 +205,17 @@ class PVar : public Pattern > { * * \tparam T the type of the hole. */ -template -class PConst : public Pattern > { +template +class PConst : public Pattern> { public: PConst(T value) // NOLINT(*) : value_(value) {} void InitMatch_() const {} - bool Match_(const T& value) const { - return PEqualChecker()(value_, value); - } + bool Match_(const T& value) const { return PEqualChecker()(value_, value); } - T Eval() const { - return value_; - } + T Eval() const { return value_; } private: const T value_; @@ -240,9 +227,8 @@ class PConst : public Pattern > { * \tparam TA The pattern type of the first operand. * \tparam TB The pattern type of the second operand. */ -template -class PBinaryExpr : - public Pattern > { +template +class PBinaryExpr : public Pattern> { public: PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} @@ -274,12 +260,10 @@ class PBinaryExpr : typename TB::Nested b_; }; -template -class PConstWithTypeLike : - public Pattern > { +template +class PConstWithTypeLike : public Pattern> { public: - PConstWithTypeLike(const TA& ref, int64_t value) - : ref_(ref), value_(value) {} + PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {} void InitMatch_() const {} @@ -291,39 +275,33 @@ class PConstWithTypeLike : } } - PrimExpr Eval() const { - return tir::make_const(ref_.Eval().dtype(), value_); - } + PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; int64_t value_; }; - -#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ - template \ - inline PBinaryExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - CheckStep; \ - return PBinaryExpr(a.derived(), b.derived()); \ - } \ - template \ - inline PBinaryExpr > \ - FuncName(const Pattern& a, int64_t b) { \ - CheckStep; \ - return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ - } \ - template \ - inline PBinaryExpr, TA> \ - FuncName(int64_t b, const Pattern& a) { \ - CheckStep; \ - return FuncName(PConstWithTypeLike(a.derived(), b), a); \ - } - -#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ - TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) - +#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ + template \ + inline PBinaryExpr FuncName(const Pattern& a, const Pattern& b) { \ + CheckStep; \ + return PBinaryExpr(a.derived(), b.derived()); \ + } \ + template \ + inline PBinaryExpr> FuncName(const Pattern& a, \ + int64_t b) { \ + CheckStep; \ + return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ + } \ + template \ + inline PBinaryExpr, TA> FuncName(int64_t b, \ + const Pattern& a) { \ + CheckStep; \ + return FuncName(PConstWithTypeLike(a.derived(), b), a); \ + } + +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a)); @@ -355,15 +333,12 @@ TVM_PATTERN_BINARY_OP(operator||, tir::OrNode); * \brief Pattern not expression. * \tparam TA The pattern type of the true operand. */ -template -class PNotExpr : public Pattern > { +template +class PNotExpr : public Pattern> { public: - explicit PNotExpr(const TA& value) - : value_(value) {} + explicit PNotExpr(const TA& value) : value_(value) {} - void InitMatch_() const { - value_.InitMatch_(); - } + void InitMatch_() const { value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::NotNode* ptr = node.as()) { @@ -374,15 +349,13 @@ class PNotExpr : public Pattern > { } } - PrimExpr Eval() const { - return tir::NotNode::make(value_.Eval()); - } + PrimExpr Eval() const { return tir::NotNode::make(value_.Eval()); } private: typename TA::Nested value_; }; -template +template inline PNotExpr operator!(const Pattern& value) { return PNotExpr(value.derived()); } @@ -394,16 +367,11 @@ inline PNotExpr operator!(const Pattern& value) { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -class PSelectExpr : - public Pattern > { +template +class PSelectExpr : public Pattern> { public: - PSelectExpr(const TCond& condition, - const TA& true_value, - const TB& false_value) - : condition_(condition), - true_value_(true_value), - false_value_(false_value) {} + PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value) + : condition_(condition), true_value_(true_value), false_value_(false_value) {} void InitMatch_() const { condition_.InitMatch_(); @@ -423,8 +391,7 @@ class PSelectExpr : } PrimExpr Eval() const { - return tir::SelectNode::make( - condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + return tir::SelectNode::make(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: @@ -446,13 +413,12 @@ class PSelectExpr : * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PSelectExpr -select(const Pattern& condition, - const Pattern& true_value, - const Pattern& false_value) { - return PSelectExpr( - condition.derived(), true_value.derived(), false_value.derived()); +template +inline PSelectExpr select(const Pattern& condition, + const Pattern& true_value, + const Pattern& false_value) { + return PSelectExpr(condition.derived(), true_value.derived(), + false_value.derived()); } /*! @@ -460,13 +426,10 @@ select(const Pattern& condition, * \tparam DType The Pattern type of dtype. * \tparam TA The pattern type of the first operand. */ -template -class PCastExpr : - public Pattern > { +template +class PCastExpr : public Pattern> { public: - PCastExpr(const DType& dtype, const TA& value) - : dtype_(dtype), value_(value) { - } + PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {} void InitMatch_() const { dtype_.InitMatch_(); @@ -483,9 +446,7 @@ class PCastExpr : } } - PrimExpr Eval() const { - return tir::CastNode::make(dtype_.Eval(), value_.Eval()); - } + PrimExpr Eval() const { return tir::CastNode::make(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; @@ -503,9 +464,8 @@ class PCastExpr : * \tparam DType The pattern type of type. * \tparam TA The pattern type of value. */ -template -inline PCastExpr -cast(const Pattern& dtype, const Pattern& value) { +template +inline PCastExpr cast(const Pattern& dtype, const Pattern& value) { return PCastExpr(dtype.derived(), value.derived()); } @@ -515,15 +475,11 @@ cast(const Pattern& dtype, const Pattern& value) { * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -class PRampExpr : - public Pattern > { +template +class PRampExpr : public Pattern> { public: - PRampExpr(const TBase& base, - const TStride& stride, - const TLanes& lanes) - : base_(base), stride_(stride), lanes_(lanes) { - } + PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes) + : base_(base), stride_(stride), lanes_(lanes) {} void InitMatch_() const { base_.InitMatch_(); @@ -542,9 +498,7 @@ class PRampExpr : } } - PrimExpr Eval() const { - return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; @@ -565,24 +519,18 @@ class PRampExpr : * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -inline PRampExpr -ramp(const Pattern& base, - const Pattern& stride, - const Pattern& lanes) { - return PRampExpr( - base.derived(), stride.derived(), lanes.derived()); +template +inline PRampExpr ramp(const Pattern& base, + const Pattern& stride, + const Pattern& lanes) { + return PRampExpr(base.derived(), stride.derived(), lanes.derived()); } -template -inline PRampExpr, PConst> -ramp(const Pattern& base, - int stride, - int lanes) { +template +inline PRampExpr, PConst> ramp(const Pattern& base, + int stride, int lanes) { return PRampExpr, PConst>( - base.derived(), - PConstWithTypeLike(base.derived(), stride), - PConst(lanes)); + base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); } /*! @@ -590,14 +538,10 @@ ramp(const Pattern& base, * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -class PBroadcastExpr : - public Pattern > { +template +class PBroadcastExpr : public Pattern> { public: - PBroadcastExpr(const TA& value, - const TLanes& lanes) - : value_(value), lanes_(lanes) { - } + PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {} void InitMatch_() const { value_.InitMatch_(); @@ -614,9 +558,7 @@ class PBroadcastExpr : } } - PrimExpr Eval() const { - return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; @@ -634,40 +576,37 @@ class PBroadcastExpr : * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -inline PBroadcastExpr -broadcast(const Pattern& value, const Pattern& lanes) { +template +inline PBroadcastExpr broadcast(const Pattern& value, + const Pattern& lanes) { return PBroadcastExpr(value.derived(), lanes.derived()); } // internal namespace namespace detail { // implementation details for CallExpr -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) { // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) { // NOLINT(*) f(I, std::get(tuple)); - tuple_for_each_dispatcher< - (I + 1) == std::tuple_size::value, (I + 1), F> - ::run(f, tuple); + tuple_for_each_dispatcher<(I + 1) == std::tuple_size::value, (I + 1), F>::run(f, tuple); } }; -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) {} // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) {} // NOLINT(*) }; -template +template inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) - tuple_for_each_dispatcher::value == 0, 0, F> - ::run(f, tuple); + tuple_for_each_dispatcher::value == 0, 0, F>::run(f, tuple); } struct PCallExprInitMatchFunctor { - template + template void operator()(size_t i, const T& pattern) const { pattern.InitMatch_(); } @@ -677,10 +616,9 @@ struct PCallExprMatchFunctor { const tir::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const tir::CallNode* call) - : call_(call) {} + explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} - template + template void operator()(size_t i, const T& pattern) { matched_ = matched_ && pattern.Match_(call_->args[i]); } @@ -689,7 +627,7 @@ struct PCallExprMatchFunctor { struct PCallExprEvalArgsFunctor { Array args_; - template + template void operator()(size_t i, const T& pattern) { args_.push_back(pattern.Eval()); } @@ -703,13 +641,10 @@ struct PCallExprEvalArgsFunctor { * \note Op functor contains the name of the function and * the implementation of Eval. */ -template -class PCallExpr : - public Pattern > { +template +class PCallExpr : public Pattern> { public: - explicit PCallExpr(const TArgs&... args) - : args_(args...) { - } + explicit PCallExpr(const TArgs&... args) : args_(args...) {} void InitMatch_() const { detail::PCallExprInitMatchFunctor finit; @@ -739,18 +674,16 @@ class PCallExpr : }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - return PCallExpr(a.derived(), b.derived()); \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); @@ -760,18 +693,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); @@ -779,9 +710,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::CallNode::make( - args[1].dtype(), kName, args, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; @@ -799,13 +728,12 @@ struct PIfThenElseOp { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PCallExpr -if_then_else(const Pattern& cond, - const Pattern& true_value, - const Pattern& false_value) { - return PCallExpr( - cond.derived(), true_value.derived(), false_value.derived()); +template +inline PCallExpr if_then_else(const Pattern& cond, + const Pattern& true_value, + const Pattern& false_value) { + return PCallExpr(cond.derived(), true_value.derived(), + false_value.derived()); } } // namespace arith diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 126310813cc46..3b8ccfb01a931 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -22,12 +22,15 @@ * \brief Rewrite-rule based simplification. */ // Acknowledgement: Most rewrite-rules are from Halide. +#include "rewrite_simplify.h" + #include #include + #include + #include "const_fold.h" #include "pattern_match.h" -#include "rewrite_simplify.h" namespace tvm { namespace arith { @@ -35,9 +38,9 @@ namespace arith { using namespace tir; // macro for doing simple rewrite -#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ - if ((SrcExpr).Match(ret)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ } // macro for rewrite + recursively rewrite ResExpr @@ -47,15 +50,15 @@ using namespace tir; } // macro rewrite only if CondExor is true after match. -#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return (ResExpr).Eval(); \ } // macro rewrite + recursive_rewrite only if CondExor is true after match. -#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return RecursiveRewrite((ResExpr).Eval()); \ +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ } // NOTE for developers: @@ -66,8 +69,8 @@ using namespace tir; // // try to prove x equals val -RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: -TryCompare(const PrimExpr& x, int64_t val) { +RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, + int64_t val) { PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -100,23 +103,19 @@ TryCompare(const PrimExpr& x, int64_t val) { return kUnknown; } -void RewriteSimplifier::Impl:: -Update(const Var& var, const PrimExpr& info, bool can_override) { +void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { if (!can_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(ExprDeepEqual()(it->second, info)) - << "Trying to update var \'" << var << "\'" - << " with a different value: " - << "original=" << it->second - << ", new=" << info; + CHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" + << " with a different value: " + << "original=" << it->second << ", new=" << info; } } var_map_[var] = info; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -129,14 +128,10 @@ VisitExpr_(const AddNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), - ramp(b1 + b2, s1 + s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), - ramp(b1 + x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), - ramp(x + b1, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), - broadcast(x + y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); } if (IsIndexType(op->dtype)) { @@ -167,14 +162,10 @@ VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); - TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value); // constant folding // NOTE: canonicalization might better at this. @@ -213,8 +204,7 @@ VisitExpr_(const AddNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), - select(x, b1 + s1, b2 + s2)); + TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2)); // default value return ret; } @@ -230,8 +220,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c return frecover; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -244,14 +233,10 @@ VisitExpr_(const SubNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), - ramp(b1 - b2, s1 - s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), - ramp(b1 - x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), - ramp(x - b1, 0 - s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), - broadcast(x - y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); } if (IsIndexType(op->dtype)) { @@ -293,20 +278,20 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE((y + x) - (z + x), y - z); TVM_TRY_REWRITE((y + x) - (x + z), y - z); - TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); - TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); - TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); - TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); - TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); - TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x)); @@ -324,10 +309,8 @@ VisitExpr_(const SubNode* op) { // DivMod rules // trucdiv // NOTE: c*(x/c) + x % c == x is true all division mode. - TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1), @@ -337,45 +320,40 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); // Proof in the case of floordiv, need positive condition. // let x = a * c3 + r // (x + c1) / c3 - x / c3 => (r + c1) / c3 // NOTE: the use of floormod(c2, c3) was intentional to simplify the const. - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3), CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && - c1.Eval()->value >= c2.Eval()->value && - c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3), - truncdiv(truncmod(x, c3) + c1, c3), - CanProveGreaterEqual(x.Eval(), 0) && - c1.Eval()->value >= 0 && - c3.Eval()->value > 0); + c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); + TVM_TRY_REWRITE_IF( + truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3), + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1), @@ -385,30 +363,29 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), - floordiv(floormod(x, c3) + c1, c3), + TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3), c3.Eval()->value > 0); // canonicalization rule @@ -420,17 +397,13 @@ VisitExpr_(const SubNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), - select(x, b1 - s1, b2 - s2)); - TVM_TRY_REWRITE(select(x, y, z) - z, - select(x, y - z, ZeroWithTypeLike(z))); - TVM_TRY_REWRITE(select(x, y, z) - y, - select(x, ZeroWithTypeLike(y), z - y)); + TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2)); + TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z))); + TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y)); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -443,12 +416,9 @@ VisitExpr_(const MulNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), - broadcast(x * y, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), - ramp(b1 * x, s1 * x, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), - ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); } if (IsIndexType(op->dtype)) { @@ -461,15 +431,12 @@ VisitExpr_(const MulNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); - TVM_TRY_RECURSIVE_REWRITE_IF( - (x - y) * c1, (y - x) * (0 - c1), - c1.Eval()->value < 0); + TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -490,8 +457,7 @@ VisitExpr_(const DivNode* op) { // Vector rules if (op->dtype.lanes() != 1) { // NOTE: use div as the pattern also works for float. - TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(div(x, y), lanes)); + TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); // ramp / bcast if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -532,10 +498,8 @@ VisitExpr_(const DivNode* op) { c1.Eval()->value > 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3), - c1.Eval()->value > 0 && - c2.Eval()->value >= 0 && - c3.Eval()->value > 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 && + CanProveGreaterEqual(x.Eval(), 0)); if (truncdiv(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -551,147 +515,102 @@ VisitExpr_(const DivNode* op) { TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), - x * truncdiv(c1, c2) + truncdiv(y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), - min(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), - max(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), - truncdiv(y, c2) + x * truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), - min(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), - max(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2), - x * truncdiv(c1, c2) + truncdiv(z - y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((z - y).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2), - x * truncdiv(c1, c2) + truncdiv(y - z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y - z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), - truncdiv(x, c2) + truncdiv(c1, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -725,8 +644,7 @@ VisitExpr_(const ModNode* op) { if (ramp_min == ramp_max) { return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); } else { - return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), - broadcast(c2, lanes)).Eval(); + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } @@ -738,41 +656,34 @@ VisitExpr_(const ModNode* op) { // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual((x * c1).Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value >= 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value >= 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y * c1).Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y * c1).Eval(), 0)); // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), - truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis if (truncmod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); int64_t c1val = c1.Eval()->value; - if (mod->coeff % c1val == 0 && - c1val > 0 && - CanProveGreaterEqual(x.Eval(), 0)) { + if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) { return truncmod(mod->base, c1).Eval(); } } @@ -780,8 +691,7 @@ VisitExpr_(const ModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -836,67 +746,43 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), - x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), - min(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), - max(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), - floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), - min(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), - max(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), - x * floordiv(c1, c2) + floordiv(z - y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), - x * floordiv(c1, c2) + floordiv(y - z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), - floordiv(x, c2) + floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -907,10 +793,8 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, - CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0)); @@ -924,8 +808,7 @@ VisitExpr_(const FloorDivNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -967,20 +850,16 @@ VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // try modular analysis if (floormod(x, c1).Match(ret)) { @@ -994,8 +873,7 @@ VisitExpr_(const FloorModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MinNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1009,8 +887,7 @@ VisitExpr_(const MinNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(min(x, y), lanes)); + TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } @@ -1035,8 +912,7 @@ VisitExpr_(const MinNode* op) { return (x + c2).Eval(); } } - if (min(x + c1, x).Match(ret) || - min(x, x + c1).Match(ret)) { + if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) { if (c1.Eval()->value < 0) { return (x + c1).Eval(); } else { @@ -1055,40 +931,30 @@ VisitExpr_(const MinNode* op) { // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0); TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y)); TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y)); @@ -1168,19 +1034,15 @@ VisitExpr_(const MinNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - min(c1 - x, c2), c1 - max(x, c1 - c2), - c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), - select(x, min(y, s1), min(z, s2))); + TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MaxNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1194,8 +1056,7 @@ VisitExpr_(const MaxNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(max(x, y), lanes)); + TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } @@ -1220,8 +1081,7 @@ VisitExpr_(const MaxNode* op) { return (x + c2).Eval(); } } - if (max(x + c1, x).Match(ret) || - max(x, x + c1).Match(ret)) { + if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) { if (c1.Eval()->value > 0) { return (x + c1).Eval(); } else { @@ -1239,27 +1099,19 @@ VisitExpr_(const MaxNode* op) { // DivMod rules // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) - TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0); TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y)); TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y)); @@ -1342,18 +1194,15 @@ VisitExpr_(const MaxNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), - select(x, max(y, s1), max(z, s2))); + TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const EQNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1367,8 +1216,7 @@ VisitExpr_(const EQNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), - broadcast(x == y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1386,28 +1234,23 @@ VisitExpr_(const EQNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { return this->VisitExpr(NotNode::make(op->a == op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { return this->VisitExpr(NotNode::make(op->b < op->a)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GENode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { return this->VisitExpr(NotNode::make(op->a < op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1421,10 +1264,8 @@ VisitExpr_(const LTNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), - broadcast(x < y, lanes)); - TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), - broadcast(x < y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); + TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1436,6 +1277,7 @@ VisitExpr_(const LTNode* op) { return make_const(op->dtype, false); } + // clang-format off TVM_TRY_REWRITE(x + y < x + z, y < z); TVM_TRY_REWRITE(x + y < z + x, y < z); TVM_TRY_REWRITE(y + x < x + z, y < z); @@ -1449,100 +1291,76 @@ VisitExpr_(const LTNode* op) { TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x); TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, - c1.Eval()->value < 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); // constant cancelation: only need to make use of one mod // truc div - TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c2 < c1, + x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2), - c1.Eval()->value <= 0 && - c2.Eval()->value > 0); + c1.Eval()->value <= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (euclidean is ok too, floored is not) - TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, - c1.Eval()->value > 0 && + TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 && c2.Eval()->value < 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x, - c1.Eval()->value <= 0 && - c2.Eval()->value < 0); + c1.Eval()->value <= 0 && c2.Eval()->value < 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1, - c1.Eval()->value < 0 && - c2.Eval()->value < 0); + c1.Eval()->value < 0 && c2.Eval()->value < 0); // NOTE: trunc div required (euclidean is ok too, floored is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value < 0); + c1.Eval()->value >= 0 && c2.Eval()->value < 0); // DivMod rules // trucdiv - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value <= 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value <= 0); TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); // invariance for any div mod: x - (x / c1) * c1 == x % c1 - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, + 0 < truncmod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, + y < truncmod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x, - c2 < truncmod(x + c2, c1), - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y, - c2 < truncmod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y, - y < truncmod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, - c2.Eval()->value > 0); - - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, 0 < floormod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, y < floormod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0); + + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, + 0 < floormod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, + y < floormod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x, - c2 < floormod(x + c2, c1), - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y, - c2 < floormod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y, - y < floormod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // canonicalization rule TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z); @@ -1558,12 +1376,12 @@ VisitExpr_(const LTNode* op) { TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); + // clang-format on } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NotNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a); @@ -1587,8 +1405,7 @@ VisitExpr_(const NotNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AndNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1601,8 +1418,7 @@ VisitExpr_(const AndNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), - broadcast(x && y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } auto cfalse = PConst(make_const(op->dtype, false)); @@ -1612,32 +1428,23 @@ VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(x <= y && y < x, cfalse); TVM_TRY_REWRITE(y < x && x <= y, cfalse); - TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, - c2.Eval()->value > c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, - c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const OrNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); PrimExpr const_res = TryConstFold(op->a, op->b); @@ -1650,8 +1457,7 @@ VisitExpr_(const OrNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), - broadcast(x || y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } auto ctrue = PConst(make_const(op->dtype, true)); @@ -1662,32 +1468,23 @@ VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(x <= y || y < x, ctrue); TVM_TRY_REWRITE(y < x || x <= y, ctrue); - TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, - c2.Eval()->value < c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, - c2.Eval()->value < c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); - TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SelectNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; @@ -1697,8 +1494,7 @@ VisitExpr_(const SelectNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CallNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // add condition context to if_then_else PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); @@ -1728,8 +1524,7 @@ VisitExpr_(const CallNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const VarNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) { @@ -1738,15 +1533,13 @@ VisitExpr_(const VarNode* op) { return GetRef(op); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CastNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); return cast(op->dtype, op->value); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LetNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { // it is fine to discard the let binding @@ -1755,8 +1548,7 @@ VisitExpr_(const LetNode* op) { return this->VisitExpr(op->body); } PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -1775,9 +1567,7 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { return res; } -void RewriteSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } @@ -1785,13 +1575,9 @@ std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constra return impl_->EnterConstraint(constraint); } -RewriteSimplifier::RewriteSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -RewriteSimplifier::~RewriteSimplifier() { - delete impl_; -} +RewriteSimplifier::~RewriteSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 8798df92777d8..fd248b97b19ca 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -26,11 +26,13 @@ #include #include + #include #include + #include "const_fold.h" -#include "pattern_match.h" #include "ir_mutator_with_analyzer.h" +#include "pattern_match.h" namespace tvm { namespace arith { @@ -46,8 +48,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; - explicit Impl(Analyzer* parent) - : IRMutatorWithAnalyzer(parent) {} + explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} void Update(const Var& var, const PrimExpr& info, bool override_info); PrimExpr VisitExpr_(const AddNode* op) override; @@ -78,15 +79,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { protected: /*! \brief internal structure for comparison. */ - enum CompareResult { - kUnknown, - kEQ, - kGT, - kGE, - kLT, - kLE, - kNE - }; + enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE }; // counter to record recursive rewrite depth. int recur_depth_{0}; // internal variable map @@ -127,18 +120,17 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { return res; } - template + template PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 0); } - template + template PConstWithTypeLike OneWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 1); } }; - } // namespace arith } // namespace tvm #endif // TVM_ARITH_REWRITE_SIMPLIFY_H_ diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index a89cebe0bf049..50a3243abfb2f 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -21,26 +21,23 @@ * \file tvm/arith/solve_linear_equation.cc * \brief Solve linear equations. */ -#include -#include #include #include -#include #include - +#include +#include +#include +#include #include #include -#include namespace tvm { namespace arith { using namespace tvm::runtime; -void SmithNormalFormDiag(std::vector >* S, - std::vector >* V, - std::vector* x, - std::vector* y) { +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y) { if (S->empty() || V->empty()) return; size_t m = S->size(); size_t n = (*S)[0].size(); // n is # of variables @@ -124,9 +121,9 @@ void SmithNormalFormDiag(std::vector >* S, for (size_t j = index; j < (*S)[i].size(); ++j) { // Multiply index-th row by a and add the i-th row multiplied by b // This will make the index-th diagonal element equal to the gcd - int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; + int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j]; // This transformation performs zeroing of matrix[i][index] - int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; + int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j]; (*S)[index][j] = new_index_j; (*S)[i][j] = new_i_j; } @@ -135,8 +132,8 @@ void SmithNormalFormDiag(std::vector >* S, PrimExpr eb = tir::make_const((*y)[i].dtype(), b); PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); - PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; - PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; + PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i]; + PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i]; (*y)[index] = new_index_rhs; (*y)[i] = new_i_rhs; } @@ -178,15 +175,15 @@ void SmithNormalFormDiag(std::vector >* S, int64_t n_g = (*S)[index][j] / g; for (size_t i = index; i < m; ++i) { - int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; - int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; + int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j]; + int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j]; (*S)[i][index] = new_i_index; (*S)[i][j] = new_i_j; } // We do exactly the same transformations with V for (size_t i = 0; i < n; ++i) { - int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; - int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; + int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j]; + int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j]; (*V)[i][index] = new_i_index; (*V)[i][j] = new_i_j; } @@ -195,8 +192,8 @@ void SmithNormalFormDiag(std::vector >* S, PrimExpr eb = tir::make_const((*x)[index].dtype(), b); PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); - PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; - PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; + PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j]; + PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j]; (*x)[index] = new_index; (*x)[j] = new_j; } @@ -210,8 +207,7 @@ void SmithNormalFormDiag(std::vector >* S, } } -Map InferRange(const Map& vars_to_infer, - const Array& ori_vars, +Map InferRange(const Map& vars_to_infer, const Array& ori_vars, const Map& ori_ranges) { // The resulting ranges Map new_ranges; @@ -245,8 +241,7 @@ Map InferRange(const Map& vars_to_infer, // pretty print matrix equation void DebugPrint(const std::vector>& S, - const std::vector>& V, - const std::vector& V_inv_x, + const std::vector>& V, const std::vector& V_inv_x, const std::vector& rhs) { std::cout << "S:\n"; for (size_t i = 0; i < S.size(); ++i) { @@ -267,7 +262,7 @@ void DebugPrint(const std::vector>& S, std::cout << "\n" << std::endl; } -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) { // m: # of equations // n: # of variables // we first construct A_{mxn} x_{nx1} = y_{mx1} @@ -275,10 +270,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} // => U^{-1} S V^{-1} x = y // S V^{-1} x = U y - std::vector Uy; // mx1 + std::vector Uy; // mx1 std::vector> S; // mxn std::vector> V; // nxn - std::vector V_inv_x; // V^{-1} x, nx1 + std::vector V_inv_x; // V^{-1} x, nx1 // Conditions we don't know what to do with std::vector rest; @@ -301,9 +296,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] - Array coeffs = arith::DetectLinearEquation( - analyzer_problem.Simplify(eq->a - eq->b), - system_to_solve->variables); + Array coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), + system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -365,13 +359,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. - return IntConstraintsTransform( - system_to_solve, - IntConstraints( - /*variables=*/{}, - /*ranges=*/{}, - /*relations=*/{tir::make_zero(DataType::Bool())}), - {}, {}); + return IntConstraintsTransform(system_to_solve, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}), + {}, {}); } else if (!tir::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); } @@ -405,14 +398,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(-Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } } @@ -421,15 +412,15 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol for (size_t i = 0; i < num_vars; ++i) { PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { - e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } // The resulting ranges - Map new_ranges = InferRange( - new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + Map new_ranges = + InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); @@ -440,10 +431,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { PrimExpr express_by_new_vars = old_to_new_map[old_var]; - PrimExpr lower_cond = analyzer_solution.Simplify( - old_range->min <= express_by_new_vars); - PrimExpr upper_cond = analyzer_solution.Simplify( - express_by_new_vars < old_range->min + old_range->extent); + PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); + PrimExpr upper_cond = + analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); if (!tir::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } @@ -459,23 +449,21 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol } IntConstraints solution(new_vars, new_ranges, new_relations); - IntConstraintsTransform transform( - system_to_solve, solution, old_to_new_map, new_to_old_map); + IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map); return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveLinearEquations(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveLinearEquations(problem); - } else { - LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); - } - }); +TVM_REGISTER_GLOBAL("arith.SolveLinearEquations").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveLinearEquations(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveLinearEquations(problem); + } else { + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + } +}); } // namespace arith } // namespace tvm diff --git a/src/arith/util.cc b/src/arith/util.cc index 058c3e9595281..7b7189254408f 100644 --- a/src/arith/util.cc +++ b/src/arith/util.cc @@ -21,8 +21,8 @@ * \file util.cc * \brief The utils for arithmetic analysis. */ -#include #include +#include namespace tvm { namespace arith { @@ -44,7 +44,7 @@ std::tuple xgcd(int64_t a, int64_t b) { CHECK_EQ(a % old_r, 0); CHECK_EQ(b % old_r, 0); - CHECK(old_r == old_s*a + old_t*b); + CHECK(old_r == old_s * a + old_t * b); return std::make_tuple(old_r, old_s, old_t); } diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index da044babdd431..54fc2522db663 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -30,10 +30,9 @@ namespace autotvm { // for loop void FeatureVisitor::VisitStmt_(const ForNode* op) { - const auto *extent = op->extent.as(); + const auto* extent = op->extent.as(); int64_t loop_extent = -1; - if (extent != nullptr) - loop_extent = extent->value; + if (extent != nullptr) loop_extent = extent->value; AnnotationType ann = kSerial; switch (op->for_type) { case ForType ::Parallel: @@ -58,10 +57,9 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { // parallel axis, virtual thread void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); std::string name = var.get()->name_hint; diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 5391bddfa2f69..8180839b0668c 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -29,6 +29,7 @@ #include #include #include + #include namespace tvm { @@ -40,8 +41,17 @@ using namespace tvm::tir; * \brief Type of for loop, used as one-hot encoding in features */ enum AnnotationType { - kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ, - kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread, + kBlockX, + kBlockY, + kBlockZ, + kThreadX, + kThreadY, + kThreadZ, + kUnrolled, + kVectorized, + kParallel, + kSerial, + kVirtualThread, kNum, }; @@ -59,17 +69,17 @@ class FeatureVisitor : public StmtExprVisitor { void VisitExpr_(const LoadNode* op) final; void VisitStmt_(const StoreNode* op) final; - using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; protected: /*! - * \brief Enter a for loop node - * \param var The expression to be printed. - * \param length The output stream - * \param ann_type The type for the for loop - * \return skip Whether skip this node - */ + * \brief Enter a for loop node + * \param var The expression to be printed. + * \param length The output stream + * \param ann_type The type for the for loop + * \return skip Whether skip this node + */ virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0; /*! \brief Exit a for loop subtree */ virtual void ExitItervar_() = 0; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index fbd0829c8a60a..02dae64d6c388 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -24,9 +24,9 @@ #include "touch_extractor.h" -#include #include #include +#include #include namespace tvm { @@ -34,9 +34,14 @@ namespace autotvm { int ParallelLevel(AnnotationType ann) { switch (ann) { - case kBlockX: case kBlockY: case kBlockZ: + case kBlockX: + case kBlockY: + case kBlockZ: return 2; - case kThreadX: case kThreadY: case kThreadZ: case kParallel: + case kThreadX: + case kThreadY: + case kThreadZ: + case kParallel: return 1; default: return 0; @@ -44,7 +49,7 @@ int ParallelLevel(AnnotationType ann) { } // get touch pattern from index expression -class IndexParser: public ExprVisitor { +class IndexParser : public ExprVisitor { public: void Parse(PrimExpr expr) { pattern_map.clear(); @@ -95,11 +100,9 @@ bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_t itervar_map.erase(var); } - itervar_map.insert({var, ItervarFeature(var, length, - static_cast(itervar_stack_.size()), - ann_type, - topdown_product_, - static_cast(itervar_counter_++))}); + itervar_map.insert( + {var, ItervarFeature(var, length, static_cast(itervar_stack_.size()), ann_type, + topdown_product_, static_cast(itervar_counter_++))}); } return true; @@ -120,7 +123,7 @@ void TouchExtractor::ExitItervar_() { CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); touch_pattern->second.count *= itervar_map[var].length; } - } else { // multiply reuse ratio + } else { // multiply reuse ratio for (auto stack_var : itervar_stack_) { auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); @@ -131,8 +134,7 @@ void TouchExtractor::ExitItervar_() { itervar_stack_.pop_back(); int64_t length = itervar_map[var].length; - if (length != 0) - topdown_product_ /= length; + if (length != 0) topdown_product_ /= length; int64_t bottomup_product = -1; for (auto kv : itervar_map[var].touch_feature) { bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse); @@ -188,8 +190,7 @@ void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) { } } -void TouchExtractor::ExitMem_() { -} +void TouchExtractor::ExitMem_() {} /*! * \brief Get axis-based feature for all axes @@ -219,7 +220,7 @@ void TouchExtractor::ExitMem_() { * \note If you want to flatten these features as the input of your model, * You can use the faster one GetItervarFeatureFlatten below. */ -void GetItervarFeature(Stmt stmt, bool take_log, Array > > *ret_feature) { +void GetItervarFeature(Stmt stmt, bool take_log, Array > >* ret_feature) { // extract TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -229,7 +230,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -237,28 +238,26 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { Array > feature_row; - ItervarFeature &fea = touch_analyzer.itervar_map[var]; + ItervarFeature& fea = touch_analyzer.itervar_map[var]; feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); - Array attr{tvm::tir::StringImmNode::make("_attr_"), - FloatImm(DataType::Float(32), trans(fea.length)), - IntImm(DataType::Int(32), fea.nest_level), - FloatImm(DataType::Float(32), trans(fea.topdown_product)), - FloatImm(DataType::Float(32), trans(fea.bottomup_product)), + Array attr{ + tvm::tir::StringImmNode::make("_attr_"), + FloatImm(DataType::Float(32), trans(fea.length)), + IntImm(DataType::Int(32), fea.nest_level), + FloatImm(DataType::Float(32), trans(fea.topdown_product)), + FloatImm(DataType::Float(32), trans(fea.bottomup_product)), }; // one hot annotation for (int i = 0; i < kNum; i++) { @@ -267,10 +266,11 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{tvm::tir::StringImmNode::make("_arith_"), - FloatImm(DataType::Float(32), trans(fea.add_ct)), - FloatImm(DataType::Float(32), trans(fea.mul_ct)), - FloatImm(DataType::Float(32), trans(fea.div_ct)), + feature_row.push_back(Array{ + tvm::tir::StringImmNode::make("_arith_"), + FloatImm(DataType::Float(32), trans(fea.add_ct)), + FloatImm(DataType::Float(32), trans(fea.mul_ct)), + FloatImm(DataType::Float(32), trans(fea.div_ct)), }); // touch map @@ -280,16 +280,16 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; - feature_row.push_back( - Array{tvm::tir::StringImmNode::make(k), - FloatImm(DataType::Float(32), trans(v.stride)), - FloatImm(DataType::Float(32), trans(v.mod)), - FloatImm(DataType::Float(32), trans(v.count)), - FloatImm(DataType::Float(32), trans(v.reuse)), - FloatImm(DataType::Float(32), trans(v.thread_count)), - FloatImm(DataType::Float(32), trans(v.thread_reuse)), - }); + TouchPattern& v = fea.touch_feature[k]; + feature_row.push_back(Array{ + tvm::tir::StringImmNode::make(k), + FloatImm(DataType::Float(32), trans(v.stride)), + FloatImm(DataType::Float(32), trans(v.mod)), + FloatImm(DataType::Float(32), trans(v.count)), + FloatImm(DataType::Float(32), trans(v.reuse)), + FloatImm(DataType::Float(32), trans(v.thread_count)), + FloatImm(DataType::Float(32), trans(v.thread_reuse)), + }); } ret_feature->push_back(feature_row); @@ -305,7 +305,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > * \note See GetItervarFeature for more details about the return value. * This is an optimized version of GetItervarFeature + Flatten. This runs much faster. */ -void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_feature) { +void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -315,7 +315,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -323,20 +323,17 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { - ItervarFeature &fea = touch_analyzer.itervar_map[var]; + ItervarFeature& fea = touch_analyzer.itervar_map[var]; ret_feature->push_back(trans(fea.length)); ret_feature->push_back(fea.nest_level); @@ -360,7 +357,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; + TouchPattern& v = fea.touch_feature[k]; ret_feature->push_back(trans(v.stride)); ret_feature->push_back(trans(v.mod)); ret_feature->push_back(trans(v.count)); @@ -372,12 +369,12 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } /*! - * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector. - * \param stmt The statement to be extracted - * \param sample_n The number of points used for sampling a curve (along one dimension) - * \param ret_feature The buffer where the return value is stored + * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional + * vector. \param stmt The statement to be extracted \param sample_n The number of points used for + * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is + * stored */ -void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *ret_feature) { +void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_ext; touch_ext.Analyze(stmt); @@ -387,7 +384,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r for (auto kv : touch_ext.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order; }); @@ -401,14 +398,14 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // find maximum depth of loop nest for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; max_depth = std::max(max_depth, fea.nest_level); } // mark inner most buffer for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) { auto var = *iter; - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; if (fea.nest_level == max_depth) { for (auto kv : fea.touch_feature) { // delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A') @@ -416,8 +413,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A') size_t pos = raw_name.find("."); - if (pos < kv.first.size()) - raw_name = raw_name.substr(0, pos); + if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos); // If there are multiple innermost buffers that are derived from a same raw buffer // We only record the last occurrence (note the `iter` is in reverse order) @@ -441,7 +437,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // extract curves for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; for (auto kv : fea.touch_feature) { if (innermost_buffers.find(kv.first) != innermost_buffers.end()) { reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2)); @@ -453,7 +449,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } // sample relation in the curve - auto sample_curve = [&](const std::vector &x, const std::vector &y, + auto sample_curve = [&](const std::vector& x, const std::vector& y, double weight) { for (int i = 0; i < sample_n; i++) { double xx = i * weight; @@ -469,9 +465,9 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // serialize to frontend for (auto k : innermost_buffers) { - std::vector &count = count_curve[k]; - std::vector &reuse = reuse_curve[k]; - std::vector &top_down = topdown_curve[k]; + std::vector& count = count_curve[k]; + std::vector& reuse = reuse_curve[k]; + std::vector& top_down = topdown_curve[k]; std::sort(count.begin(), count.end()); std::sort(reuse.begin(), reuse.end()); @@ -484,49 +480,45 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } } - // register API for front end TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - Array > > ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + Array > > ret_feature; - GetItervarFeature(stmt, take_log, &ret_feature); - - *ret = ret_feature; -}); + GetItervarFeature(stmt, take_log, &ret_feature); + *ret = ret_feature; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + std::vector ret_feature; - GetItervarFeatureFlatten(stmt, take_log, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetItervarFeatureFlatten(stmt, take_log, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - int sample_n = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + int sample_n = args[1]; + std::vector ret_feature; - GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); } // namespace autotvm } // namespace tvm diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 23fbc54d843e6..973efb30d5306 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -25,16 +25,17 @@ #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ +#include #include #include -#include -#include -#include +#include #include +#include #include -#include #include +#include + #include "feature_visitor.h" namespace tvm { @@ -55,11 +56,7 @@ struct TouchPattern { // all the feature of an iter var struct ItervarFeature { - ItervarFeature(Var var, - int64_t extent, - int nest, - AnnotationType ann_type, - int64_t topdown, + ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown, int counter) : length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {} ItervarFeature() {} @@ -67,9 +64,9 @@ struct ItervarFeature { // Axis Attributes int64_t length; int nest_level; - AnnotationType ann; // one-hot axis type - int64_t topdown_product; // accumulative product of axis length, in top-down order - int64_t bottomup_product; // accumulative product of axis length, in bottom-up order + AnnotationType ann; // one-hot axis type + int64_t topdown_product; // accumulative product of axis length, in top-down order + int64_t bottomup_product; // accumulative product of axis length, in bottom-up order // bottomup_product = reuse * count for any touched buffer int order; // used for soring axis @@ -86,38 +83,31 @@ struct ItervarFeature { // extract iter vars and their touch pattern from ir class TouchExtractor : public FeatureVisitor { public: - void Analyze(const Stmt& stmt) { - operator()(stmt); - } + void Analyze(const Stmt& stmt) { operator()(stmt); } // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].mul_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index bb97900833dd6..f61ad33190c43 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -20,10 +20,12 @@ /*! * \file codegen_hybrid.cc */ +#include "codegen_hybrid.h" + #include -#include + #include -#include "codegen_hybrid.h" +#include namespace tvm { namespace contrib { @@ -34,7 +36,7 @@ using runtime::TVMRetValue; using namespace tir; std::string dot_to_underscore(std::string s) { - for (auto &ch : s) + for (auto& ch : s) if (ch == '.') ch = '_'; return s; } @@ -57,11 +59,9 @@ std::string CodeGenHybrid::GetUniqueName(std::string prefix) { return prefix; } -std::string CodeGenHybrid::Finish() { - return stream.str(); -} +std::string CodeGenHybrid::Finish() { return stream.str(); } -void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { +void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { if (t.is_float()) { os << "float"; CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); @@ -80,20 +80,19 @@ void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL os << op->value; } -void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); @@ -111,11 +110,10 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsitc(const CallNode* op, - const char* opstr, +inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); @@ -252,9 +250,7 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Phase 0 has no Load(s)!"; } -void CodeGenHybrid::VisitStmt_(const StoreNode* op) { - LOG(FATAL) << "Phase 0 has no Store(s)!"; -} +void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; @@ -268,7 +264,7 @@ void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Ramp to be supported yet"; } -void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -293,8 +289,8 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { CHECK(iter_var); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); PrintIndent(); - stream << "for " << binds_[iter_var->var.get()] << " in bind('" - << iter_var->var->name_hint << "', "; + stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint + << "', "; PrintExpr(op->value, stream); stream << "):\n"; indent_ += tab_; @@ -355,17 +351,16 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) { std::string extent = PrintExpr(op->extent); PrintIndent(); std::string vid = GetVarID(op->loop_var.get()); - stream << "for " << vid << " in " << "range(" << extent << "):\n"; + stream << "for " << vid << " in " + << "range(" << extent << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } -bool is_noop(const Stmt &stmt) { - if (!stmt.defined()) - return true; - if (auto eval = stmt.as()) - return is_const(eval->value); +bool is_noop(const Stmt& stmt) { + if (!stmt.defined()) return true; + if (auto eval = stmt.as()) return is_const(eval->value); return false; } @@ -395,17 +390,13 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); - if (!str.empty()) - stream << str << "\n"; + if (!str.empty()) stream << str << "\n"; } -void CodeGenHybrid::PrintIndent() { - stream << std::string(indent_, ' '); -} +void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } -std::string CodeGenHybrid::GetVarID(const VarNode *v) { - if (binds_.count(v)) - return binds_[v]; +std::string CodeGenHybrid::GetVarID(const VarNode* v) { + if (binds_.count(v)) return binds_[v]; auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; @@ -413,7 +404,7 @@ std::string CodeGenHybrid::GetVarID(const VarNode *v) { return id_map_[key] = GetUniqueName(v->name_hint); } -std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) { +std::string CodeGenHybrid::GetTensorID(const FunctionRef& func, int value_index) { auto key = std::make_pair(func.get(), value_index); if (id_map_.count(key)) { return id_map_[key]; @@ -469,10 +460,8 @@ void CodeGenHybrid::ReserveKeywords() { GetUniqueName("max_num_threads"); } -void CodeGenHybrid::DumpStmt(const Stmt &stmt, - const Array &inputs, - const Array &outputs, - const std::string &name) { +void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, + const Array& outputs, const std::string& name) { ReserveKeywords(); GetUniqueName(name); @@ -491,14 +480,12 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, indent_ += tab_; for (size_t i = 0; i < outputs.size(); ++i) { PrintIndent(); - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) - << " = output_tensor(("; + stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) << " = output_tensor(("; for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { if (j) stream << ", "; PrintExpr(outputs[i]->shape[j], stream); } - if (outputs[i]->shape.size() == 1) - stream << ", "; + if (outputs[i]->shape.size() == 1) stream << ", "; stream << "), '" << outputs[i]->dtype << "')\n"; } PrintStmt(stmt); @@ -511,14 +498,13 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, stream << "\n"; } -TVM_REGISTER_GLOBAL("hybrid._Dump") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CodeGenHybrid codegen; - if (args.size() == 4) - codegen.DumpStmt(args[0], args[1], args[2], args[3]); - else - codegen.DumpStmt(args[0], args[1], args[2]); - *rv = codegen.Finish(); - }); +TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) { + CodeGenHybrid codegen; + if (args.size() == 4) + codegen.DumpStmt(args[0], args[1], args[2], args[3]); + else + codegen.DumpStmt(args[0], args[1], args[2]); + *rv = codegen.Finish(); +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index d282edbb19264..78a22b55dae79 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -24,10 +24,11 @@ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ -#include -#include #include #include +#include +#include + #include #include #include @@ -45,9 +46,8 @@ using namespace tir; * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. */ -class CodeGenHybrid : - public ExprFunctor, - public StmtFunctor { +class CodeGenHybrid : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Dump the given function body to hybrid script. @@ -56,8 +56,8 @@ class CodeGenHybrid : * \param outputs Output tensors of this schedule. * \param name The name of the function. */ - void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, - const std::string &name = "hybrid_func"); + void DumpStmt(const Stmt& stmt, const Array& inputs, const Array& outputs, + const std::string& name = "hybrid_func"); /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -69,55 +69,51 @@ class CodeGenHybrid : * \brief Print the Stmt n to CodeGenHybrid->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt &n) { - this->VisitStmt(n); - } + void PrintStmt(const Stmt& n) { this->VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const PrimExpr &n, std::ostream &os) { - this->VisitExpr(n, os); - } + void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); } /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - std::string PrintExpr(const PrimExpr &n) { + std::string PrintExpr(const PrimExpr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); } // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; @@ -136,7 +132,7 @@ class CodeGenHybrid : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) private: /*! \brief The current indent of the code dump. */ @@ -150,9 +146,9 @@ class CodeGenHybrid : /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ - std::map, std::string> id_map_; + std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ - std::map binds_; + std::map binds_; /*! * \brief Find an unallocated name for the given prefix. * \param prefix The given prefix. @@ -164,13 +160,13 @@ class CodeGenHybrid : * \brief Get or allocate the ID for the given variable. * \param v The given variable. */ - std::string GetVarID(const VarNode *v); + std::string GetVarID(const VarNode* v); /*! * \brief Get or allocate the ID for the given tensor. * \param func The tensor to allocate a name. * \param value_index The value index of the given tensor. */ - std::string GetTensorID(const FunctionRef &func, int value_index); + std::string GetTensorID(const FunctionRef& func, int value_index); /*! \brief the storage scope of allocation */ std::map alloc_storage_scope_; }; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 849c74028d634..cdd9d5441b251 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -23,13 +23,12 @@ */ #include #include -#include - -#include -#include -#include #include #include +#include +#include +#include +#include #include #include @@ -37,9 +36,9 @@ namespace tvm { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -59,12 +58,8 @@ Target DefaultTargetHost(Target target) { } } -tir::Buffer BufferWithOffsetAlignment(Array shape, - DataType dtype, - std::string name, - int data_alignment, - int offset_factor, - bool compact) { +tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact) { auto data = tir::Var(name, DataType::Handle()); bool has_any = false; if (!compact) { @@ -85,21 +80,19 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, } return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor, buffer_type); + data_alignment, offset_factor, buffer_type); } -void GetBinds(const Array& args, - bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, - Map* out_binds, - Array* out_arg_list, + Map* out_binds, Array* out_arg_list, const BuildConfig& config) { *out_binds = binds; - for (const auto &x : args) { + for (const auto& x : args) { if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, - config->data_alignment, config->offset_factor, compact); + auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment, + config->offset_factor, compact); out_binds->Set(x, buf); out_arg_list->push_back(buf); } else { @@ -115,8 +108,7 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } - -template +template transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { @@ -128,10 +120,7 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } - -IRModule lower(te::Schedule sch, - const Array& args, - const std::string& name, +IRModule lower(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, const BuildConfig& config) { Array out_arg_list; @@ -147,8 +136,7 @@ IRModule lower(te::Schedule sch, GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); // build the function - tir::PrimFunc f = te::SchedulePostProcToPrimFunc( - out_arg_list, std::move(stmt), out_binds); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); if (config->restricted_func) { f = WithAttr(std::move(f), "tir.noalias", Integer(1)); @@ -159,8 +147,7 @@ IRModule lower(te::Schedule sch, // Phase 0 pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back( - tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); + pass_list.push_back(tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); // Phase 1 pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -170,10 +157,8 @@ IRModule lower(te::Schedule sch, pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop)); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back( - tir::transform::UnrollLoop(config->auto_unroll_max_step, - config->auto_unroll_max_depth, - config->auto_unroll_max_extent, - config->unroll_explicit)); + tir::transform::UnrollLoop(config->auto_unroll_max_step, config->auto_unroll_max_depth, + config->auto_unroll_max_extent, config->unroll_explicit)); // Phase 2 pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); @@ -189,16 +174,11 @@ IRModule lower(te::Schedule sch, return mod; } - -std::pair -split_dev_host_funcs(IRModule mod_mixed, - const Target& target, - const Target& target_host, - const BuildConfig& config) { - Array mixed_pass_list = { - BindTarget(target), - tir::transform::VerifyMemory() - }; +std::pair split_dev_host_funcs(IRModule mod_mixed, const Target& target, + const Target& target_host, + const BuildConfig& config) { + Array mixed_pass_list = {BindTarget(target), + tir::transform::VerifyMemory()}; if (config->detect_global_barrier) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } @@ -212,32 +192,30 @@ split_dev_host_funcs(IRModule mod_mixed, mod_mixed = opt_mixed(std::move(mod_mixed)); auto host_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target_host), + tir::transform::LowerTVMBuiltin(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + tir::transform::CombineContextCall(), }; auto opt_host = transform::Sequential(host_pass_list); auto mhost = opt_host(mod_mixed); // device pipeline auto device_pass_list = { - Filter([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::Simplify(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target), + tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), }; auto opt_device = transform::Sequential(device_pass_list); auto mdevice = opt_device(mod_mixed); @@ -246,33 +224,28 @@ split_dev_host_funcs(IRModule mod_mixed, auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && mdevice->functions.size() == 0) { - LOG(WARNING) << "Specified target " - << target->str() + LOG(WARNING) << "Specified target " << target->str() << " but cannot find device code. Did you forget to bind?"; } - if (target->device_type == target::llvm()->device_type && - target_host == target) { - CHECK(mdevice->functions.empty()) - << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + if (target->device_type == target::llvm()->device_type && target_host == target) { + CHECK(mdevice->functions.empty()) << "No device code should be generated when target " + << "and host_target are both llvm target." + << "\n"; } return {mhost, mdevice}; } - // Build for heterogeneous execution. -runtime::Module build(const Map& inputs, - const Target& target_host, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { std::vector device_modules; Target target_host_val = target_host; if (!target_host.defined()) { for (const auto& it : inputs) { - if (it.first->device_type == kDLCPU) { + if (it.first->device_type == kDLCPU || it.first->device_type == kDLMicroDev) { target_host_val = it.first; break; } @@ -286,8 +259,7 @@ runtime::Module build(const Map& inputs, IRModule mhost_all = IRModule(Map()); for (const auto& it : inputs) { - auto pair = - split_dev_host_funcs(it.second, it.first, target_host_val, config); + auto pair = split_dev_host_funcs(it.second, it.first, target_host_val, config); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -308,8 +280,7 @@ runtime::Module build(const Map& inputs, } // Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs, - const Target& target_host, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { Map updated_input; for (const auto& it : inputs) { @@ -323,9 +294,7 @@ runtime::Module build(const Map& inputs, } // Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, +runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host, const BuildConfig& config) { Map inputs = {{target, funcs}}; return build(inputs, target_host, config); diff --git a/src/ir/adt.cc b/src/ir/adt.cc index 4650a3bed4a7e..957905ded3cf1 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -21,14 +21,12 @@ * \file src/ir/adt.cc * \brief ADT type definitions. */ -#include #include +#include namespace tvm { -Constructor::Constructor(std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { +Constructor::Constructor(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); @@ -39,21 +37,18 @@ Constructor::Constructor(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_GLOBAL("ir.Constructor") -.set_body_typed([](std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { - return Constructor(name_hint, inputs, belong_to); -}); + .set_body_typed([](std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { + return Constructor(name_hint, inputs, belong_to); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorNode(" << node->name_hint << ", " - << node->inputs << ", " << node->belong_to << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", " + << node->belong_to << ")"; + }); -TypeData::TypeData(GlobalTypeVar header, - tvm::Array type_vars, +TypeData::TypeData(GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { ObjectPtr n = make_object(); n->header = std::move(header); @@ -65,17 +60,16 @@ TypeData::TypeData(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_GLOBAL("ir.TypeData") -.set_body_typed([](GlobalTypeVar header, - tvm::Array type_vars, - tvm::Array constructors) { - return TypeData(header, type_vars, constructors); -}); + .set_body_typed([](GlobalTypeVar header, tvm::Array type_vars, + tvm::Array constructors) { + return TypeData(header, type_vars, constructors); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " - << node->constructors << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " + << node->constructors << ")"; + }); } // namespace tvm diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index dbd5a4fab23bd..56a561b6998f9 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -32,6 +32,7 @@ #include #include + #include namespace tvm { @@ -39,16 +40,13 @@ namespace tvm { template class AttrFunctor; -#define ATTR_FUNCTOR_DEFAULT \ +#define ATTR_FUNCTOR_DEFAULT \ { return VisitAttrDefault_(op, std::forward(args)...); } - -#define ATTR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define ATTR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), std::forward(args)...); \ + }); // A functor for common attribute information. template diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index bee103d7ed206..edc81ae201b08 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -22,20 +22,16 @@ */ #include #include + #include "attr_functor.h" namespace tvm { -void DictAttrsNode::VisitAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::InitByPackedArgs( - const runtime::TVMArgs& args, bool allow_unknown) { +void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; @@ -49,9 +45,7 @@ void DictAttrsNode::InitByPackedArgs( } } -Array DictAttrsNode::ListFieldInfo() const { - return {}; -} +Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { ObjectPtr n = make_object(); @@ -60,22 +54,20 @@ DictAttrs::DictAttrs(Map dict) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dict; + }); TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") -.set_body_typed([](DictAttrs attrs) { +TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { return attrs->dict; }); -TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") -.set_body_typed([](Attrs attrs) { +TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 4d3ed30bc0320..7deff903cc1f0 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -26,16 +26,15 @@ namespace tvm { - using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "EnvFunc(" << op->name << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "EnvFunc(" << op->name << ")"; + }); ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); @@ -46,31 +45,24 @@ ObjectPtr CreateEnvNode(const std::string& name) { return n; } -EnvFunc EnvFunc::Get(const std::string& name) { - return EnvFunc(CreateEnvNode(name)); -} +EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("ir.EnvFuncGet") -.set_body_typed(EnvFunc::Get); +TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall") -.set_body([](TVMArgs args, TVMRetValue* rv) { - EnvFunc env = args[0]; - CHECK_GE(args.size(), 1); - env->func.CallPacked(TVMArgs(args.values + 1, - args.type_codes + 1, - args.size() - 1), rv); - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) { + EnvFunc env = args[0]; + CHECK_GE(args.size(), 1); + env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv); +}); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") -.set_body_typed([](const EnvFunc&n) { - return n->func; - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) { + return n->func; +}); TVM_REGISTER_NODE_TYPE(EnvFuncNode) -.set_creator(CreateEnvNode) -.set_repr_bytes([](const Object* n) -> std::string { - return static_cast(n)->name; - }); + .set_creator(CreateEnvNode) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); } // namespace tvm diff --git a/src/ir/error.cc b/src/ir/error.cc index 9d498288d2ba2..9db61a078bdb7 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -22,8 +22,8 @@ * \brief Utilities for error tracking and reporting. */ -#include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -31,13 +31,15 @@ // Rationale: use relay's printer for astext. #include +// clang-fomat off #include #include #include +// clang-format on namespace tvm { -template +template using NodeMap = std::unordered_map; void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { @@ -76,9 +78,9 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // Setup error map. auto it = error_maps.find(global); if (it != error_maps.end()) { - it->second.insert({ node, err_msg.str() }); + it->second.insert({node, err_msg.str()}); } else { - error_maps.insert({ global, { { node, err_msg.str() }}}); + error_maps.insert({global, {{node, err_msg.str()}}}); } } @@ -87,10 +89,10 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { std::stringstream annotated_prog; // First we output a header for the errors. - annotated_prog << - rang::style::bold << std::endl << - "Error(s) have occurred. The program has been annotated with them:" - << std::endl << std::endl << rang::style::reset; + annotated_prog << rang::style::bold << std::endl + << "Error(s) have occurred. The program has been annotated with them:" << std::endl + << std::endl + << rang::style::reset; // For each global function which contains errors, we will // construct an annotated function. @@ -101,11 +103,8 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // We output the name of the function before displaying // the annotated program. - annotated_prog << - rang::style::bold << - "In `" << global->name_hint << "`: " << - std::endl << - rang::style::reset; + annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl + << rang::style::reset; // We then call into the Relay printer to generate the program. // @@ -140,9 +139,9 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con if (it != this->node_to_error_.end()) { it->second.push_back(index_to_insert); } else { - this->node_to_error_.insert({ node, { index_to_insert }}); + this->node_to_error_.insert({node, {index_to_insert}}); } - this->node_to_gv_.insert({ node, global }); + this->node_to_gv_.insert({node, global}); } } // namespace tvm diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213ad406f..000305b61c269 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -21,9 +21,9 @@ * \file src/ir/expr.cc * \brief The expression AST nodes for the common IR infra. */ -#include #include #include +#include // NOTE: reverse dependency on top/tir. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -34,11 +34,9 @@ namespace tvm { -PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImm(DataType::Int(32), value)) {} +PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} -PrimExpr::PrimExpr(float value) - : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; @@ -52,17 +50,14 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { return tir::StringImmNode::make(GetRef(ptr)); } CHECK(ObjectTypeChecker::Check(ref.get())) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ref->GetTypeKey(); + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " + << ref->GetTypeKey(); return Downcast(ref); } - IntImm::IntImm(DataType dtype, int64_t value) { - CHECK(dtype.is_scalar()) - << "ValueError: IntImm can only take scalar."; - CHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm can only take scalar."; if (dtype.is_uint()) { CHECK_GE(value, 0U); } @@ -72,86 +67,75 @@ IntImm::IntImm(DataType dtype, int64_t value) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm") -.set_body_typed([](DataType dtype, int64_t value) { +TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value) { return IntImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(IntImmNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - p->stream << op->value; - } else { - p->stream << "(" << op->dtype << ")" << op->value; - } - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + p->stream << op->value; + } else { + p->stream << "(" << op->dtype << ")" << op->value; + } + }); FloatImm::FloatImm(DataType dtype, double value) { - CHECK_EQ(dtype.lanes(), 1) - << "ValueError: FloatImm can only take scalar."; + CHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm") -.set_body_typed([](DataType dtype, double value) { +TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value) { return FloatImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(FloatImmNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - auto& stream = p->stream; - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto& stream = p->stream; + switch (op->dtype.bits()) { + case 64: + stream << op->value; + break; + case 32: + stream << op->value << 'f'; + break; + case 16: + stream << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); Range::Range(PrimExpr begin, PrimExpr end) - : Range(make_object( - begin, - tir::is_zero(begin) ? end : (end - begin))) { -} + : Range(make_object(begin, tir::is_zero(begin) ? end : (end - begin))) {} Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } -TVM_REGISTER_GLOBAL("ir.range_by_min_extent") -.set_body_typed(Range::make_by_min_extent); +TVM_REGISTER_GLOBAL("ir.range_by_min_extent").set_body_typed(Range::make_by_min_extent); -TVM_REGISTER_GLOBAL("ir.Range") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Range(args[0], args[1]); - }); +}); TVM_REGISTER_NODE_TYPE(RangeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); @@ -161,57 +145,56 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) { return GlobalVar(name); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); // Container printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0 ; i < op->data.size(); ++i) { - if (i != 0) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->data.size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->data[i]); } - p->Print(op->data[i]); - } - p->stream << ']'; -}); + p->stream << ']'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->Print(it->first); + p->stream << ": "; + p->Print(it->second); } - p->Print(it->first); - p->stream << ": "; - p->Print(it->second); - } - p->stream << '}'; - }); + p->stream << '}'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + p->stream << '\"' << it->first << "\": "; + p->Print(it->second); } - p->stream << '\"' << it->first << "\": "; - p->Print(it->second); - } - p->stream << '}'; - }); + p->stream << '}'; + }); } // namespace tvm diff --git a/src/ir/function.cc b/src/ir/function.cc index 08cdc93e28b5a..57d62b4f17b57 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -21,40 +21,32 @@ * \file src/ir/function.cc * \brief The function data structure. */ -#include #include +#include // NOTE: reverse dependency on relay, tir/ // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into the type specific WithAttr function -#include #include - +#include namespace tvm { -TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") -.set_body_typed([](BaseFunc func) { - return func->attrs; -}); +TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; }); -TVM_REGISTER_GLOBAL("ir.BaseFuncCopy") -.set_body_typed([](BaseFunc func) { - return func; -}); +TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") -.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - return func; - } -}); - + .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 6262150556c7f..c7393749dc37e 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,9 +21,9 @@ * \file module.cc * \brief The global module in Relay. */ -#include #include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -32,15 +32,15 @@ #include #include -#include #include +#include #include namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set) { + std::unordered_set import_set) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -52,14 +52,14 @@ IRModule::IRModule(tvm::Map functions, for (const auto& kv : n->functions) { // set global var map CHECK(n->global_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } for (const auto& kv : n->type_definitions) { // set global typevar map CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global type definition name " << kv.first->name_hint; + << "Duplicate global type definition name " << kv.first->name_hint; n->global_type_var_map_.Set(kv.first->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); } @@ -87,9 +87,8 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { auto reduce_temp = [&]() { // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); hash_reduce(static_cast(temp.size())); // hash the content @@ -111,15 +110,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { reduce_temp(); } -bool IRModuleNode::ContainGlobalVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalVar(const String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalTypeVar(const String& name) const { return global_type_var_map_.find(name) != global_type_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -146,15 +145,15 @@ tvm::Array IRModuleNode::GetGlobalVars() const { return tvm::Array(global_vars); } -GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { +GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) - << "Cannot find global type var " << name << " in the Module"; + << "Cannot find global type var " << name << " in the Module"; return (*it).second; } -Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { +Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const { TypeData typeDef = this->LookupTypeDef(adt); for (Constructor c : typeDef->constructors) { if (cons.compare(c->name_hint) == 0) { @@ -174,7 +173,7 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -template +template tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { tvm::Array ret(l); for (const T& t : r) { @@ -184,55 +183,37 @@ tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { } // helper function to run type check -relay::Function RunTypeCheck(const IRModule& mod, - const GlobalVar& var, - relay::Function f) { +relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) { auto func = Downcast(relay::DeDup(std::move(f))); // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { - LOG(WARNING) - << "There are free variables: " - << fv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) + << std::endl; } if (ftv.size() != 0) { - LOG(WARNING) - << "There are free type variables: " - << ftv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free type variables: " << ftv + << " in function: " << AsText(func, false) << std::endl; } - func = relay::Function(concat(func->params, fv), - func->body, - func->ret_type, - concat(func->type_params, ftv), - func->attrs); + func = relay::Function(concat(func->params, fv), func->body, func->ret_type, + concat(func->type_params, ftv), func->attrs); // Type check the item before we add it to the module. relay::Function checked_func = InferType(func, mod, var); return checked_func; } -void IRModuleNode::Add(const GlobalVar& var, - const BaseFunc& f, - bool update) { +void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; if (auto* ptr = f.as()) { - checked_func = RunTypeCheck(GetRef(this), - var, - GetRef(ptr)); + checked_func = RunTypeCheck(GetRef(this), var, GetRef(ptr)); } Type type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { - CHECK(update) - << "Already have definition for " << var->name_hint; + CHECK(update) << "Already have definition for " << var->name_hint; auto old_type = functions[var]->checked_type(); CHECK(tvm::StructuralEqual()(type, old_type)) << "Module#update changes type, not possible in this mode."; @@ -241,8 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } -void IRModuleNode::AddUnchecked(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); @@ -268,36 +248,31 @@ void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData } } -void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, - const TypeData& type, - bool update) { +void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially CHECK(relay::KindCheck(type, GetRef(this)) == TypeKind::kTypeData) - << "Invalid or malformed typedata given to module: " << type; + << "Invalid or malformed typedata given to module: " << type; } -void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, +void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map CHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << var->name_hint; + << "Duplicate global type definition name " << var->name_hint; } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); } -void IRModuleNode::Update(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { this->Add(var, func, true); } -void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, - const TypeData& type) { +void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { this->AddTypeDef(var, type, true); } @@ -310,32 +285,29 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - CHECK(it != functions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -BaseFunc IRModuleNode::Lookup(const std::string& name) const { +BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - CHECK(it != type_definitions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -TypeData IRModuleNode::LookupTypeDef(const std::string& name) const { +TypeData IRModuleNode::LookupTypeDef(const String& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupTypeDef(id); } Constructor IRModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); - CHECK(it != constructor_tag_map_.end()) - << "There is no constructor with the tag " << tag; + CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; return (*it).second; } @@ -356,10 +328,9 @@ void IRModuleNode::Update(const IRModule& mod) { } } -IRModule IRModule::FromExpr( - const RelayExpr& expr, - const tvm::Map& global_funcs, - const tvm::Map& type_definitions) { +IRModule IRModule::FromExpr(const RelayExpr& expr, + const tvm::Map& global_funcs, + const tvm::Map& type_definitions) { auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; std::string gv_name = "main"; @@ -371,39 +342,35 @@ IRModule IRModule::FromExpr( } } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), - relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVar(gv_name); mod->Add(main_gv, func); return mod; } -void IRModuleNode::Import(const std::string& path) { +void IRModuleNode::Import(const String& path) { if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; std::fstream src_file(path, std::fstream::in); - std::string file_contents { - std::istreambuf_iterator(src_file), - std::istreambuf_iterator() }; + std::string file_contents{std::istreambuf_iterator(src_file), + std::istreambuf_iterator()}; auto mod_to_import = IRModule::FromText(file_contents, path); Update(mod_to_import); } } -void IRModuleNode::ImportFromStd(const std::string& path) { +void IRModuleNode::ImportFromStd(const String& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - return this->Import(std_path + "/" + path); + this->Import(std_path + "/" + path.operator std::string()); } -std::unordered_set IRModuleNode::Imports() const { - return this->import_set_; -} +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } -IRModule IRModule::FromText(const std::string& text, const std::string& source_path) { +IRModule IRModule::FromText(const String& text, const String& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; IRModule mod = (*f)(text, source_path); @@ -413,13 +380,12 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") -.set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); -}); + .set_body_typed([](tvm::Map funcs, + tvm::Map types) { + return IRModule(funcs, types, {}); + }); -TVM_REGISTER_GLOBAL("ir.Module_Add") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; GlobalVar var = args[1]; ObjectRef val = args[2]; @@ -443,75 +409,65 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef") -.set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") -.set_body_method(&IRModuleNode::GetGlobalVar); + .set_body_method(&IRModuleNode::GetGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars") -.set_body_method(&IRModuleNode::GetGlobalVars); + .set_body_method(&IRModuleNode::GetGlobalVars); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -.set_body_method(&IRModuleNode::GetGlobalTypeVars); + .set_body_method(&IRModuleNode::GetGlobalTypeVars); TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -.set_body_method(&IRModuleNode::ContainGlobalVar); + .set_body_method(&IRModuleNode::ContainGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -.set_body_method(&IRModuleNode::GetGlobalTypeVar); + .set_body_method(&IRModuleNode::GetGlobalTypeVar); -TVM_REGISTER_GLOBAL("ir.Module_Lookup") -.set_body_typed([](IRModule mod, GlobalVar var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") -.set_body_typed([](IRModule mod, std::string var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef") -.set_body_typed([](IRModule mod, GlobalTypeVar var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") -.set_body_typed([](IRModule mod, std::string var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupTag") -.set_body_typed([](IRModule mod, int32_t tag) { - return mod->LookupTag(tag); - }); +TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { + return mod->LookupTag(tag); +}); TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -.set_body_typed([](RelayExpr e, - tvm::Map funcs, - tvm::Map type_defs) { - return IRModule::FromExpr(e, funcs, type_defs); -}); + .set_body_typed([](RelayExpr e, tvm::Map funcs, + tvm::Map type_defs) { + return IRModule::FromExpr(e, funcs, type_defs); + }); -TVM_REGISTER_GLOBAL("ir.Module_Update") -.set_body_typed([](IRModule mod, IRModule from) { +TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("ir.Module_Import") -.set_body_typed([](IRModule mod, std::string path) { +TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); -TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") -.set_body_typed([](IRModule mod, std::string path) { +TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { mod->ImportFromStd(path); -});; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IRModuleNode( " << node->functions << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IRModuleNode( " << node->functions << ")"; + }); } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index b024165c1a4cb..8f587686d7c4d 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -23,8 +23,8 @@ */ #include #include -#include #include +#include #include #include @@ -37,13 +37,11 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; -::dmlc::Registry* OpRegistry::Registry() { - return ::dmlc::Registry::Get(); -} +::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); } // single manager of operator information. struct OpManager { @@ -112,9 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) { } } -void OpRegistry::UpdateAttr(const std::string& key, - TVMRetValue value, - int plevel) { +void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); std::unique_ptr& op_map = mgr->attr[key]; @@ -127,91 +123,81 @@ void OpRegistry::UpdateAttr(const std::string& key, op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); } std::pair& p = op_map->data_[index]; - CHECK(p.second != plevel) - << "Attribute " << key << " of operator " << this->name - << " is already registered with same plevel=" << plevel; + CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name + << " is already registered with same plevel=" << plevel; CHECK(value.type_code() != kTVMNullptr) - << "Registered packed_func is Null for " << key - << " of operator " << this->name; + << "Registered packed_func is Null for " << key << " of operator " << this->name; if (p.second < plevel && value.type_code() != kTVMNullptr) { op_map->data_[index] = std::make_pair(value, plevel); } } // Frontend APIs -TVM_REGISTER_GLOBAL("relay.op._ListOpNames") -.set_body_typed([]() { - Array ret; - for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(name); - } - return ret; - }); - -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); - -TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); - -TVM_REGISTER_GLOBAL("relay.op._OpSetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); - reg.set_attr(attr_name, value, plevel); - }); - -TVM_REGISTER_GLOBAL("relay.op._OpResetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); - reg.reset_attr(attr_name); - }); - -TVM_REGISTER_GLOBAL("relay.op._Register") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; +TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() { + Array ret; + for (const std::string& name : dmlc::Registry::ListAllNames()) { + ret.push_back(name); + } + return ret; +}); + +TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op { + return Op::Get(name); +}); + +TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } +}); + +TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); +}); + +TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); + reg.reset_attr(attr_name); +}); + +TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string op_name = args[0]; + std::string attr_key = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + LOG(FATAL) << "attrs type key no longer supported"; + } else { + // normal attr table override. + if (args[2].type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = args[2]; + // If we get a function from frontend, avoid deleting it. + OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); + reg.set_attr(attr_key, f, plevel); } else { - // normal attr table override. - if (args[2].type_code() == kTVMPackedFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); - } else { - reg.set_attr(attr_key, args[2], plevel); - } + reg.set_attr(attr_key, args[2], plevel); } - }); + } +}); // helper to get internal dev function in objectref. struct Op2ObjectPtr : public ObjectRef { - static ObjectPtr Get(const Op& op) { - return GetDataPtr(op); - } + static ObjectPtr Get(const Op& op) { return GetDataPtr(op); } }; ObjectPtr CreateOp(const std::string& name) { @@ -221,16 +207,14 @@ ObjectPtr CreateOp(const std::string& name) { return Op2ObjectPtr::Get(op); } -TVM_REGISTER_NODE_TYPE(OpNode) -.set_creator(CreateOp) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); +TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes([](const Object* n) { + return static_cast(n)->name; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Op(" << node->name << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Op(" << node->name << ")"; + }); } // namespace tvm diff --git a/src/ir/span.cc b/src/ir/span.cc index f84353de2a8b1..742c9858950cb 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -25,10 +25,10 @@ namespace tvm { -ObjectPtr GetSourceNameNode(const std::string& name) { +ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; + static std::unordered_map > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { @@ -41,24 +41,25 @@ ObjectPtr GetSourceNameNode(const std::string& name) { } } -SourceName SourceName::Get(const std::string& name) { - return SourceName(GetSourceNameNode(name)); +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { + return GetSourceNameNode(name); } -TVM_REGISTER_GLOBAL("ir.SourceName") -.set_body_typed(SourceName::Get); +SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } + +TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SourceName(" << node->name << ", " << node << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(GetSourceNameNode) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); + .set_creator(GetSourceNameNodeByStr) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { auto n = make_object(); @@ -70,13 +71,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span") -.set_body_typed(SpanNode::make); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->lineno << ", " - << node->col_offset << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset + << ")"; + }); } // namespace tvm diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 92f0ea2294212..0fab0acb89644 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -21,8 +21,8 @@ * \file src/ir/tensor_type.cc * \brief The type system AST nodes of Relay. */ -#include #include +#include #include namespace tvm { @@ -37,9 +37,7 @@ TensorType::TensorType(Array shape, DataType dtype) { data_ = std::move(n); } -TensorType TensorType::Scalar(DataType dtype) { - return TensorType({}, dtype); -} +TensorType TensorType::Scalar(DataType dtype) { return TensorType({}, dtype); } PrimExpr TensorTypeNode::Size() const { if (shape.size() == 0) { @@ -55,15 +53,14 @@ PrimExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("ir.TensorType") -.set_body_typed([](Array shape, DataType dtype) { +TVM_REGISTER_GLOBAL("ir.TensorType").set_body_typed([](Array shape, DataType dtype) { return TensorType(shape, dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c1547d5205a4d..d7d9b063aa126 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -22,11 +22,11 @@ * \brief Infrastructure for transformation passes. */ #include -#include +#include +#include #include #include -#include -#include +#include // TODO(tqchen): Update to use String container after it is merged. #include @@ -37,9 +37,9 @@ namespace tvm { namespace transform { +using tvm::ReprPrinter; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; -using tvm::ReprPrinter; struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -48,32 +48,26 @@ struct PassContextThreadLocalEntry { /*! \brief The current pass context. */ std::stack context_stack; - PassContextThreadLocalEntry() { - default_context = PassContext(make_object()); - } + PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } }; /*! \brief Thread local store to hold the pass context. */ -typedef dmlc::ThreadLocalStore - RelayPassContextThreadLocalStore; +typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } PassContext PassContext::Current() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { @@ -81,15 +75,13 @@ PassContext PassContext::Current() { } } -PassContext PassContext::Create() { - return PassContext(make_object()); -} +PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { - auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); - } + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->trace_func != nullptr) { + pass_ctx_node->trace_func(module, info, is_before); + } } class ModulePass; @@ -114,9 +106,7 @@ class ModulePassNode : public PassNode { ModulePassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a module pass on given pass context. @@ -211,9 +201,7 @@ class SequentialNode : public PassNode { TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; -PassInfo::PassInfo(int opt_level, - std::string name, - tvm::Array required) { +PassInfo::PassInfo(int opt_level, std::string name, tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -221,9 +209,8 @@ PassInfo::PassInfo(int opt_level, data_ = std::move(pass_info); } -ModulePass::ModulePass( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { +ModulePass::ModulePass(runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); @@ -231,13 +218,10 @@ ModulePass::ModulePass( } // Module -> Module optimizations. -IRModule ModulePassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); - DLOG(INFO) << "Executing module pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing module pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -307,20 +291,18 @@ Pass GetPass(const std::string& pass_name) { // pass } else if ((f = Registry::Get("relay._transform." + pass_name))) { } - CHECK(f != nullptr) << "Cannot use " << pass_name - << "to create the pass"; + CHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass"; return (*f)(); } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. -IRModule SequentialNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); @@ -330,11 +312,9 @@ IRModule SequentialNode::operator()(IRModule mod, return mod; } -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { +Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, const std::string& name, + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -342,55 +322,50 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") -.set_body_typed([](int opt_level, std::string name, tvm::Array required) { - return PassInfo(opt_level, name, required); -}); + .set_body_typed([](int opt_level, std::string name, tvm::Array required) { + return PassInfo(opt_level, name, required); + }); -TVM_REGISTER_GLOBAL("transform.Info") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "The meta data of the pass: "; - p->stream << "pass name: " << node->name; - p->stream << "opt_level: " << node->opt_level; - p->stream << "required passes: [" << "\n"; - for (const auto& it : node->required) { - p->stream << it << ", "; - } - p->stream << "]\n"; -}); + .set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "The meta data of the pass: "; + p->stream << "pass name: " << node->name; + p->stream << "opt_level: " << node->opt_level; + p->stream << "required passes: [" + << "\n"; + for (const auto& it : node->required) { + p->stream << it << ", "; + } + p->stream << "]\n"; + }); TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("transform.MakeModulePass") -.set_body_typed( - [](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return ModulePass(pass_func, pass_info); -}); + .set_body_typed([](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("transform.RunPass") -.set_body_typed([](Pass pass, IRModule mod) { +TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) { return pass(std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Module pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name << " at the optimization level " + << info->opt_level; + }); TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; @@ -400,23 +375,22 @@ TVM_REGISTER_GLOBAL("transform.Sequential") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Sequential pass: " << info->name - << " at the optimization level " << info->opt_level << ". "; - p->stream << "The passes will be executed are: ["; - for (const auto& it : node->passes) { - const PassInfo pass_info = it->Info(); - p->stream << pass_info->name << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name << " at the optimization level " + << info->opt_level << ". "; + p->stream << "The passes will be executed are: ["; + for (const auto& it : node->passes) { + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; + } + p->stream << "]"; + }); TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.PassContext").set_body([](TVMArgs args, TVMRetValue* ret) { auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; @@ -432,59 +406,48 @@ TVM_REGISTER_GLOBAL("transform.PassContext") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Pass context information: " << "\n"; - p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " - << runtime::DeviceName(node->fallback_device) - << "\n"; - - p->stream << "\trequired passes: [" << node->opt_level; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; - - p->stream << "\tdisabled passes: [" << node->opt_level; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Pass context information: " + << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; + + p->stream << "\trequired passes: [" << node->opt_level; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: [" << node->opt_level; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]"; + }); class PassContext::Internal { public: - static void EnterScope(PassContext pass_ctx) { - pass_ctx.EnterWithScope(); - } + static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } - static void ExitScope(PassContext pass_ctx) { - pass_ctx.ExitWithScope(); - } + static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext") -.set_body_typed(PassContext::Current); - -TVM_REGISTER_GLOBAL("transform.EnterPassContext") -.set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("transform.ExitPassContext") -.set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); Pass PrintIR(std::string header, bool show_meta_data) { - auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) { - LOG(INFO) << "PrintIR(" << header << "):\n" - << AsText(mod, show_meta_data); + auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); } -TVM_REGISTER_GLOBAL("transform.PrintIR") -.set_body_typed(PrintIR); +TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 5b038218c1272..212a6e5ea1bc5 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -33,17 +33,15 @@ PrimType::PrimType(runtime::DataType dtype) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("ir.PrimType") -.set_body_typed([](runtime::DataType dtype) { +TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << node->dtype; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->dtype; + }); PointerType::PointerType(Type element_type) { ObjectPtr n = make_object(); @@ -53,20 +51,18 @@ PointerType::PointerType(Type element_type) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_REGISTER_GLOBAL("ir.PointerType") -.set_body_typed([](Type element_type) { +TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) { return PointerType(element_type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->Print(node->element_type); - p->stream << '*'; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->element_type); + p->stream << '*'; + }); - -TypeVar::TypeVar(std::string name, TypeKind kind) { +TypeVar::TypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); @@ -75,18 +71,15 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_GLOBAL("ir.TypeVar") -.set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { ObjectPtr n = make_object(); @@ -97,21 +90,17 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalTypeVar") -.set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); -FuncType::FuncType(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, +FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints) { ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); @@ -124,21 +113,17 @@ FuncType::FuncType(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_GLOBAL("ir.FuncType") -.set_body_typed([](tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { - return FuncType(arg_types, ret_type, type_params, type_constraints); -}); + .set_body_typed([](tvm::Array arg_types, Type ret_type, tvm::Array type_params, + tvm::Array type_constraints) { + return FuncType(arg_types, ret_type, type_params, type_constraints); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncType(" << node->type_params << ", " - << node->arg_types << ", " << node->ret_type << ", " - << node->type_constraints << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " + << node->ret_type << ", " << node->type_constraints << ")"; + }); TupleType::TupleType(Array fields) { ObjectPtr n = make_object(); @@ -146,23 +131,19 @@ TupleType::TupleType(Array fields) { data_ = std::move(n); } -TupleType TupleType::Empty() { - return TupleType(Array()); -} +TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType") -.set_body_typed([](Array fields) { +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleTypeNode(" << node->fields << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); IncompleteType::IncompleteType(TypeKind kind) { auto n = make_object(); @@ -172,17 +153,15 @@ IncompleteType::IncompleteType(TypeKind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_GLOBAL("ir.IncompleteType") -.set_body_typed([](int kind) { - return IncompleteType(static_cast(kind)); - }); +TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) { + return IncompleteType(static_cast(kind)); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); RelayRefType::RelayRefType(Type value) { ObjectPtr n = make_object(); @@ -190,17 +169,16 @@ RelayRefType::RelayRefType(Type value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.RelayRefType") -.set_body_typed([](Type value) { +TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) { return RelayRefType(value); }); TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RelayRefTypeNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RelayRefTypeNode(" << node->value << ")"; + }); } // namespace tvm diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 9d9167fa1c0f5..21ce3d09d2aef 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -22,18 +22,16 @@ * \brief Implementations of type functors. */ #include + #include namespace tvm { -void TypeVisitor::VisitType_(const TypeVarNode* op) { -} +void TypeVisitor::VisitType_(const TypeVarNode* op) {} -void TypeVisitor::VisitType_(const TensorTypeNode* op) { -} +void TypeVisitor::VisitType_(const TensorTypeNode* op) {} -void TypeVisitor::VisitType_(const IncompleteTypeNode* op) { -} +void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {} void TypeVisitor::VisitType_(const FuncTypeNode* op) { for (auto type_param : op->type_params) { @@ -56,9 +54,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } -void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { - this->VisitType(op->value); -} +void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); } void TypeVisitor::VisitType_(const TypeRelationNode* op) { for (const Type& t : op->args) { @@ -66,8 +62,7 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } -void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { -} +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {} void TypeVisitor::VisitType_(const TypeCallNode* op) { this->VisitType(op->func); @@ -90,12 +85,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } -void TypeVisitor::VisitType_(const PrimTypeNode* op) { -} +void TypeVisitor::VisitType_(const PrimTypeNode* op) {} -void TypeVisitor::VisitType_(const PointerTypeNode* op) { - this->VisitType(op->element_type); -} +void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); } Type TypeMutator::VisitType(const Type& t) { return t.defined() ? TypeFunctor::VisitType(t) : t; @@ -115,18 +107,14 @@ Array TypeMutator::MutateArray(Array arr) { return arr; } -Type TypeMutator::VisitType_(const TypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TensorTypeNode* op) { // TODO(tvm-team) recursively visit to replace Var return GetRef(op); } -Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const FuncTypeNode* op) { bool changed = false; @@ -145,8 +133,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { for (auto type_cs : op->type_constraints) { auto new_type_cs = VisitType(type_cs); changed = changed || !new_type_cs.same_as(type_cs); - if (const TypeConstraintNode* tin = - new_type_cs.as()) { + if (const TypeConstraintNode* tin = new_type_cs.as()) { type_constraints.push_back(GetRef(tin)); } else { LOG(FATAL) << new_type_cs; @@ -160,10 +147,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { changed = changed || !new_ret_type.same_as(op->ret_type); if (!changed) return GetRef(op); - return FuncType(new_args, - new_ret_type, - type_params, - type_constraints); + return FuncType(new_args, new_ret_type, type_params, type_constraints); } Type TypeMutator::VisitType_(const TupleTypeNode* op) { @@ -184,16 +168,11 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { if (new_args.same_as(type_rel->args)) { return GetRef(type_rel); } else { - return TypeRelation(type_rel->func, - new_args, - type_rel->num_inputs, - type_rel->attrs); + return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs); } } -Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TypeCallNode* op) { Type new_func = VisitType(op->func); @@ -205,13 +184,9 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) { } } -Type TypeMutator::VisitType_(const TypeDataNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef(op); } -Type TypeMutator::VisitType_(const PrimTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const PointerTypeNode* op) { Type element_type = VisitType(op->element_type); @@ -226,8 +201,7 @@ Type TypeMutator::VisitType_(const PointerTypeNode* op) { // Implements bind. class TypeBinder : public TypeMutator { public: - explicit TypeBinder(const tvm::Map& args_map) - : args_map_(args_map) {} + explicit TypeBinder(const tvm::Map& args_map) : args_map_(args_map) {} Type VisitType_(const TypeVarNode* op) override { auto id = GetRef(op); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index ab479e782b563..f038a6678b425 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -35,22 +35,17 @@ TypeCall::TypeCall(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_GLOBAL("ir.TypeCall") -.set_body_typed([](Type func, Array type) { +TVM_REGISTER_GLOBAL("ir.TypeCall").set_body_typed([](Type func, Array type) { return TypeCall(func, type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeCallNode(" << node->func << ", " - << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); -TypeRelation::TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { +TypeRelation::TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); @@ -62,18 +57,13 @@ TypeRelation::TypeRelation(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_GLOBAL("ir.TypeRelation") -.set_body_typed([](TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { - return TypeRelation(func, args, num_inputs, attrs); -}); + .set_body_typed([](TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { + return TypeRelation(func, args, num_inputs, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeRelationNode(" - << node->func->name - << ", " << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeRelationNode(" << node->func->name << ", " << node->args << ")"; + }); } // namespace tvm diff --git a/src/node/container.cc b/src/node/container.cc index 52e4bf19718c7..a5e7669fc66de 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -20,10 +20,11 @@ * Expose container API to frontend. * \file src/node/container.cc */ -#include -#include #include +#include +#include #include + #include "../support/str_escape.h" namespace tvm { @@ -32,14 +33,11 @@ namespace tvm { struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::StringObj* key, - SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes(key->data, key->size)); + static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); } - static bool SEqualReduce(const runtime::StringObj* lhs, - const runtime::StringObj* rhs, + static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->size != rhs->size) return false; @@ -49,32 +47,29 @@ struct StringObjTrait { }; struct RefToObjectPtr : public ObjectRef { - static ObjectPtr Get(const ObjectRef& ref) { - return GetDataPtr(ref); - } + static ObjectPtr Get(const ObjectRef& ref) { return GetDataPtr(ref); } }; TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) -.set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); -}) -.set_repr_bytes([](const Object* n) -> std::string { - return GetRef( - static_cast(n)).operator std::string(); -}); + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef(static_cast(n)) + . + operator std::string(); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; + }); struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::ADTObj* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { hash_reduce(key->tag); hash_reduce(static_cast(key->size)); for (uint32_t i = 0; i < key->size; ++i) { @@ -82,8 +77,7 @@ struct ADTObjTrait { } } - static bool SEqualReduce(const runtime::ADTObj* lhs, - const runtime::ADTObj* rhs, + static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->tag != rhs->tag) return false; @@ -98,39 +92,31 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); - struct NDArrayContainerTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::NDArray::Container* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(key->dl_tensor)) - << "Can only hash contiguous tensor"; + CHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; hash_reduce(runtime::DataType(key->dl_tensor.dtype)); hash_reduce(key->dl_tensor.ndim); for (int i = 0; i < key->dl_tensor.ndim; ++i) { hash_reduce(key->dl_tensor.shape[i]); } - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes( - static_cast(key->dl_tensor.data), - runtime::GetDataSize(key->dl_tensor))); + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); } static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { + const runtime::NDArray::Container* rhs, SEqualReducer equal) { if (lhs == rhs) return true; auto ldt = lhs->dl_tensor.dtype; auto rdt = rhs->dl_tensor.dtype; CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(lhs->dl_tensor)) - << "Can only compare contiguous tensor"; - CHECK(runtime::IsContiguous(rhs->dl_tensor)) - << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { @@ -147,21 +133,17 @@ struct NDArrayContainerTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); - struct ArrayNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ArrayNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { hash_reduce(static_cast(key->data.size())); for (size_t i = 0; i < key->data.size(); ++i) { hash_reduce(key->data[i]); } } - static bool SEqualReduce(const ArrayNode* lhs, - const ArrayNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { if (lhs->data.size() != rhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { if (!equal(lhs->data[i], rhs->data[i])) return false; @@ -172,53 +154,45 @@ struct ArrayNodeTrait { TVM_REGISTER_OBJECT_TYPE(ArrayNode); TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Array") -.set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); } - auto node = make_object(); - node->data = std::move(data); - *ret = Array(node); - }); + } + auto node = make_object(); + node->data = std::move(data); + *ret = Array(node); +}); -TVM_REGISTER_GLOBAL("node.ArrayGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - CHECK_LT(static_cast(i), n->data.size()) - << "out of bound of array"; - *ret = n->data[static_cast(i)]; - }); - -TVM_REGISTER_GLOBAL("node.ArraySize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - *ret = static_cast( - static_cast(ptr)->data.size()); - }); +TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; + *ret = n->data[static_cast(i)]; +}); +TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->data.size()); +}); struct MapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const MapNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store @@ -233,15 +207,15 @@ struct MapNodeTrait { } } // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // add size to the hash hash_reduce(static_cast(key->data.size())); // hash the content for (size_t i = 0; i < temp.size();) { size_t k = i + 1; - for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {} + for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { + } // ties are rare, but we need to skip them to make the hash determinsitic if (k == i + 1) { hash_reduce->SHashReduceHashedValue(temp[i].first); @@ -251,9 +225,7 @@ struct MapNodeTrait { } } - static bool SEqualReduce(const MapNode* lhs, - const MapNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { // Only allow equal checking if the keys are already mapped @@ -272,16 +244,14 @@ struct MapNodeTrait { TVM_REGISTER_OBJECT_TYPE(MapNode); TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); struct StrMapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const StrMapNode* key, - SHashReducer hash_reduce) { + static void SHashReduce(const StrMapNode* key, SHashReducer hash_reduce) { // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store // Map where Var is defined in the function @@ -289,9 +259,8 @@ struct StrMapNodeTrait { using KV = std::pair; std::vector temp(key->data.begin(), key->data.end()); // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // NOTE: we won't have ties // add size to the hash after sorting. hash_reduce(static_cast(key->data.size())); @@ -302,9 +271,7 @@ struct StrMapNodeTrait { } } - static bool SEqualReduce(const StrMapNode* lhs, - const StrMapNode* rhs, - SEqualReducer equal) { + static bool SEqualReduce(const StrMapNode* lhs, const StrMapNode* rhs, SEqualReducer equal) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); @@ -317,120 +284,104 @@ struct StrMapNodeTrait { TVM_REGISTER_OBJECT_TYPE(StrMapNode); TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Map") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size() % 2, 0); - if (args.size() != 0 && args[0].type_code() == kTVMStr) { - // StrMap - StrMapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kTVMStr) - << "key of str map need to be str"; - CHECK(args[i + 1].IsObjectRef()) - << "value of the map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); - } else { - // Container node. - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].IsObjectRef()) - << "key of str map need to be object"; - CHECK(args[i + 1].IsObjectRef()) - << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator ObjectRef(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size() % 2, 0); + if (args.size() != 0 && args[0].type_code() == kTVMStr) { + // StrMap + StrMapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str"; + CHECK(args[i + 1].IsObjectRef()) << "value of the map to be NodeRef"; + data.emplace( + std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); + } + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); + } else { + // Container node. + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].IsObjectRef()) << "key of str map need to be object"; + CHECK(args[i + 1].IsObjectRef()) << "value of map to be NodeRef"; + data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } - }); + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); + } +}); +TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + *ret = static_cast(n->data.size()); + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->data.size()); + } +}); -TVM_REGISTER_GLOBAL("node.MapSize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } - }); +TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator std::string()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } +}); -TVM_REGISTER_GLOBAL("node.MapGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator ObjectRef()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator std::string()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } - }); +TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); -TVM_REGISTER_GLOBAL("node.MapCount") -.set_body([](TVMArgs args, TVMRetValue* ret) { + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); + *ret = static_cast(n->data.count(args[1].operator ObjectRef())); + } else { + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->data.count(args[1].operator std::string())); + } +}); - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - *ret = static_cast( - n->data.count(args[1].operator ObjectRef())); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast( - n->data.count(args[1].operator std::string())); - } - }); +TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); -TVM_REGISTER_GLOBAL("node.MapItems") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); - } else { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(tir::StringImmNode::make(kv.first)); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + auto rkvs = make_object(); + for (const auto& kv : n->data) { + rkvs->data.push_back(kv.first); + rkvs->data.push_back(kv.second); + } + *ret = Array(rkvs); + } else { + auto* n = static_cast(ptr); + auto rkvs = make_object(); + for (const auto& kv : n->data) { + rkvs->data.push_back(tir::StringImmNode::make(kv.first)); + rkvs->data.push_back(kv.second); } - }); + *ret = Array(rkvs); + } +}); } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 08a914ff38f9f..c3397e7500c17 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -21,17 +21,17 @@ * Reflection utilities. * \file node/reflection.cc */ -#include -#include +#include #include +#include #include -#include +#include namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; // Attr getter. class AttrGetter : public AttrVisitor { @@ -39,9 +39,7 @@ class AttrGetter : public AttrVisitor { const std::string& skey; TVMRetValue* ret; - AttrGetter(const std::string &skey, - TVMRetValue* ret) - : skey(skey), ret(ret) {} + AttrGetter(const std::string& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} bool found_ref_object{false}; @@ -86,8 +84,7 @@ class AttrGetter : public AttrVisitor { } }; -runtime::TVMRetValue ReflectionVTable::GetAttr( - Object* self, const std::string& field_name) const { +runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const std::string& field_name) const { runtime::TVMRetValue ret; AttrGetter getter(field_name, &ret); @@ -110,8 +107,8 @@ runtime::TVMRetValue ReflectionVTable::GetAttr( } } if (!success) { - LOG(FATAL) << "AttributeError: " << self->GetTypeKey() - << " object has no attributed " << getter.skey; + LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed " + << getter.skey; } return ret; } @@ -121,40 +118,19 @@ class AttrDir : public AttrVisitor { public: std::vector* names; - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, void** value) final { - names->push_back(key); - } - void Visit(const char* key, DataType* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::NDArray* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::ObjectRef* value) final { - names->push_back(key); - } + void Visit(const char* key, double* value) final { names->push_back(key); } + void Visit(const char* key, int64_t* value) final { names->push_back(key); } + void Visit(const char* key, uint64_t* value) final { names->push_back(key); } + void Visit(const char* key, bool* value) final { names->push_back(key); } + void Visit(const char* key, int* value) final { names->push_back(key); } + void Visit(const char* key, void** value) final { names->push_back(key); } + void Visit(const char* key, DataType* value) final { names->push_back(key); } + void Visit(const char* key, std::string* value) final { names->push_back(key); } + void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } + void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } }; -std::vector -ReflectionVTable::ListAttrNames(Object* self) const { +std::vector ReflectionVTable::ListAttrNames(Object* self) const { std::vector names; AttrDir dir; dir.names = &names; @@ -176,13 +152,11 @@ ReflectionVTable* ReflectionVTable::Global() { return &inst; } -ObjectPtr -ReflectionVTable::CreateInitObject(const std::string& type_key, - const std::string& repr_bytes) const { +ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, + const std::string& repr_bytes) const { uint32_t tindex = Object::TypeKey2Index(type_key); if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { - LOG(FATAL) << "TypeError: " << type_key - << " is not registered via TVM_REGISTER_NODE_TYPE"; + LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } return fcreate_[tindex](repr_bytes); } @@ -192,30 +166,16 @@ class NodeAttrSetter : public AttrVisitor { std::string type_key; std::unordered_map attrs; - void Visit(const char* key, double* value) final { - *value = GetAttr(key).operator double(); - } - void Visit(const char* key, int64_t* value) final { - *value = GetAttr(key).operator int64_t(); - } - void Visit(const char* key, uint64_t* value) final { - *value = GetAttr(key).operator uint64_t(); - } - void Visit(const char* key, int* value) final { - *value = GetAttr(key).operator int(); - } - void Visit(const char* key, bool* value) final { - *value = GetAttr(key).operator bool(); - } + void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); } + void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); } + void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); } + void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); } + void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); } void Visit(const char* key, std::string* value) final { *value = GetAttr(key).operator std::string(); } - void Visit(const char* key, void** value) final { - *value = GetAttr(key).operator void*(); - } - void Visit(const char* key, DataType* value) final { - *value = GetAttr(key).operator DataType(); - } + void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); } + void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); } void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } @@ -240,8 +200,7 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { setter.type_key = n->GetTypeKey(); CHECK_EQ(args.size() % 2, 0); for (int i = 0; i < args.size(); i += 2) { - setter.attrs.emplace(args[i].operator std::string(), - args[i + 1]); + setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); } auto* reflection = ReflectionVTable::Global(); reflection->VisitAttrs(n, &setter); @@ -249,7 +208,7 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { if (setter.attrs.size() != 0) { std::ostringstream os; os << setter.type_key << " does not contain field "; - for (const auto &kv : setter.attrs) { + for (const auto& kv : setter.attrs) { os << " " << kv.first; } LOG(FATAL) << os.str(); @@ -267,17 +226,17 @@ void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); Object* self = static_cast(args[0].value().v_handle); - auto names = std::make_shared >( - ReflectionVTable::Global()->ListAttrNames(self)); + auto names = + std::make_shared >(ReflectionVTable::Global()->ListAttrNames(self)); - *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { - int64_t i = args[0]; - if (i == -1) { - *rv = static_cast(names->size()); - } else { - *rv = (*names)[i]; - } - }); + *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } // API function to make node. @@ -297,13 +256,9 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { *rv = ObjectRef(n); } +TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); -TVM_REGISTER_GLOBAL("node.NodeGetAttr") -.set_body(NodeGetAttr); - -TVM_REGISTER_GLOBAL("node.NodeListAttrNames") -.set_body(NodeListAttrNames); +TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames); -TVM_REGISTER_GLOBAL("node.MakeNode") -.set_body(MakeNode); +TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode); } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index bf41c82f5a76b..ea263439023fd 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -21,8 +21,8 @@ * Printer utilities * \file node/repr_printer.cc */ -#include #include +#include namespace tvm { @@ -51,16 +51,11 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } -void Dump(const runtime::ObjectRef& n) { - std::cerr << n << "\n"; -} +void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } -void Dump(const runtime::Object* n) { - Dump(runtime::GetRef(n)); -} +void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_REGISTER_GLOBAL("node.AsRepr") -.set_body_typed([](runtime::ObjectRef obj) { +TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) { std::ostringstream os; os << obj; return os.str(); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index ee6072d77c1c3..4675c5339f8df 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -23,29 +23,25 @@ */ #include #include -#include -#include -#include +#include #include #include #include -#include +#include +#include +#include -#include #include #include +#include #include "../support/base64.h" namespace tvm { -inline std::string Type2String(const DataType& t) { - return runtime::DLDataType2String(t); -} +inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); } -inline DataType String2Type(std::string s) { - return DataType(runtime::String2DLDataType(s)); -} +inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } inline std::string Base64Decode(std::string s) { dmlc::MemoryStringStream mstrm(&s); @@ -148,7 +144,7 @@ struct JSONNode { /*! \brief values of a map or array. */ std::vector data; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); if (repr_bytes.size() != 0) { @@ -173,7 +169,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); data.clear(); repr_bytes.clear(); @@ -213,36 +209,23 @@ class JSONAttrGetter : public AttrVisitor { s << (*value); node_->attrs[key] = s.str(); } - void Visit(const char* key, int64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, uint64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, bool* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, std::string* value) final { - node_->attrs[key] = *value; - } + void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to serialize a pointer"; } - void Visit(const char* key, DataType* value) final { - node_->attrs[key] = Type2String(*value); - } + void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = std::to_string( - tensor_index_->at(const_cast((*value).operator->()))); + node_->attrs[key] = + std::to_string(tensor_index_->at(const_cast((*value).operator->()))); } void Visit(const char* key, ObjectRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(const_cast(value->get()))); + node_->attrs[key] = std::to_string(node_index_->at(const_cast(value->get()))); } // Get the node @@ -262,23 +245,19 @@ class JSONAttrGetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back( - node_index_->at(const_cast(n->data[i].get()))); + node_->data.push_back(node_index_->at(const_cast(n->data[i].get()))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); for (const auto& kv : n->data) { - node_->data.push_back( - node_index_->at(const_cast(kv.first.get()))); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.first.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else if (node->IsInstance()) { StrMapNode* n = static_cast(node); for (const auto& kv : n->data) { node_->keys.push_back(kv.first); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else { // recursively index normal object. @@ -304,7 +283,7 @@ class JSONAttrSetter : public AttrVisitor { } return it->second; } - template + template void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); is >> *value; @@ -312,24 +291,12 @@ class JSONAttrSetter : public AttrVisitor { LOG(FATAL) << "Wrong value format for field " << key; } } - void Visit(const char* key, double* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, uint64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int* value) final { - ParseValue(key, value); - } - void Visit(const char* key, bool* value) final { - ParseValue(key, value); - } - void Visit(const char* key, std::string* value) final { - *value = GetValue(key); - } + void Visit(const char* key, double* value) final { ParseValue(key, value); } + void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, int* value) final { ParseValue(key, value); } + void Visit(const char* key, bool* value) final { ParseValue(key, value); } + void Visit(const char* key, std::string* value) final { *value = GetValue(key); } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to deserialize a pointer"; } @@ -363,15 +330,14 @@ class JSONAttrSetter : public AttrVisitor { MapNode* n = static_cast(node); CHECK_EQ(node_->data.size() % 2, 0U); for (size_t i = 0; i < node_->data.size(); i += 2) { - n->data[ObjectRef(node_list_->at(node_->data[i]))] - = ObjectRef(node_list_->at(node_->data[i + 1])); + n->data[ObjectRef(node_list_->at(node_->data[i]))] = + ObjectRef(node_list_->at(node_->data[i + 1])); } } else if (node->IsInstance()) { StrMapNode* n = static_cast(node); CHECK_EQ(node_->data.size(), node_->keys.size()); for (size_t i = 0; i < node_->data.size(); ++i) { - n->data[node_->keys[i]] - = ObjectRef(node_list_->at(node_->data[i])); + n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } } else { reflection_->VisitAttrs(node, this); @@ -390,7 +356,7 @@ struct JSONGraph { // global attributes AttrMap attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); @@ -401,7 +367,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("root", &root); @@ -471,8 +437,7 @@ ObjectRef LoadJSON(std::string json_str) { for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { - ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); + ObjectPtr node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); nodes.emplace_back(node); } else { nodes.emplace_back(ObjectPtr()); @@ -488,8 +453,7 @@ ObjectRef LoadJSON(std::string json_str) { // Skip the nodes that has an repr bytes representation. // NOTE: the second condition is used to guard the case // where the repr bytes itself is an empty string "". - if (setter.node_->repr_bytes.length() == 0 && - nodes[i] != nullptr && + if (setter.node_->repr_bytes.length() == 0 && nodes[i] != nullptr && !reflection->GetReprBytes(nodes[i].get(), nullptr)) { setter.Set(nodes[i].get()); } @@ -497,9 +461,7 @@ ObjectRef LoadJSON(std::string json_str) { return ObjectRef(nodes.at(jgraph.root)); } -TVM_REGISTER_GLOBAL("node.SaveJSON") -.set_body_typed(SaveJSON); +TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); -TVM_REGISTER_GLOBAL("node.LoadJSON") -.set_body_typed(LoadJSON); +TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 03cdf9c1e429d..b1353154dd7b9 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,10 +19,10 @@ /*! * \file src/node/structural_equal.cc */ -#include -#include #include #include +#include +#include #include #include @@ -30,13 +30,13 @@ namespace tvm { // Define the dispatch functio here since primary user is in this file. -bool ReflectionVTable:: -SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const { +bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, + SEqualReducer equal) const { uint32_t tindex = self->type_index(); if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE." - << " Did you forget to set _type_has_method_sequal_reduce=true?"; + << " is not registered via TVM_REGISTER_NODE_TYPE." + << " Did you forget to set _type_has_method_sequal_reduce=true?"; } return fsequal_reduce_[tindex](self, other, equal); } @@ -50,11 +50,9 @@ SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const * The order of SEqual being called is the same as the order as if we * eagerly do recursive calls in SEqualReduce. */ -class RemapVarSEqualHandler : - public SEqualReducer::Handler { +class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode) - : assert_mode_(assert_mode) {} + explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { // We cannot use check lhs.same_as(rhs) to check equality. @@ -121,9 +119,8 @@ class RemapVarSEqualHandler : // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) - << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" + << "lhs = " << lhs << "\nrhs = " << rhs; } return result; } @@ -177,9 +174,7 @@ class RemapVarSEqualHandler : // The default equal as registered in the structural equal vtable. bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { auto compute = [=]() { - CHECK(lhs.defined() && - rhs.defined() && - lhs->type_index() == rhs->type_index()); + CHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); // skip entries that already have equality maps. auto it = equal_map_lhs_.find(lhs); if (it != equal_map_lhs_.end()) { @@ -227,15 +222,12 @@ class RemapVarSEqualHandler : }; TVM_REGISTER_GLOBAL("node.StructuralEqual") -.set_body_typed([](const ObjectRef& lhs, - const ObjectRef& rhs, - bool assert_mode, - bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); -}); - -bool StructuralEqual::operator()(const ObjectRef& lhs, - const ObjectRef& rhs) const { + .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, + bool map_free_vars) { + return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); + }); + +bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index a29340c931a46..91a252403f6a0 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -19,25 +19,23 @@ /*! * \file src/node/structural_hash.cc */ -#include -#include #include #include +#include +#include #include -#include #include - +#include namespace tvm { // Define the dispatch functio here since primary user is in this file. -void ReflectionVTable:: -SHashReduce(const Object* self, SHashReducer reducer) const { +void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { uint32_t tindex = self->type_index(); if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE"; + << " is not registered via TVM_REGISTER_NODE_TYPE"; } fshash_reduce_[tindex](self, reducer); } @@ -49,8 +47,7 @@ SHashReduce(const Object* self, SHashReducer reducer) const { // In particular, when we traverse unordered_map, we should first sort // the entries by keys(or hash of keys) before traversing. -class VarCountingSHashHandler : - public SHashReducer::Handler { +class VarCountingSHashHandler : public SHashReducer::Handler { public: /*! \brief Pending reduce tasks. */ struct Task { @@ -76,7 +73,6 @@ class VarCountingSHashHandler : : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} }; - VarCountingSHashHandler() {} void MarkGraphNode() final { @@ -95,8 +91,7 @@ class VarCountingSHashHandler : } void SHashReduceHashedValue(size_t hashed_value) final { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), hashed_value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); } void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final { @@ -104,13 +99,11 @@ class VarCountingSHashHandler : if (map_free_vars) { // use counter value. size_t value = std::hash()(free_var_counter_++); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } else { // use pointer hash size_t value = std::hash()(var); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } } @@ -124,12 +117,10 @@ class VarCountingSHashHandler : } auto it = hash_memo_.find(object); if (it != hash_memo_.end()) { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), it->second, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false)); } else { // Push a pending task with initial value. - pending_tasks_.emplace_back( - Task(object, object->GetTypeKeyHash(), map_free_vars)); + pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars)); } } @@ -195,9 +186,8 @@ class VarCountingSHashHandler : // Append the graph node counter to the hash // so that we can distinguish DAG from trees. if (entry.graph_node_hash) { - entry.reduced_hash = HashCombine( - entry.reduced_hash, - std::hash()(graph_node_counter_++)); + entry.reduced_hash = + HashCombine(entry.reduced_hash, std::hash()(graph_node_counter_++)); } hash_memo_[entry.object] = entry.reduced_hash; } @@ -268,13 +258,11 @@ class VarCountingSHashHandler : std::unordered_map hash_memo_; }; - TVM_REGISTER_GLOBAL("node.StructuralHash") -.set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = - VarCountingSHashHandler().Hash(object, map_free_vars); - return static_cast(hashed_value); -}); + .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { + size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars); + return static_cast(hashed_value); + }); size_t StructuralHash::operator()(const ObjectRef& object) const { return VarCountingSHashHandler().Hash(object, false); diff --git a/src/printer/doc.cc b/src/printer/doc.cc index ee260f41df55e..d487e3e7aa3e6 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -23,10 +23,12 @@ * * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 */ +#include "doc.h" + #include -#include + #include -#include "doc.h" +#include namespace tvm { @@ -38,9 +40,7 @@ class DocTextNode : public DocAtomNode { /*! \brief The str content in the text. */ std::string str; - explicit DocTextNode(std::string str_val) - : str(str_val) { - } + explicit DocTextNode(std::string str_val) : str(str_val) {} static constexpr const char* _type_key = "printer.DocText"; TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode); @@ -68,8 +68,7 @@ class DocLineNode : public DocAtomNode { /*! \brief The amount of indent in newline. */ int indent; - explicit DocLineNode(int indent) - : indent(indent) {} + explicit DocLineNode(int indent) : indent(indent) {} static constexpr const char* _type_key = "printer.DocLine"; TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode); @@ -79,9 +78,7 @@ TVM_REGISTER_OBJECT_TYPE(DocLineNode); class DocLine : public DocAtom { public: - explicit DocLine(int indent) { - data_ = runtime::make_object(indent); - } + explicit DocLine(int indent) { data_ = runtime::make_object(indent); } TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode); }; @@ -89,14 +86,11 @@ class DocLine : public DocAtom { // DSL function implementations Doc& Doc::operator<<(const Doc& right) { CHECK(this != &right); - this->stream_.insert( - this->stream_.end(), right.stream_.begin(), right.stream_.end()); + this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); return *this; } -Doc& Doc::operator<<(std::string right) { - return *this << DocText(right); -} +Doc& Doc::operator<<(std::string right) { return *this << DocText(right); } Doc& Doc::operator<<(const DocAtom& right) { this->stream_.push_back(right); @@ -117,13 +111,9 @@ std::string Doc::str() { return os.str(); } -Doc Doc::NewLine(int indent) { - return Doc() << DocLine(indent); -} +Doc Doc::NewLine(int indent) { return Doc() << DocLine(indent); } -Doc Doc::Text(std::string text) { - return Doc() << DocText(text); -} +Doc Doc::Text(std::string text) { return Doc() << DocText(text); } Doc Doc::RawText(std::string text) { return Doc() << DocAtom(runtime::make_object(text)); @@ -152,10 +142,7 @@ Doc Doc::PyBoolLiteral(bool value) { } } -Doc Doc::Brace(std::string open, - const Doc& body, - std::string close, - int indent) { +Doc Doc::Brace(std::string open, const Doc& body, std::string close, int indent) { Doc doc; doc << open; doc << Indent(indent, NewLine() << body) << NewLine(); diff --git a/src/printer/doc.h b/src/printer/doc.h index 7d8d72e00b4ca..dc6ba8952f3e4 100644 --- a/src/printer/doc.h +++ b/src/printer/doc.h @@ -26,12 +26,13 @@ #ifndef TVM_PRINTER_DOC_H_ #define TVM_PRINTER_DOC_H_ +#include #include #include -#include + #include -#include #include +#include namespace tvm { @@ -48,7 +49,7 @@ class DocAtomNode : public Object { /*! * \brief Managed reference to DocAtomNode. * \sa DocAtomNode. -*/ + */ class DocAtom : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode); @@ -93,8 +94,7 @@ class Doc { * \tparam T the type of the value. * \return reference to self. */ - template::value>::type> + template ::value>::type> Doc& operator<<(const T& value) { std::ostringstream os; os << value; @@ -149,10 +149,7 @@ class Doc { * \param indent amount of indentation. * \return The created doc. */ - static Doc Brace(std::string open, - const Doc& body, - std::string close, - int indent = 2); + static Doc Brace(std::string open, const Doc& body, std::string close, int indent = 2); /*! * \brief Create a doc by concatenating together with separator. * \param vec The docs to be concatenated. diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index d3906926363cc..ebc76dcc80af9 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -24,10 +24,12 @@ #ifndef TVM_PRINTER_META_DATA_H_ #define TVM_PRINTER_META_DATA_H_ -#include #include +#include + #include #include + #include "doc.h" namespace tvm { @@ -98,8 +100,7 @@ class TextMetaDataContext { } std::string type_key = node->GetTypeKey(); CHECK(!type_key.empty()); - Array& mvector = - meta_data_[type_key]; + Array& mvector = meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); Doc doc; @@ -108,6 +109,13 @@ class TextMetaDataContext { return meta_repr_[node]; } + /*! + * \brief Test whether a node has been put in meta + * \param node The query node + * \return whether the node has been put in meta + */ + bool InMeta(const ObjectRef& node) { return meta_repr_.find(node) != meta_repr_.end(); } + /*! * \brief Print a key value pair */ @@ -126,9 +134,7 @@ class TextMetaDataContext { } /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } + bool empty() const { return meta_data_.empty(); } private: /*! \brief additional metadata stored in TVM json format */ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d7..3c545ef5488e5 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -18,7 +18,7 @@ */ /*! - * \file text_format_printer.cc + * \file relay_text_printer.cc * \brief Printer to print out the IR text format * that can be parsed by a parser. * @@ -32,154 +32,129 @@ * - Var * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ -#include #include -#include +#include #include #include +#include + +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" -#include "../relay/analysis/dependency_graph.h" -#include "../ir/attr_functor.h" +#include "text_printer.h" namespace tvm { namespace relay { -class RelayTextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, - public AttrFunctor { - public: - explicit RelayTextPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), - annotate_(annotate) {} - - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - Doc PrintOptionalInfo(const Expr& expr) { - Doc doc; - // default annotations - if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; - } - } else { - std::string annotated_expr = annotate_(expr); - if (annotated_expr != "") { - doc << annotated_expr; - } +/*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ +Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { + Doc doc; + // default annotations + if (annotate_ == nullptr) { + if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { + doc << " /* ty=" << Print(expr->checked_type()) << " */"; + } + } else { + std::string annotated_expr = annotate_(expr); + if (annotated_expr != "") { + doc << annotated_expr; } - - return doc; } - // indent a new body - Doc PrintBody(const ObjectRef& node, int indent = 2) { - Doc doc; - Doc body; - doc << "{"; - doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); - doc << "}"; - return doc; - } + return doc; +} - // create a new scope by creating a new printer object. This allows temp var - // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const ObjectRef& node) { - // print in a new scope - doc_stack_.push_back(Doc()); - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false, true); - doc = doc_stack_.back() << doc; - doc_stack_.pop_back(); - return doc; - } +// indent a new body +Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) { + Doc doc; + Doc body; + doc << "{"; + doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); + doc << "}"; + return doc; +} - Doc PrintFinal(const ObjectRef& node) { - if (node->IsInstance() && - !node->IsInstance()) { - // Temporarily skip non-relay functions. - // TODO(tvm-team) enhance the code to work for all functions - } else if (node.as()) { - Expr expr = Downcast(node); - dg_ = DependencyGraph::Create(&arena_, expr); - } +// create a new scope by creating a new printer object. This allows temp var +// numbers to be reused and prevents hoisted vars from escaping too far +Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { + // print in a new scope + doc_stack_.push_back(Doc()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node, false, true); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; +} - Doc doc; - doc << PrintScope(node); - if (!meta_.empty()) { - doc << Doc::NewLine(); - if (show_meta_data_) { - // append meta data in the end. - doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); - } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; - } - } - return doc; +Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { + if (node->IsInstance() && !node->IsInstance()) { + // Temporarily skip non-relay functions. + // TODO(tvm-team) enhance the code to work for all functions + } else if (node.as()) { + Expr expr = Downcast(node); + dg_ = DependencyGraph::Create(&arena_, expr); } - std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); - - Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { - bool is_non_relay_func = - node->IsInstance() && - !node->IsInstance(); - if (node.as() && !is_non_relay_func) { - return PrintExpr(Downcast(node), meta, try_inline); - } else if (node.as()) { - return PrintType(Downcast(node), meta); - } else if (node.as()) { - return PrintPattern(Downcast(node), meta); - } else if (node.as()) { - return PrintMod(Downcast(node)); - } else { - // default module. - std::ostringstream os; - os << node; - return Doc::RawText(os.str()); - } + Doc doc; + doc << PrintScope(node); + return doc; +} + +Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { + bool is_non_relay_func = + node->IsInstance() && !node->IsInstance(); + if (node.as() && !is_non_relay_func) { + return PrintExpr(Downcast(node), meta, try_inline); + } else if (node.as()) { + return PrintType(Downcast(node), meta); + } else if (node.as()) { + return PrintPattern(Downcast(node), meta); + } else if (node.as()) { + return PrintMod(Downcast(node)); + } else { + // default module. + std::ostringstream os; + os << node; + return Doc::RawText(os.str()); } +} - Doc TempVar(int n) { - Doc doc; - return doc << "%" << n; - } - - Doc AllocTemp() { - return TempVar(temp_var_counter_++); - } - - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - Doc GetUniqueName(const std::string& prefix) { - std::string unique_prefix = prefix; - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - unique_prefix = name; - break; - } +Doc RelayTextPrinter::TempVar(int n) { + Doc doc; + return doc << "%" << n; +} + +Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); } + +/*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ +Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + unique_prefix = name; + break; } } - name_alloc_map_[unique_prefix] = 0; - return Doc::Text(unique_prefix); } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} - Doc Print(Kind k) { - switch (k) { +Doc RelayTextPrinter::Print(Kind k) { + switch (k) { case kType: return Doc::Text("Type"); case kShapeVar: @@ -195,642 +170,594 @@ class RelayTextPrinter : default: LOG(ERROR) << "Unknown Kind"; throw; - } } - /*! - * \brief Allocate name to a type variable. - * \param var The input type variable. - * \return The corresponding name. - */ - Doc AllocTypeVar(const TypeVar& var) { - if (memo_type_.count(var)) { - Doc val = memo_type_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint; - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "t" + name; - } - Doc val = GetUniqueName(name); - memo_type_[var] = val; - if (var->kind != kType) { - val << ": " << Print(var->kind); - } +} +/*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) { + if (memo_type_.count(var)) { + Doc val = memo_type_[var]; + val << "-malformed-ir"; return val; } + std::string name = var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "t" + name; + } + Doc val = GetUniqueName(name); + memo_type_[var] = val; + if (var->kind != kType) { + val << ": " << Print(var->kind); + } + return val; +} - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - Doc val = memo_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "v" + name; - } - Doc val = GetUniqueName("%" + name); - memo_[var] = val; - if (var->type_annotation.defined()) { - val << ": " << Print(var->type_annotation); - } +/*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocVar(const Var& var) { + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + Doc val = memo_[var]; + val << "-malformed-ir"; return val; } - - bool IsUnique(const Expr& expr) { - auto it = dg_.expr_node.find(expr); - if (it == dg_.expr_node.end()) { - return true; - } else { - return !(it->second->parents.head && it->second->parents.head->next); - } + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; } + Doc val = GetUniqueName("%" + name); + memo_[var] = val; + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } + return val; +} - bool AlwaysInline(const Expr& expr) { - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as(); +bool RelayTextPrinter::IsUnique(const Expr& expr) { + auto it = dg_.expr_node.find(expr); + if (it == dg_.expr_node.end()) { + return true; + } else { + return !(it->second->parents.head && it->second->parents.head->next); } +} - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) { - // Exploit memoization to print GNF. - // The first time we visit an expression, we need to allocate a temp var - // for it. Every subsequent time we can just use its assigned variable. - // This works since hashing uses pointer equality. +bool RelayTextPrinter::AlwaysInline(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} - // determine whether to inline - bool inline_expr = AlwaysInline(expr); - if (try_inline) { - inline_expr |= IsUnique(expr); - } +//------------------------------------ +// Overload of Expr printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. + + // determine whether to inline + bool inline_expr = AlwaysInline(expr); + if (try_inline) { + inline_expr |= IsUnique(expr); + } + + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + + Doc printed_expr; + if (meta) { + printed_expr = meta_->GetMetaNode(GetRef(expr.get())); + } else if (!inline_expr && expr.as()) { + // wrap GNFed let in brackets + Doc body; + printed_expr << "("; + printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); + printed_expr << ")"; + } else { + printed_expr = VisitExpr(expr); + } - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - - Doc printed_expr; - if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (!inline_expr && expr.as()) { - // wrap GNFed let in brackets - Doc body; - printed_expr << "("; - printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); - printed_expr << ")"; - } else { - printed_expr = VisitExpr(expr); - } + printed_expr << PrintOptionalInfo(expr); - printed_expr << PrintOptionalInfo(expr); - - // add expr to doc - if (expr.as()) { - // This is our first time visiting the var and we hit the VarNode case - // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); - // Memoization is done in AllocVar. - return memo_[expr]; - } else if (inline_expr) { - memo_[expr] = printed_expr; - return printed_expr; - } else { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); - return temp_var; - } + // add expr to doc + if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. + doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + // Memoization is done in AllocVar. + return memo_[expr]; + } else if (inline_expr) { + memo_[expr] = printed_expr; + return printed_expr; + } else { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); + return temp_var; } +} - // Should only be triggered when op is a free variable being visited for the - // first time. - Doc VisitExpr_(const VarNode* op) final { - return AllocVar(GetRef(op)); +// Should only be triggered when op is a free variable being visited for the +// first time. +Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef(op)); } + +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ +template +Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Int(32)) { + os << value; + } else if (dtype == DataType::Float(32)) { + os << value << 'f'; + } else if (dtype == DataType::Float(64)) { + os << value; + } else if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; } + return Doc::Text(os.str()); +} - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ - template - static Doc ScalarLiteral(DataType dtype, const T& value) { +Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { + // Print out simple scalars directly. + if (op->is_scalar()) { std::ostringstream os; + DataType dtype = DataType(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); if (dtype == DataType::Int(32)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Int(64)) { + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(32)) { - os << value << 'f'; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(64)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Bool()) { - return Doc::PyBoolLiteral(value != 0); - } else { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } - return Doc::Text(os.str()); } + // default fall-back, record it as meta node. + Doc doc; + return doc << Print(GetRef(op), true); +} - Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = DataType(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == DataType::Int(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Int(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Bool()) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } - } - // default fall-back, record it as meta node. - Doc doc; - return doc << Print(GetRef(op), true); +Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(Print(field)); } - - Doc VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (op->fields.size() == 1) { - doc << ","; - } - return doc << ")"; + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (op->fields.size() == 1) { + doc << ","; } + return doc << ")"; +} - Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc; - return doc << Print(op->tuple) << "." << op->index; - } +Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { + Doc doc; + return doc << Print(op->tuple) << "." << op->index; +} - Doc VisitExpr_(const IfNode* op) final { - Doc doc; - doc << "if (" << Print(op->cond) << ") "; - doc << PrintBody(op->true_branch); - doc << " else "; - doc << PrintBody(op->false_branch); - return doc; - } +Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { + Doc doc; + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; +} - Doc VisitExpr_(const LetNode* op) final { - Doc doc; - doc - << "let " - << AllocVar(op->var) - << " = " - << Print(op->value, false, true) - << ";" +Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << ";" << Doc::NewLine(); - // we use a scope here so GNF hoisting doesn't escape too far - // and nested, unique lets are not hoisted - doc << PrintScope(op->body); - return doc; + // we use a scope here so GNF hoisting doesn't escape too far + // and nested, unique lets are not hoisted + doc << PrintScope(op->body); + return doc; +} + +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { + Doc doc; + doc << prefix; + if (fn->type_params.size() > 0) { + doc << "["; + std::vector type_params; + for (const TypeVar& tv : fn->type_params) { + type_params.push_back(Doc::Text(tv->name_hint)); + } + doc << Doc::Concat(type_params); + doc << "]"; + } + doc << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + params.push_back(d); } + doc << Doc::Concat(params) << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; + } + doc << PrintBody(fn->body); + return doc; +} - Doc PrintFunc(const Doc& prefix, const relay::Function& fn) { +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { + if (auto* n = base_func.as()) { + return PrintFunc(prefix, GetRef(n)); + } else if (auto* n = base_func.as()) { + std::ostringstream os; + os << GetRef(n); + return Doc::RawText(os.str()); + } else { + // def @xyz = meta['ExternalFunc'][id] Doc doc; - doc << prefix; - if (fn->type_params.size() > 0) { - doc << "["; - std::vector type_params; - for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc::Text(tv->name_hint)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - doc << "("; - std::vector params; - for (Var param : fn->params) { - params.push_back(AllocVar(param)); - } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { - params.push_back(d); - } - doc << Doc::Concat(params) << ") "; - if (fn->ret_type.defined()) { - doc << "-> " << Print(fn->ret_type) << " "; - } - doc << PrintBody(fn->body); + doc << prefix << " = " << meta_->GetMetaNode(base_func); return doc; } +} - Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { - if (auto* n = base_func.as()) { - return PrintFunc(prefix, GetRef(n)); - } else if (auto* n = base_func.as()) { - std::ostringstream os; - os << GetRef(n); - return Doc::RawText(os.str()); - } else { - // def @xyz = meta['ExternalFunc'][id] - Doc doc; - doc << prefix << " = " << meta_.GetMetaNode(base_func); - return doc; +Doc RelayTextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); } + doc << Print(kv.second); + doc << Doc::NewLine(); } - - Doc PrintMod(const IRModule& mod) { - Doc doc; - int counter = 0; - // type definitions - for (const auto& kv : mod->type_definitions) { - if (counter++ != 0) { - doc << Doc::NewLine(); - } - doc << Print(kv.second); - doc << Doc::NewLine(); + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + dg_ = DependencyGraph::Create(&arena_, kv.second); } - // functions - for (const auto& kv : mod->functions) { - if (kv.second.as()) { - dg_ = DependencyGraph::Create(&arena_, kv.second); - } - if (counter++ != 0) { - doc << Doc::NewLine(); - } - std::ostringstream os; - os << "def @" << kv.first->name_hint; - doc << PrintFunc(Doc::Text(os.str()), kv.second); + if (counter++ != 0) { doc << Doc::NewLine(); } - return doc; - } - - Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Doc::Text("fn "), GetRef(op)); + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << PrintFunc(Doc::Text(os.str()), kv.second); + doc << Doc::NewLine(); } + return doc; +} - Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); - } +Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { + return PrintFunc(Doc::Text("fn "), GetRef(op)); +} - Doc VisitExpr_(const OpNode* op) final { - return Doc::Text(op->name); - } +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); } - Doc VisitExpr_(const CallNode* op) final { - Doc doc; - // visit args first so they are lifted before the op - // this places op closer to its call site - std::vector args; - for (const Expr& arg : op->args) { - args.push_back(Print(arg)); - } - for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { - args.push_back(d); - } - const auto* cons_node = op->op.as(); - if (cons_node) { - doc << cons_node->name_hint; - } else { - doc << Print(op->op); - } +Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } - if (cons_node && cons_node->inputs.size() == 0) { - // don't print as a call if it's a 0-arity cons - return doc; - } else { - return doc << "(" << Doc::Concat(args) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + // visit args first so they are lifted before the op + // this places op closer to its call site + std::vector args; + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { + args.push_back(d); + } + const auto* cons_node = op->op.as(); + if (cons_node) { + doc << cons_node->name_hint; + } else { + doc << Print(op->op); } - Doc VisitExpr_(const RefCreateNode* op) final { - Doc doc; - return doc << "ref(" << Print(op->value) << ")"; + if (cons_node && cons_node->inputs.size() == 0) { + // don't print as a call if it's a 0-arity cons + return doc; + } else { + return doc << "(" << Doc::Concat(args) << ")"; } +} - Doc VisitExpr_(const RefReadNode* op) final { - Doc doc; - return doc << Print(op->ref) << "^"; - } +Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) { + Doc doc; + return doc << "ref(" << Print(op->value) << ")"; +} - Doc VisitExpr_(const RefWriteNode* op) final { - Doc doc; - return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) { + Doc doc; + return doc << Print(op->ref) << "^"; +} - Doc VisitExpr_(const MatchNode* op) final { - // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. - Doc doc; - Doc body; - doc << "match"; - if (!op->complete) { - doc << "?"; - } - doc << " (" << Print(op->data) << ") {"; - std::vector clause_docs; - for (const auto& clause : op->clauses) { - Doc clause_doc; - clause_doc << PrintPattern(clause->lhs, false) << " => "; - Doc rhs_doc = PrintScope(clause->rhs); - if (clause->rhs.as()) { - // only add braces if there are multiple lines on the rhs - rhs_doc = Doc::Brace("{", rhs_doc, "}"); - } - clause_doc << rhs_doc << ","; - clause_docs.push_back(clause_doc); - } - doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) - << Doc::NewLine() << "}"; - return doc; - } +Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) { + Doc doc; + return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; +} - Doc PrintPattern(const Pattern& pattern, bool meta) { - auto it = memo_pattern_.find(pattern); - if (it != memo_pattern_.end()) return it->second; - Doc printed_pattern; - if (meta) { - printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); - } else { - printed_pattern = VisitPattern(pattern); - } - memo_pattern_[pattern] = printed_pattern; - return printed_pattern; - } +Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { + // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. + Doc doc; + Doc body; + doc << "match"; + if (!op->complete) { + doc << "?"; + } + doc << " (" << Print(op->data) << ") {"; + std::vector clause_docs; + for (const auto& clause : op->clauses) { + Doc clause_doc; + clause_doc << PrintPattern(clause->lhs, false) << " => "; + Doc rhs_doc = PrintScope(clause->rhs); + if (clause->rhs.as()) { + // only add braces if there are multiple lines on the rhs + rhs_doc = Doc::Brace("{", rhs_doc, "}"); + } + clause_doc << rhs_doc << ","; + clause_docs.push_back(clause_doc); + } + doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) + << Doc::NewLine() << "}"; + return doc; +} - Doc VisitPattern_(const PatternConstructorNode* p) final { - Doc doc; - doc << p->constructor->name_hint; - if (!p->patterns.empty()) { - doc << "("; - std::vector pats; - for (const auto& pat : p->patterns) { - pats.push_back(Print(pat)); - } - doc << Doc::Concat(pats) << ")"; - } - return doc; +Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) { + auto it = memo_pattern_.find(pattern); + if (it != memo_pattern_.end()) return it->second; + Doc printed_pattern; + if (meta) { + printed_pattern = meta_->GetMetaNode(GetRef(pattern.get())); + } else { + printed_pattern = VisitPattern(pattern); } + memo_pattern_[pattern] = printed_pattern; + return printed_pattern; +} - Doc VisitPattern_(const PatternTupleNode* pt) final { - Doc doc; +Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) { + Doc doc; + doc << p->constructor->name_hint; + if (!p->patterns.empty()) { doc << "("; std::vector pats; - for (const auto& pat : pt->patterns) { + for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } doc << Doc::Concat(pats) << ")"; - return doc; } + return doc; +} - Doc VisitPattern_(const PatternWildcardNode* pw) final { - return Doc::Text("_"); +Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) { + Doc doc; + doc << "("; + std::vector pats; + for (const auto& pat : pt->patterns) { + pats.push_back(Print(pat)); } + doc << Doc::Concat(pats) << ")"; + return doc; +} - Doc VisitPattern_(const PatternVarNode* pv) final { - return AllocVar(pv->var); - } +Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); } - Doc VisitExpr_(const ConstructorNode* n) final { - Doc doc; - doc << n->name_hint; - if (in_adt_def_ && n->inputs.size() != 0) { - doc << "("; - std::vector inputs; - for (Type input : n->inputs) { - inputs.push_back(Print(input)); - } - doc << Doc::Concat(inputs) << ")"; - } - return doc; - } +Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); } - //------------------------------------ - // Overload of Type printing functions - //------------------------------------ - Doc PrintType(const Type& type, bool meta) { - auto it = memo_type_.find(type); - if (it != memo_type_.end()) return it->second; - Doc printed_type; - if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); - } else { - printed_type = VisitType(type); +Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) { + Doc doc; + doc << n->name_hint; + if (in_adt_def_ && n->inputs.size() != 0) { + doc << "("; + std::vector inputs; + for (Type input : n->inputs) { + inputs.push_back(Print(input)); } - memo_type_[type] = printed_type; - return printed_type; + doc << Doc::Concat(inputs) << ")"; } + return doc; +} - Doc VisitTypeDefault_(const Object* node) final { - // by default always print as meta data - return Print(GetRef(node), true); +//------------------------------------ +// Overload of Type printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintType(const Type& type, bool meta) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type; + if (meta) { + printed_type = meta_->GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); } + memo_type_[type] = printed_type; + return printed_type; +} - Doc VisitType_(const TypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) { + // by default always print as meta data + return Print(GetRef(node), true); +} - Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); } - Doc VisitType_(const TypeCallNode* node) final { - Doc doc = PrintType(node->func, false); - std::vector args; - for (const Type& t : node->args) { - args.push_back(PrintType(t, false)); - } - doc << "["; - doc << Doc::Concat(args); - doc << "]"; - return doc; - } +Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) { + return Doc::Text(node->name_hint); +} - Doc PrintDType(DataType dtype) { - return Doc::Text(runtime::DLDataType2String(dtype)); +Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) { + Doc doc = PrintType(node->func, false); + std::vector args; + for (const Type& t : node->args) { + args.push_back(PrintType(t, false)); } + doc << "["; + doc << Doc::Concat(args); + doc << "]"; + return doc; +} - Doc VisitType_(const TensorTypeNode* node) final { - // scalar type - if (node->shape.size() == 0) { - return PrintDType(node->dtype); - } - Doc doc; - doc << "Tensor[("; - std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); - } - doc << Doc::Concat(shapes); - return doc << "), " << PrintDType(node->dtype) << "]"; - } +Doc RelayTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} - Doc VisitType_(const TupleTypeNode* node) final { - std::vector fields; - for (Type field : node->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (node->fields.size() == 1) { - doc << ","; - } - return doc << ")"; +Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); } - - Doc VisitType_(const FuncTypeNode* node) final { - Doc doc; - doc << "fn "; - if (node->type_params.size() != 0) { - doc << "["; - std::vector type_params; - for (Type type_param : node->type_params) { - type_params.push_back(Print(type_param)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - std::vector arg_types; - for (Type arg_type : node->arg_types) { - arg_types.push_back(Print(arg_type)); - } - return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); + Doc doc; + doc << "Tensor[("; + std::vector shapes; + for (ObjectRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); } + doc << Doc::Concat(shapes); + return doc << "), " << PrintDType(node->dtype) << "]"; +} - Doc VisitType_(const RelayRefTypeNode* node) final { - Doc doc; - return doc << "ref(" << Print(node->value) << ")"; +Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} - Doc VisitType_(const TypeDataNode* node) final { - in_adt_def_ = true; - Doc doc; - doc << "type " << Print(node->header); - - // type vars - if (node->type_vars.size() != 0) { - doc << "["; - std::vector type_vars; - for (Type type_var : node->type_vars) { - type_vars.push_back(Print(type_var)); - } - doc << Doc::Concat(type_vars) << "]"; - } - doc << " "; - - std::vector constructor_docs; - for (Constructor constructor : node->constructors) { - constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); - } - Doc separator; - separator << "," << Doc::NewLine(); - Doc adt_body; - adt_body << Doc::Concat(constructor_docs, separator); - // add trailing comma if there are any constructors - if (!constructor_docs.empty()) { - adt_body << ","; +Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) { + Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "["; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); } - doc << Doc::Brace("{", adt_body, "}"); - in_adt_def_ = false; - return doc; + doc << Doc::Concat(type_params); + doc << "]"; } + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); + } + return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); +} - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ +Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) { + Doc doc; + return doc << "ref(" << Print(node->value) << ")"; +} - Doc PrintAttr(const ObjectRef& value, bool meta = false) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (meta) { - printed_attr = meta_.GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; - } else { - return Doc::Text("None"); +Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { + in_adt_def_ = true; + Doc doc; + doc << "type " << Print(node->header); + + // type vars + if (node->type_vars.size() != 0) { + doc << "["; + std::vector type_vars; + for (Type type_var : node->type_vars) { + type_vars.push_back(Print(type_var)); } + doc << Doc::Concat(type_vars) << "]"; } + doc << " "; - Doc VisitAttrDefault_(const Object* op) final { - return PrintAttr(GetRef(op), true); + std::vector constructor_docs; + for (Constructor constructor : node->constructors) { + constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); } + Doc separator; + separator << "," << Doc::NewLine(); + Doc adt_body; + adt_body << Doc::Concat(constructor_docs, separator); + // add trailing comma if there are any constructors + if (!constructor_docs.empty()) { + adt_body << ","; + } + doc << Doc::Brace("{", adt_body, "}"); + in_adt_def_ = false; + return doc; +} - Doc VisitAttr_(const ArrayNode* op) final { - Doc doc; - doc << "["; - std::vector arr_vals; - for (auto val : op->data) { - arr_vals.push_back(PrintAttr(val)); +//------------------------------------ +// Overload of Attr printing functions +//------------------------------------ + +Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); + } else { + printed_attr = VisitAttr(value); } - doc << Doc::Concat(arr_vals); - doc << "]"; - return doc; + return printed_attr; + } else { + return Doc::Text("None"); } +} - Doc VisitAttr_(const tir::IntImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { + return PrintAttr(GetRef(op), true); +} - Doc VisitAttr_(const tir::FloatImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { + Doc doc; + doc << "["; + std::vector arr_vals; + for (auto val : op->data) { + arr_vals.push_back(PrintAttr(val)); + } + doc << Doc::Concat(arr_vals); + doc << "]"; + return doc; +} - Doc VisitAttr_(const tir::StringImmNode* op) final { - return Doc::StrLiteral(op->value); - } +Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} - private: - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief Stack of docs to implement scoped GNFing. */ - std::vector doc_stack_{}; - /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_pattern_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief meta data context */ - TextMetaDataContext meta_; - /*! \brief counter of temporary variable */ - size_t temp_var_counter_{0}; - /*! \brief whether the printer is currently in an ADT definition */ - bool in_adt_def_; - /*! \brief arena for dependency graph */ - support::Arena arena_; - /*! \brief dependency graph of the expr */ - DependencyGraph dg_; - class AttrPrinter; - friend class AttrPrinter; -}; +Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} + +Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { + return Doc::StrLiteral(op->value); +} /*! * \brief Attribute printer which prints the attributes in the call. */ -class RelayTextPrinter::AttrPrinter : - public AttrVisitor { +class RelayTextPrinter::AttrPrinter : public AttrVisitor { public: - AttrPrinter(std::vector* doc, RelayTextPrinter* parent) - : docs(doc), parent_(parent) {} + AttrPrinter(std::vector* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {} - template + template void PrintKV(const char* key, const T& value) { Doc doc; doc << key << "=" << value; @@ -842,24 +769,12 @@ class RelayTextPrinter::AttrPrinter : doc << key << "=" << *value << "f"; docs->push_back(doc); } - void Visit(const char* key, int64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, uint64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, int* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, bool* value) final { - PrintKV(key, Doc::PyBoolLiteral(*value)); - } - void Visit(const char* key, std::string* value) final { - PrintKV(key, Doc::StrLiteral(*value)); - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; } void Visit(const char* key, DataType* value) final { PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); } @@ -875,15 +790,14 @@ class RelayTextPrinter::AttrPrinter : RelayTextPrinter* parent_; }; -std::vector RelayTextPrinter::PrintCallAttrs( - const Attrs& attrs, const Expr& op) { +std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { // fallback Doc doc; - doc << meta_.GetMetaNode(attrs); + doc << meta_->GetMetaNode(attrs); docs.push_back(doc); return docs; } else { @@ -905,38 +819,6 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { } return docs; } -} // namespace relay - -static const char* kSemVer = "v0.0.4"; - -// TODO(tvm-team): split into files, related: arith/analyzer.h -// -// - text_printer.h (common header) -// - text_printer.cc (prints modules dispatch into relay and tir files) -// - type_text_printer.cc(specific printing logics for types, -// can also consider put under type_text_printer) -// - Implements AsText -// - relay_text_printer.cc (specific printing logics for relay) -// - tir_text_printer.cc (specific printing logics for TIR) -std::string PrettyPrint(const ObjectRef& node) { - Doc doc; - doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); - return doc.str(); -} -std::string AsText(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - Doc doc; - doc << kSemVer << Doc::NewLine(); - doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); - return doc.str(); -} - - -TVM_REGISTER_GLOBAL("ir.PrettyPrint") -.set_body_typed(PrettyPrint); - -TVM_REGISTER_GLOBAL("ir.AsText") -.set_body_typed(AsText); +} // namespace relay } // namespace tvm diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc new file mode 100644 index 0000000000000..2993d38234ead --- /dev/null +++ b/src/printer/text_printer.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.cc + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#include "text_printer.h" + +#include + +#include + +namespace tvm { + +static const char* kSemVer = "v0.0.4"; + +// TODO(tvm-team): split into files, related: arith/analyzer.h +// +// - text_printer.h (common header) +// - text_printer.cc (prints modules dispatch into relay and tir files) +// - type_text_printer.cc(specific printing logics for types, +// can also consider put under type_text_printer) +// - Implements AsText +// - relay_text_printer.cc (specific printing logics for relay) +// - tir_text_printer.cc (specific printing logics for TIR) + +Doc TextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); + } + doc << relay_text_printer_.Print(kv.second); + doc << Doc::NewLine(); + } + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + relay_text_printer_.dg_ = + relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + } + if (counter++ != 0) { + doc << Doc::NewLine(); + } + if (kv.second.as()) { + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); + } else if (kv.second.as()) { + doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + } + doc << Doc::NewLine(); + } + return doc; +} + +String PrettyPrint(const ObjectRef& node) { + Doc doc; + doc << TextPrinter(false, nullptr).PrintFinal(node); + return doc.str(); +} + +String AsText(const ObjectRef& node, bool show_meta_data, + runtime::TypedPackedFunc annotate) { + Doc doc; + doc << kSemVer << Doc::NewLine(); + runtime::TypedPackedFunc ftyped = nullptr; + if (annotate != nullptr) { + ftyped = runtime::TypedPackedFunc( + [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); }); + } + doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node); + return doc.str(); +} + +TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint); + +TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText); + +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h new file mode 100644 index 0000000000000..00b6fb9007662 --- /dev/null +++ b/src/printer/text_printer.h @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.h + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#ifndef TVM_PRINTER_TEXT_PRINTER_H_ +#define TVM_PRINTER_TEXT_PRINTER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +class TextPrinter; +} // namespace tvm + +namespace tvm { +namespace relay { + +class RelayTextPrinter : public ExprFunctor, + public PatternFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, + runtime::TypedPackedFunc annotate) + : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr); + // indent a new body + Doc PrintBody(const ObjectRef& node, int indent = 2); + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far + Doc PrintScope(const ObjectRef& node); + Doc PrintFinal(const ObjectRef& node); + std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); + std::vector PrintFuncAttrs(const Attrs& attrs); + + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); + + Doc TempVar(int n); + Doc AllocTemp(); + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(const std::string& prefix); + Doc Print(Kind k); + /*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ + Doc AllocTypeVar(const TypeVar& var); + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var); + bool IsUnique(const Expr& expr); + bool AlwaysInline(const Expr& expr); + + Doc PrintFunc(const Doc& prefix, const relay::Function& fn); + Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); + Doc PrintMod(const IRModule& mod); + + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline); + // Should only be triggered when op is a free variable being visited for the + // first time. + Doc VisitExpr_(const VarNode* op) final; + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ + template + static Doc ScalarLiteral(DataType dtype, const T& value); + Doc VisitExpr_(const ConstantNode* op) final; + Doc VisitExpr_(const TupleNode* op) final; + Doc VisitExpr_(const TupleGetItemNode* op) final; + Doc VisitExpr_(const IfNode* op) final; + Doc VisitExpr_(const LetNode* op) final; + Doc VisitExpr_(const FunctionNode* op) final; + Doc VisitExpr_(const GlobalVarNode* op) final; + Doc VisitExpr_(const OpNode* op) final; + Doc VisitExpr_(const CallNode* op) final; + Doc VisitExpr_(const RefCreateNode* op) final; + Doc VisitExpr_(const RefReadNode* op) final; + Doc VisitExpr_(const RefWriteNode* op) final; + Doc VisitExpr_(const MatchNode* op) final; + Doc PrintPattern(const Pattern& pattern, bool meta); + Doc VisitPattern_(const PatternConstructorNode* p) final; + Doc VisitPattern_(const PatternTupleNode* pt) final; + Doc VisitPattern_(const PatternWildcardNode* pw) final; + Doc VisitPattern_(const PatternVarNode* pv) final; + Doc VisitExpr_(const ConstructorNode* n) final; + //------------------------------------ + // Overload of Type printing functions + //------------------------------------ + Doc PrintType(const Type& type, bool meta); + Doc VisitTypeDefault_(const Object* node) final; + Doc VisitType_(const TypeVarNode* node) final; + Doc VisitType_(const GlobalTypeVarNode* node); + Doc VisitType_(const TypeCallNode* node) final; + Doc PrintDType(DataType dtype); + Doc VisitType_(const TensorTypeNode* node) final; + Doc VisitType_(const TupleTypeNode* node) final; + Doc VisitType_(const FuncTypeNode* node) final; + Doc VisitType_(const RelayRefTypeNode* node) final; + Doc VisitType_(const TypeDataNode* node) final; + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + Doc PrintAttr(const ObjectRef& value, bool meta = false); + Doc VisitAttrDefault_(const Object* op) final; + Doc VisitAttr_(const ArrayNode* op) final; + Doc VisitAttr_(const tir::IntImmNode* op) final; + Doc VisitAttr_(const tir::FloatImmNode* op) final; + Doc VisitAttr_(const tir::StringImmNode* op) final; + + private: + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{}; + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_pattern_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; + /*! \brief whether the printer is currently in an ADT definition */ + bool in_adt_def_; + /*! \brief arena for dependency graph */ + support::Arena arena_; + /*! \brief dependency graph of the expr */ + DependencyGraph dg_; + class AttrPrinter; + friend class AttrPrinter; + friend class tvm::TextPrinter; +}; + +} // namespace relay +} // namespace tvm + +namespace tvm { +namespace tir { + +/*! + * \brief Meta node collector + * If we decide to put some node into meta, then all the sub-nodes inside + * it need to be put in meta as well, since when parsing we need to know + * whether two refs are the same + */ +class MetaCollector : public StmtExprVisitor { + public: + explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} + + void Collect(const ObjectRef& n) { + // these nodes can be print directly(StringLiteral or use identifier to identify) + if (!n.defined() || n.as() || n.as() || n.as() || + n.as() || n.as() || n.as()) { + return; + } + if (n->IsInstance()) { + VisitStmt(Downcast(n)); + } else if (n->IsInstance()) { + VisitExpr(Downcast(n)); + } + } + + void VisitStmt(const Stmt& n) override { + meta_->GetMetaNode(n); + StmtVisitor::VisitStmt(n); + } + + void VisitExpr(const PrimExpr& n) override { + meta_->GetMetaNode(n); + ExprVisitor::VisitExpr(n); + } + + private: + TextMetaDataContext* meta_; +}; + +class TIRTextPrinter : public StmtFunctor, + public ExprFunctor, + public TypeFunctor { + public: + explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) + : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} + + /*! \brief Print the node */ + Doc Print(const ObjectRef& node); + + private: + /*! \brief whether show meta data */ + bool show_meta_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief meta collector */ + MetaCollector meta_collector_; + /*! \brief Map from Var to Doc */ + std::unordered_map memo_var_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_buf_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + + friend class tvm::TextPrinter; + + Doc VisitExpr_(const IntImmNode* op) override; + Doc VisitExpr_(const FloatImmNode* op) override; + Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const CastNode* op) override; + Doc VisitExpr_(const VarNode* op) override; + Doc VisitExpr_(const AddNode* op) override; + Doc VisitExpr_(const SubNode* op) override; + Doc VisitExpr_(const MulNode* op) override; + Doc VisitExpr_(const DivNode* op) override; + Doc VisitExpr_(const ModNode* op) override; + Doc VisitExpr_(const FloorDivNode* op) override; + Doc VisitExpr_(const FloorModNode* op) override; + Doc VisitExpr_(const MinNode* op) override; + Doc VisitExpr_(const MaxNode* op) override; + Doc VisitExpr_(const EQNode* op) override; + Doc VisitExpr_(const NENode* op) override; + Doc VisitExpr_(const LTNode* op) override; + Doc VisitExpr_(const LENode* op) override; + Doc VisitExpr_(const GTNode* op) override; + Doc VisitExpr_(const GENode* op) override; + Doc VisitExpr_(const AndNode* op) override; + Doc VisitExpr_(const OrNode* op) override; + Doc VisitExpr_(const NotNode* op) override; + Doc VisitExpr_(const SelectNode* op) override; + Doc VisitExpr_(const BufferLoadNode* op) override; + Doc VisitExpr_(const LoadNode* op) override; + Doc VisitExpr_(const RampNode* op) override; + Doc VisitExpr_(const BroadcastNode* op) override; + Doc VisitExpr_(const LetNode* op) override; + Doc VisitExpr_(const CallNode* op) override; + Doc VisitExpr_(const ShuffleNode* op) override; + Doc VisitExpr_(const ReduceNode* op) override; + Doc VisitExprDefault_(const Object* op) override; + + Doc VisitStmt_(const LetStmtNode* op) override; + Doc VisitStmt_(const AttrStmtNode* op) override; + Doc VisitStmt_(const AssertStmtNode* op) override; + Doc VisitStmt_(const StoreNode* op) override; + Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const FreeNode* op) override; + Doc VisitStmt_(const IfThenElseNode* op) override; + Doc VisitStmt_(const SeqStmtNode* op) override; + Doc VisitStmt_(const EvaluateNode* op) override; + Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const PrefetchNode* op) override; + Doc VisitStmtDefault_(const Object* op) override; + + Doc VisitType_(const PrimTypeNode* node) override; + Doc VisitType_(const PointerTypeNode* node) override; + Doc VisitType_(const TupleTypeNode* node) override; + + Doc PrintIRModule(const IRModule& module); + Doc PrintPrimFunc(const PrimFunc& primFunc); + Doc PrintArray(const ArrayNode* op); + Doc PrintIterVar(const IterVarNode* op); + Doc PrintRange(const RangeNode* op); + Doc PrintBuffer(const BufferNode* op); + Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } + + /*! + * \brief special method to print out data type + * \param dtype The data type + */ + static Doc PrintDType(DataType dtype); + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ + template + static Doc PrintConstScalar(DataType dtype, const T& data); + Doc GetUniqueName(std::string prefix); + Doc AllocVar(const Var& var); + Doc AllocBuf(const Buffer& buffer); + /*! + * \brief special method to render vectors of docs with a separator + * \param vec vector of docs + * \param sep separator + */ + static Doc PrintSep(const std::vector& vec, const Doc& sep); + Doc PrintBody(const Stmt& body, bool indent = true); +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { + +class TextPrinter { + public: + explicit TextPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate) + : show_meta_data_(show_meta_data), + annotate_(annotate), + relay_text_printer_(show_meta_data, &meta_, annotate), + tir_text_printer_(show_meta_data, &meta_) {} + + /*! \brief whether show meta data */ + bool show_meta_data_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Relay Text Printer */ + relay::RelayTextPrinter relay_text_printer_; + /*! \brief TIR Text Printer */ + tir::TIRTextPrinter tir_text_printer_; + + Doc PrintFinal(const ObjectRef& node) { + Doc doc; + if (node->IsInstance()) { + doc << PrintMod(Downcast(node)); + } else if (node->IsInstance() || node->IsInstance() || + node->IsInstance()) { + doc << tir_text_printer_.Print(node); + } else { + doc << relay_text_printer_.PrintFinal(node); + } + if (!meta_.empty()) { + doc << Doc::NewLine(); + if (show_meta_data_) { + // append meta data in the end. + doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); + } else { + doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; + } + } + return doc; + } + + Doc PrintMod(const IRModule& mod); +}; +} // namespace tvm + +#endif // TVM_PRINTER_TEXT_PRINTER_H_ diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc new file mode 100644 index 0000000000000..511a24377738c --- /dev/null +++ b/src/printer/tir_text_printer.cc @@ -0,0 +1,612 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_text_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +namespace tir { + +Doc TIRTextPrinter::Print(const ObjectRef& node) { + if (!node.defined()) return Doc::Text("(nullptr)"); + if (node->IsInstance()) { + return VisitStmt(Downcast(node)); + } else if (node->IsInstance()) { + return Doc::Text("?"); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); + } else if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return PrintPrimFunc(Downcast(node)); + } else if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { + return PrintArray(node.as()); + } else if (node->IsInstance()) { + return PrintIterVar(node.as()); + } else if (node->IsInstance()) { + return PrintRange(node.as()); + } else if (node->IsInstance()) { + return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintString(node.as()); + } else { + return this->meta_->GetMetaNode(node); + } +} + +Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { + const auto* op = primFunc.operator->(); + const auto& signature = op->func_type_annotation(); + // collect Meta in DictAttr + for (const auto& it : primFunc->attrs->dict) { + meta_collector_.Collect(it.second); + } + // collect buffers in buffer_map + memo_var_.clear(); + memo_buf_.clear(); + for (const auto& it : op->buffer_map) { + memo_buf_[it.second] = AllocBuf(it.second); + } + // print PrimFunc + Doc doc; + doc << "primfn" + << "("; + // print params and its type annotation + std::vector params; + for (const auto& param : op->params) { + params.push_back(Print(param)); + } + Doc sep; + doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")"; + // print return type + doc << " -> " << Print(signature->ret_type); + // print attr + Doc attr_doc; + std::vector attr_docs; + for (const auto& it : op->attrs->dict) { + attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; + doc << Doc::Indent(2, attr_doc); + // print all the buffers in the tree + Doc buffer_doc; + std::vector buffer_docs; + for (const auto& it : memo_buf_) { + const auto& buf = it.first; + buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " + << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " + << Print(buf->strides)); + if (!is_zero(buf->elem_offset)) { + buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != 128) { + buffer_docs.back() << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 1) { + buffer_docs.back() << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); + } + buffer_docs.back() << ")"; + } + buffer_doc << Doc::NewLine() << "buffers = {"; + buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); + doc << Doc::Indent(2, buffer_doc) << "}"; + // print buffer_map + std::vector buffer_map_doc; + for (const auto& it : op->buffer_map) { + buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); + } + doc << Doc::Indent( + 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { + const auto* op = module.operator->(); + Doc doc; + + Doc body; + body << Doc::NewLine(); + std::vector functions; + for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { + if ((*it).second.as()) { + functions.push_back(Print((*it).second)); + } + } + body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); + doc << Doc::Indent(0, body); + return doc; +} + +Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { + Doc doc; + doc << '['; + for (size_t i = 0; i < op->data.size(); ++i) { + if (i != 0) { + doc << ", "; + } + doc << Print(op->data[i]); + } + doc << ']'; + return doc; +} + +Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) { + Doc doc; + doc << "IterVar(" << Print(op->var); + if (op->dom.defined()) { + doc << ", [" << Print(op->dom) << "], "; + } else { + doc << ", " << Print(op->dom) << ", "; + } + doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "; + doc << Doc::StrLiteral(op->thread_tag) << ")"; + return doc; +} + +Doc TIRTextPrinter::PrintRange(const RangeNode* op) { + return Print(op->min) << ":" << Print(op->min + op->extent); +} + +Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { + const Buffer& buffer = GetRef(op); + CHECK_GT(memo_buf_.count(buffer), 0); + return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer]; +} + +Doc TIRTextPrinter::VisitExprDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } + +Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { + Doc doc; + doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { + const Var& var = GetRef(op); + return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); +} + +#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ + Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ") + +Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) { + Doc doc; + doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) { + Doc doc; + doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MinNode* op) { + Doc doc; + doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) { + Doc doc; + doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const NotNode* op) { + Doc doc; + doc << "!" << Print(op->a); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) { + Doc doc; + doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + << Print(op->false_value); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { + Doc doc; + doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) + << "])"; + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { + Doc doc; + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { + Doc doc; + doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); + return doc; +} + +inline const char* CallType2String(CallNode::CallType t) { + switch (t) { + case CallNode::Extern: + return "extern"; + case CallNode::ExternCPlusPlus: + return "extern_cpp"; + case CallNode::PureExtern: + return "pure_extern"; + case CallNode::Halide: + return "halide"; + case CallNode::Intrinsic: + return "intrin"; + case CallNode::PureIntrinsic: + return "pure_intrin"; + } + LOG(FATAL) << "Unknown CallType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + doc << "@" << Doc::Text(op->name) << "("; + std::vector args; + for (const auto& arg : op->args) { + args.push_back(Print(arg)); + } + doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) + << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) + << ", index=" << op->value_index << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) { + Doc doc; + doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { + Doc doc; + doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) + << ", " << op->value_index << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { + Doc doc; + meta_collector_.Collect(op->node); + doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = " + << Print(op->value); + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { + Doc doc; + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" + << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) { + Doc doc; + doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { + Doc doc; + doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << PrintBody(op->body) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { + Doc doc; + doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + if (!is_one(op->condition)) { + doc << " if " << Print(op->condition); + } + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) { + Doc doc; + doc << "free(" << Print(op->buffer_var) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { + Doc doc; + doc << "if " << Print(op->condition) << PrintBody(op->then_case); + if (!is_one(op->condition) && op->else_case.defined()) { + doc << " else" << PrintBody(op->else_case); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { + std::vector stmts; + Doc seq_doc, doc; + for (Stmt stmt : op->seq) { + seq_doc << Doc::NewLine() << Print(stmt); + } + doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { + Doc doc; + doc << Print(op->value); + return doc; +} + +inline const char* ForType2String(ForType t) { + switch (t) { + case ForType::Serial: + return "serial"; + case ForType::Parallel: + return "parallel"; + case ForType::Vectorized: + return "vectorized"; + case ForType::Unrolled: + return "unroll"; + } + LOG(FATAL) << "Unknown ForType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { + Doc doc; + doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " + << Print(op->min + op->extent) << ")"; + if (op->for_type != ForType::Serial) { + doc << " " << Doc::StrLiteral(ForType2String(op->for_type)); + } + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { + Doc doc; + doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) { + Doc doc; + doc << PrintDType(node->dtype); + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) { + Doc doc; + doc << "Pointer(" << Print(node->element_type) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} + +Doc TIRTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} + +template +Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { + Doc doc; + std::ostringstream os; + os << data; + if (dtype == DataType::Int(32)) { + doc << Doc::Text(os.str()); + } else { + if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) { + doc << ((data == 1) ? "True" : "False"); + return doc; + } + doc << Doc::Text(os.str()); + switch (dtype.code()) { + case kDLInt: + doc << "i"; + break; + case kDLUInt: + doc << "u"; + break; + case kDLFloat: + doc << "f"; + break; + } + doc << Doc::Text(std::to_string(dtype.bits())); + if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); + } + return doc; +} + +Doc TIRTextPrinter::GetUniqueName(std::string prefix) { + // std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { + } + } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} + +Doc TIRTextPrinter::AllocVar(const Var& var) { + const auto& it = memo_var_.find(var); + if (it != memo_var_.end()) { + return it->second; + } + std::string name = var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName(name); + memo_var_[var] = val; + return val << ": " << Print(GetType(var)); +} + +Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { + const auto& it = memo_buf_.find(buffer); + if (it != memo_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_buf_[buffer] = val; + return val; +} + +Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq << sep << vec[i]; + } + } + return seq; +} + +Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { + Doc doc; + if (body->IsInstance()) return Print(body); + doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}"; + return doc; +} + +} // namespace tir +} // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 103ddcb31111e..587add36706fe 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -19,14 +19,13 @@ #include "annotated_region_set.h" -#include #include +#include #include #include #include - namespace tvm { namespace relay { @@ -39,8 +38,7 @@ AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { return AnnotatedRegion(nullptr); } -void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, - AnnotatedRegion dest) { +void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) { if (dest == src) { return; } @@ -104,12 +102,12 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { for (auto arg : args) { const CallNode* end = arg.as(); if (end && end->op == end_op_) { // Ignore closed regions. - continue; + continue; } region = region_set_->GetRegion(arg); if (region.defined()) { - break; + break; } } @@ -117,7 +115,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { for (auto arg : args) { const CallNode* end = arg.as(); if (end && end->op == end_op_) { // Ignore closed regions. - continue; + continue; } auto arg_region = region_set_->GetRegion(arg); @@ -171,9 +169,7 @@ class AnnotatedRegionSet::Creator : protected MixedModeVisitor { } } - void VisitExpr_(const TupleNode* op) { - AddToArgRegion(GetRef(op), op->fields); - } + void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef(op), op->fields); } void VisitExpr_(const TupleGetItemNode* g) { Array args = {g->tuple}; @@ -227,15 +223,14 @@ TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet") -.set_body_typed([](Expr expr, Op begin, Op end) { - return AnnotatedRegionSet::Create(expr, begin, end); -}); + .set_body_typed([](Expr expr, Op begin, Op end) { + return AnnotatedRegionSet::Create(expr, begin, end); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRegion") -.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { - return region_set->GetRegion(expr); -}); - + .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { + return region_set->GetRegion(expr); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 3bd569387d461..f12db6ae05469 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -27,19 +27,19 @@ #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ +#include #include #include #include -#include #include -#include #include +#include +#include #include #include #include #include -#include namespace tvm { namespace relay { @@ -61,29 +61,19 @@ class AnnotatedRegionNode : public Object { } /*! \brief Get the region ID. */ - int GetID() const { - return id_; - } + int GetID() const { return id_; } /*! \brief Get the region target. */ - std::string GetTarget() const { - return target_; - } + std::string GetTarget() const { return target_; } /*! \brief Get the region's inputs. */ - std::list GetInputs() const { - return ins_; - } + std::list GetInputs() const { return ins_; } /*! \brief Get the region's outputs. */ - std::list GetOutputs() const { - return outs_; - } + std::list GetOutputs() const { return outs_; } /*! \brief Get the region's nodes. */ - std::unordered_set GetNodes() const { - return nodes_; - } + std::unordered_set GetNodes() const { return nodes_; } static constexpr const char* _type_key = "relay.AnnotatedRegion"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); @@ -107,7 +97,7 @@ class AnnotatedRegionNode : public Object { /*! * \brief An object to hold the properties of a region as used by the * AnnotatedRegionSet class. This should be considered read-only. -*/ + */ class AnnotatedRegion : public ObjectRef { public: AnnotatedRegion() { @@ -116,9 +106,9 @@ class AnnotatedRegion : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * \param n The object pointer. + */ explicit AnnotatedRegion(ObjectPtr n) : ObjectRef(n) {} /*! \return Mutable pointers to the node. */ @@ -130,8 +120,7 @@ class AnnotatedRegion : public ObjectRef { }; class AnnotatedRegionSetNode : public Object { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -141,21 +130,13 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegionSetNode() = default; /*! \return The begin iterator */ - iterator begin() { - return regions_.begin(); - } + iterator begin() { return regions_.begin(); } /*! \return The end iterator */ - iterator end() { - return regions_.end(); - } + iterator end() { return regions_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return regions_.begin(); - } + const_iterator begin() const { return regions_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return regions_.end(); - } + const_iterator end() const { return regions_.end(); } /*! * \brief Get the region that an expression belongs to. @@ -168,11 +149,11 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegion GetRegion(const Expr& expr) const; /*! - * \brief Merge src region into dest region. - * - * \param src The region to merge - will be erased. - * \param dest The region into which src will be merged. - */ + * \brief Merge src region into dest region. + * + * \param src The region to merge - will be erased. + * \param dest The region into which src will be merged. + */ void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest); void VisitAttrs(AttrVisitor* v) { @@ -214,8 +195,7 @@ class AnnotatedRegionSetNode : public Object { * to update and query regions. */ class AnnotatedRegionSet : public ObjectRef { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -227,10 +207,10 @@ class AnnotatedRegionSet : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * + * \param n The object pointer. + */ explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} /*! \return The begin iterator. */ @@ -253,7 +233,7 @@ class AnnotatedRegionSet : public ObjectRef { } /*! \return The end iterator. */ const_iterator end() const { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->end(); } @@ -267,7 +247,7 @@ class AnnotatedRegionSet : public ObjectRef { /*! \return The region an expression belongs to. */ AnnotatedRegion operator[](const Expr& expr) { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->GetRegion(expr); } @@ -280,9 +260,7 @@ class AnnotatedRegionSet : public ObjectRef { * * \return The created RegionSet for the expression. */ - static AnnotatedRegionSet Create(const Expr& expr, - const Op& begin, - const Op& end); + static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end); private: /*! \brief Helper class to construct a RegionSet from an expr.*/ diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index a12d23d88a303..0d3fedcde0f79 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -26,6 +26,7 @@ #include #include + #include #include #include @@ -72,22 +73,21 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const { CHECK(module->ContainGlobalVar(var->name_hint)) - << "GlobalVar " << var->name_hint - << " not found in the current ir module"; + << "GlobalVar " << var->name_hint << " not found in the current ir module"; return module->Lookup(var); } @@ -120,8 +120,8 @@ GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph) { CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) << "Cannot remove global var " << cg_node->GetNameHint() - << " from call graph, because it still calls " - << cg_node->size() << " other global functions"; + << " from call graph, because it still calls " << cg_node->size() + << " other global functions"; if (update_call_graph) { // Update the call graph by removing all edges that point to the node @@ -172,8 +172,7 @@ std::vector CallGraphNode::TopologicalOrder() const { << " with # refs = " << (*this)[it.first]->GetRefCount(); } } - LOG(FATAL) << "Expected " << module->functions.size() - << " globals, but received " + LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received " << ret.size(); } @@ -184,8 +183,7 @@ std::vector CallGraphNode::TopologicalOrder() const { // that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting // the nodes for an entire call graph. -std::vector CallGraphEntry::TopologicalOrder( - CallGraphEntrySet* visited) const { +std::vector CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const { std::vector ret; std::vector current_nodes; if (visited->find(this) == visited->end()) { @@ -234,8 +232,7 @@ inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { // Remove an edge from the current global function to the callee. void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { for (auto it = begin();; ++it) { - CHECK(it != end()) << "Cannot find global function " - << callee->name_hint << " to remove!"; + CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!"; if (it->second->GetGlobalVar() == callee) { // Only remove one occurrence of the call site. it->second->DecRef(); @@ -260,8 +257,7 @@ void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { } // Make sure all references to the callee are removed. CHECK_EQ(callee->GetRefCount(), 0U) - << "All references to " << callee->GetNameHint() - << " should have been removed"; + << "All references to " << callee->GetNameHint() << " should have been removed"; } void CallGraphEntry::Print(std::ostream& os) const { @@ -293,54 +289,51 @@ std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { TVM_REGISTER_NODE_TYPE(CallGraphNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - CHECK(node); - p->stream << "CallGraph: \n" << GetRef(node); -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + CHECK(node); + p->stream << "CallGraph: \n" << GetRef(node); + }); -TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") -.set_body_typed([](IRModule module) { +TVM_REGISTER_GLOBAL("relay.analysis.CallGraph").set_body_typed([](IRModule module) { return CallGraph(module); }); -TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph").set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); -TVM_REGISTER_GLOBAL("relay.analysis.GetModule") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.GetModule").set_body_typed([](CallGraph call_graph) { return call_graph->module; }); TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - std::stringstream ss; - ss << *entry_node; - return ss.str(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + std::stringstream ss; + ss << *entry_node; + return ss.str(); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->GetRefCount()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->GetRefCount()); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->size()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->size()); + }); TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return entry_node->IsRecursive(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return entry_node->IsRecursive(); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 86bc6469c3164..387d2d3e046b4 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -32,6 +32,7 @@ #include #include #include + #include #include #include @@ -47,8 +48,7 @@ class CallGraph; class CallGraphNode : public Object { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraphNode object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -60,9 +60,7 @@ class CallGraphNode : public Object { /*! \brief Default constructor. */ CallGraphNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("module", &module); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("module", &module); } /*! * \brief Print the call graph. @@ -72,21 +70,13 @@ class CallGraphNode : public Object { void Print(std::ostream& os) const; /*! \return The begin iterator. */ - iterator begin() { - return call_graph_.begin(); - } + iterator begin() { return call_graph_.begin(); } /*! \return The end iterator. */ - iterator end() { - return call_graph_.end(); - } + iterator end() { return call_graph_.end(); } /*! \return The begin iterator. */ - const_iterator begin() const { - return call_graph_.begin(); - } + const_iterator begin() const { return call_graph_.begin(); } /*! \return The end iterator. */ - const_iterator end() const { - return call_graph_.end(); - } + const_iterator end() const { return call_graph_.end(); } /*! * \brief Get an element from the CallGraphNode using a GlobalVar. @@ -157,8 +147,7 @@ class CallGraphNode : public Object { * * \return The GlobalVar removed from the current module. */ - GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, - bool update_call_graph = false); + GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false); /*! * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for @@ -207,8 +196,7 @@ class CallGraphNode : public Object { */ class CallGraph : public ObjectRef { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraph object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -340,30 +328,20 @@ class CallGraphEntry { CallGraphEntry& operator=(const CallGraphEntry&) = delete; /*! \return The begin iterator */ - iterator begin() { - return called_globals_.begin(); - } + iterator begin() { return called_globals_.begin(); } /*! \return The end iterator */ - iterator end() { - return called_globals_.end(); - } + iterator end() { return called_globals_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return called_globals_.begin(); - } + const_iterator begin() const { return called_globals_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return called_globals_.end(); - } + const_iterator end() const { return called_globals_.end(); } /*! * \brief Return if the list of called nodes is empty. * * \return true if the list is empty. Otherwise, false. */ - bool empty() const { - return called_globals_.empty(); - } + bool empty() const { return called_globals_.empty(); } /*! * \brief Return the size of the list that represents the nodes are called by @@ -371,9 +349,7 @@ class CallGraphEntry { * * \return The number of called nodes. */ - uint32_t size() const { - return static_cast(called_globals_.size()); - } + uint32_t size() const { return static_cast(called_globals_.size()); } /*! * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called @@ -400,27 +376,21 @@ class CallGraphEntry { * * \return The count. */ - uint32_t GetRefCount() const { - return ref_cnt_; - } + uint32_t GetRefCount() const { return ref_cnt_; } /*! * \brief Return the GlobalVar stored in the current CallGraphEntry. * * \return The GlobalVar. */ - GlobalVar GetGlobalVar() const { - return global_; - } + GlobalVar GetGlobalVar() const { return global_; } /*! * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry. * * \return The name hint of the global function. */ - std::string GetNameHint() const { - return global_->name_hint; - } + std::string GetNameHint() const { return global_->name_hint; } /*! * \brief Return if the global function corresponding to the current @@ -428,9 +398,7 @@ class CallGraphEntry { * * \return true if it is recursive. Otherwise, false. */ - bool IsRecursive() const { - return is_recursive_; - } + bool IsRecursive() const { return is_recursive_; } /*! * \brief Return if the global function corresponding to the current @@ -439,9 +407,7 @@ class CallGraphEntry { * * \return true if it is both a recursive function and an entry. Otherwise, false. */ - bool IsRecursiveEntry() const { - return GetRefCount() == 1 && IsRecursive(); - } + bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); } /*! * \brief Return the topological order of the CallGraphEntry. diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 7e48d12d0cf3d..a583e9a63eb2d 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -22,7 +22,9 @@ * \brief Implementation of dependency graph APIs. */ #include "dependency_graph.h" + #include + #include #include @@ -32,8 +34,7 @@ namespace relay { // Creator of DependencyGraph class DependencyGraph::Creator : private ExprFunctor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} DependencyGraph Create(const Expr& body) { this->VisitExpr(body); @@ -164,15 +165,15 @@ class DependencyGraph::Creator : private ExprFunctor { } } - void VisitExpr_(const VarNode* v) final { } + void VisitExpr_(const VarNode* v) final {} - void VisitExpr_(const GlobalVarNode* v) final { } + void VisitExpr_(const GlobalVarNode* v) final {} - void VisitExpr_(const ConstantNode* c) final { } + void VisitExpr_(const ConstantNode* c) final {} - void VisitExpr_(const OpNode* o) final { } + void VisitExpr_(const OpNode* o) final {} - void VisitExpr_(const ConstructorNode* c) final { } + void VisitExpr_(const ConstructorNode* c) final {} }; DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) { diff --git a/src/relay/analysis/dependency_graph.h b/src/relay/analysis/dependency_graph.h index 5e2dc0c899d1f..4aad95e712411 100644 --- a/src/relay/analysis/dependency_graph.h +++ b/src/relay/analysis/dependency_graph.h @@ -25,16 +25,18 @@ #define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_ #include + #include #include -#include "../transforms/let_list.h" + #include "../../support/arena.h" +#include "../transforms/let_list.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /* DependencyGraph track input and output of an Expr. * Additionally, dummy scope is created to model scope. diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 95c2f731ff72d..9e94459d70184 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -21,11 +21,12 @@ * \file feature.cc * \brief Detect features used in Expr/Module */ -#include +#include #include #include #include -#include +#include + #include "../transforms/pass_util.h" namespace tvm { @@ -49,34 +50,30 @@ FeatureSet DetectFeature(const Expr& expr) { } } } -#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ - void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ - STMT \ - fs += f##CONSTRUCT_NAME; \ - } -#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \ - ExprVisitor::VisitExpr_(op); \ - }) +#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ + void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; } +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \ + DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); }) DETECT_DEFAULT_CONSTRUCT(Var) DETECT_DEFAULT_CONSTRUCT(GlobalVar) DETECT_DEFAULT_CONSTRUCT(Constant) DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->HasNonzeroAttr(attr::kPrimitive)) { - ExprVisitor::VisitExpr_(op); - } - }) + if (!op->HasNonzeroAttr(attr::kPrimitive)) { + ExprVisitor::VisitExpr_(op); + } + }) DETECT_DEFAULT_CONSTRUCT(Op) DETECT_DEFAULT_CONSTRUCT(Call) DETECT_CONSTRUCT(Let, { - for (const Var& v : FreeVars(op->value)) { - if (op->var == v) { - fs += fLetRec; - } + for (const Var& v : FreeVars(op->value)) { + if (op->var == v) { + fs += fLetRec; } - ExprVisitor::VisitExpr_(op); - }) + } + ExprVisitor::VisitExpr_(op); + }) DETECT_DEFAULT_CONSTRUCT(If) DETECT_DEFAULT_CONSTRUCT(RefCreate) DETECT_DEFAULT_CONSTRUCT(RefRead) @@ -104,8 +101,7 @@ Array PyDetectFeature(const Expr& expr, const IRModule& mod) { return static_cast>(fs); } -TVM_REGISTER_GLOBAL("relay.analysis.detect_feature") -.set_body_typed(PyDetectFeature); +TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index b4835ccb7a3c3..ac0abc0655578 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -31,9 +31,9 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ +#include #include #include -#include namespace tvm { namespace relay { @@ -51,40 +51,28 @@ struct KindChecker : TypeFunctor { this->err_reporter.RenderErrors(mod); } - void CheckKindMatches(const Type& t, const Type& outer, - Kind expected, const std::string& description) { + void CheckKindMatches(const Type& t, const Type& outer, Kind expected, + const std::string& description) { Kind k = this->VisitType(t); if (k != expected) { ReportFatalError(ErrorBuilder() - << "Incorrect kind for a " << description - << ". Type " << t << " inside " << outer - << " is of kind " << k - << " but was expected to be " - << expected); + << "Incorrect kind for a " << description << ". Type " << t << " inside " + << outer << " is of kind " << k << " but was expected to be " << expected); } } - Kind VisitType_(const IncompleteTypeNode* op) override { - return op->kind; - } + Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; } - Kind VisitType_(const TypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const TypeVarNode* op) override { return op->kind; } - Kind VisitType_(const GlobalTypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; } - Kind VisitType_(const TensorTypeNode* op) override { - return Kind::kType; - } + Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; } Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "tuple member"); + CheckKindMatches(t, GetRef(op), Kind::kType, "tuple member"); } return Kind::kType; } @@ -117,8 +105,7 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "argument to type relation"); + CheckKindMatches(t, GetRef(op), Kind::kType, "argument to type relation"); } return Kind::kConstraint; } @@ -128,9 +115,8 @@ struct KindChecker : TypeFunctor { TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); if (gtv == nullptr) { - ReportFatalError( - ErrorBuilder() <<"The callee in " << tc - << " is not a global type var, but is " << op->func); + ReportFatalError(ErrorBuilder() << "The callee in " << tc + << " is not a global type var, but is " << op->func); } CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); @@ -143,9 +129,8 @@ struct KindChecker : TypeFunctor { auto var = GetRef(gtv); auto data = mod->LookupTypeDef(var); if (data->type_vars.size() != op->args.size()) { - ReportFatalError(ErrorBuilder() - << "Expected " << data->type_vars.size() << "arguments for " << tc - << "; got " << op->args.size()); + ReportFatalError(ErrorBuilder() << "Expected " << data->type_vars.size() << "arguments for " + << tc << "; got " << op->args.size()); } return Kind::kType; } @@ -164,9 +149,8 @@ struct KindChecker : TypeFunctor { for (const auto& con : op->constructors) { if (!con->belong_to.same_as(op->header)) { - ReportFatalError(ErrorBuilder() - <belong_to - << " but " << op << " has header " << op->header); + ReportFatalError(ErrorBuilder() << con << " has header " << con->belong_to << " but " << op + << " has header " << op->header); } for (const Type& t : con->inputs) { @@ -176,9 +160,7 @@ struct KindChecker : TypeFunctor { return Kind::kTypeData; } - Kind Check(const Type& t) { - return this->VisitType(t); - } + Kind Check(const Type& t) { return this->VisitType(t); } }; Kind KindCheck(const Type& t, const IRModule& mod) { @@ -186,14 +168,13 @@ Kind KindCheck(const Type& t, const IRModule& mod) { return kc.Check(t); } -TVM_REGISTER_GLOBAL("relay.analysis.check_kind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 1) { - *ret = KindCheck(args[0], IRModule({}, {})); - } else { - *ret = KindCheck(args[0], args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = KindCheck(args[0], IRModule({}, {})); + } else { + *ret = KindCheck(args[0], args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index fecde3c756695..882bba9e4bf9b 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -26,11 +26,12 @@ * otherwise the count is 0. */ -#include +#include #include #include -#include +#include #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -52,8 +53,7 @@ inline int64_t GetCartesianProd(Array arr) { * \param call_node The call node. * \return The number of MACs. */ -using FMacCount = runtime::TypedPackedFunc< - int64_t(const Call& call_node)>; +using FMacCount = runtime::TypedPackedFunc; //---------------------------------------------- // Per operator defs for MAC count @@ -65,30 +65,26 @@ int64_t ConvMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; - CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D should be 2."; + CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_attr->groups; return count; } @@ -99,29 +95,27 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { } Array args = call_node->args; CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D Transpose node should be 2."; + << "The number of input arguments of a CONV 2D Transpose node should be 2."; const auto* conv_2d_transpose_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_transpose_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_transpose_attr->kernel_size; CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D Transpose should be 2."; + << "The dimension of the kernel in Conv 2D Transpose should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_transpose_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_transpose_attr->groups; return count; } @@ -131,20 +125,18 @@ int64_t DenseMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a Dense node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2."; const auto* data_type = args[0]->checked_type().as(); const auto* weight_type = args[1]->checked_type().as(); Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; + << "The dimension of an input tensor to Dense node should be 2."; int64_t d1 = static_cast(data_shape[0].as()->value); int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); - CHECK_EQ(d2, d4) - << "The dimensions of input arguments do not match."; + CHECK_EQ(d2, d4) << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } @@ -165,23 +157,17 @@ int64_t BatchMatmulMacCount(const Call& call_node) { return batch * m * k * n; } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FMacCount", ConvMacCount); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FMacCount", ConvMacCount); -RELAY_REGISTER_OP("nn.conv2d_transpose") -.set_attr("FMacCount", Conv2dTransposeMacCount); +RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr("FMacCount", Conv2dTransposeMacCount); -RELAY_REGISTER_OP("nn.dense") -.set_attr("FMacCount", DenseMacCount); +RELAY_REGISTER_OP("nn.dense").set_attr("FMacCount", DenseMacCount); -RELAY_REGISTER_OP("nn.batch_matmul") -.set_attr("FMacCount", BatchMatmulMacCount); +RELAY_REGISTER_OP("nn.batch_matmul").set_attr("FMacCount", BatchMatmulMacCount); class MacCounter : private ExprVisitor { public: - MacCounter() { - count_ = 0; - } + MacCounter() { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { LOG(INFO) << "This pass only counts MACs in direct conv2d, " << "conv2d_transpose, dense, and batch_matmul ops"; @@ -192,8 +178,7 @@ class MacCounter : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { - static const auto& fprep = - Op::GetAttr("FMacCount"); + static const auto& fprep = Op::GetAttr("FMacCount"); auto f = fprep.get(call_node->op, nullptr); if (f != nullptr) count_ += f(GetRef(call_node)); ExprVisitor::VisitExpr_(call_node); @@ -202,12 +187,9 @@ class MacCounter : private ExprVisitor { int64_t count_; }; -int64_t GetTotalMacNumber(const Expr& expr) { - return MacCounter::GetTotalMacNumber(expr); -} +int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber") -.set_body_typed(GetTotalMacNumber); +TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber").set_body_typed(GetTotalMacNumber); } // namespace mac_count } // namespace relay diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index eeb7fce18c52e..96dab6b604956 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -27,10 +27,11 @@ * code correctness, since hitting an unmatched case results in a * dynamic error unless exhaustiveness is checked in advance. */ -#include #include +#include #include #include + #include namespace tvm { @@ -154,17 +155,14 @@ Array> CartesianProduct(Array> fields) { } Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod); + const Pattern& cand, const IRModule& mod); -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod); // Expands all wildcards in the candidate pattern once // Returns a list of all possible expansions. -Array ExpandWildcards(const Pattern& clause_pat, - const Pattern& cand, +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, const IRModule& mod) { if (auto clause_ctor = clause_pat.as()) { return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); @@ -179,8 +177,7 @@ Array ExpandWildcards(const Pattern& clause_pat, // Use the pattern to decide which constructors to insert. // Returns a list of all possible expansions. Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod) { + const Pattern& cand, const IRModule& mod) { auto gtv = Downcast(clause_ctor->constructor->belong_to); // for a wildcard node, create constructor nodes with wildcards for all args. @@ -203,9 +200,8 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], - ctor_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod)); } // generate new candidates using a cartesian product. @@ -219,8 +215,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // Expands all wildcards in the candidate pattern once. // Returns a list of all possible expansions. -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod) { // for a wildcard node, create constructor nodes with wildcards for all args. if (cand.as()) { @@ -236,9 +231,8 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], - tuple_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod)); } // generate new candidates using a cartesian product @@ -311,14 +305,13 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") -.set_body_typed( - [](const Match& match, const IRModule& mod_ref) { - IRModule call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = IRModule({}, {}); - } - return UnmatchedCases(match, call_mod); - }); + .set_body_typed([](const Match& match, const IRModule& mod_ref) { + IRModule call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = IRModule({}, {}); + } + return UnmatchedCases(match, call_mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 650403ca5267b..05e231acecf90 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -21,26 +21,25 @@ * \file type_solver.cc * \brief Type solver implementations. */ -#include +#include "type_solver.h" + #include +#include #include -#include + #include +#include #include #include -#include "type_solver.h" namespace tvm { namespace relay { class TypeSolver::Reporter : public TypeReporterNode { public: - explicit Reporter(TypeSolver* solver) - : solver_(solver) {} + explicit Reporter(TypeSolver* solver) : solver_(solver) {} - void Assign(const Type& dst, const Type& src) final { - solver_->Unify(dst, src, location); - } + void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, location); } bool Assert(const IndexExpr& cond) final { if (const int64_t* pdiff = tir::as_const_int(cond)) { @@ -58,13 +57,9 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const ObjectRef& ref) final { - location = ref; - } + TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } - TVM_DLL IRModule GetModule() final { - return this->solver_->module_; - } + TVM_DLL IRModule GetModule() final { return this->solver_->module_; } private: /*! \brief The location to report unification errors at. */ @@ -76,7 +71,7 @@ class TypeSolver::Reporter : public TypeReporterNode { class TypeSolver::OccursChecker : public TypeVisitor { public: explicit OccursChecker(TypeSolver* solver, TypeNode* var) - : solver_(solver), var_(var), found_(false) {} + : solver_(solver), var_(var), found_(false) {} bool Check(const Type& t) { VisitType(t); @@ -112,25 +107,24 @@ class TypeSolver::Unifier : public TypeFunctor { if (lhs->resolved_type.as()) { CHECK(!OccursCheck(lhs, rhs->resolved_type)) - << "Incomplete type " << lhs->resolved_type << " occurs in " - << rhs->resolved_type << ", cannot unify"; + << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { CHECK(!OccursCheck(rhs, lhs->resolved_type)) - << "Incomplete type " << rhs->resolved_type << " occurs in " - << lhs->resolved_type << ", cannot unify"; + << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); if (!resolved.defined()) { - solver_->ReportError( - ErrorBuilder() << "unable to unify: " - << "`" << PrettyPrint(lhs->resolved_type) << "` and `" - << PrettyPrint(rhs->resolved_type) << "`", - this->loc); + solver_->ReportError(ErrorBuilder() << "unable to unify: " + << "`" << PrettyPrint(lhs->resolved_type) << "` and `" + << PrettyPrint(rhs->resolved_type) << "`", + this->loc); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -227,14 +221,11 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->ReportError( - ErrorBuilder() << - "tensor type `" << PrettyPrint(tt1) << - "` has " << tt1->shape.size() << - " dimensions, while `" << - PrettyPrint(tt2) << - "` has " << tt2->shape.size() << - " dimensions", this->loc); + this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions", + this->loc); return Type(nullptr); } @@ -259,12 +250,8 @@ class TypeSolver::Unifier : public TypeFunctor { ErrorBuilder err; err << "in particular "; for (auto mismatch : mismatches) { - err << "dimension " - << std::get<0>(mismatch) - << " conflicts " - << std::get<1>(mismatch) - << " does not match " - << std::get<2>(mismatch); + err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) + << " does not match " << std::get<2>(mismatch); } Error error(err); this->solver_->ReportError(error, this->loc); @@ -293,9 +280,8 @@ class TypeSolver::Unifier : public TypeFunctor { Type VisitType_(const FuncTypeNode* op, const Type& tn) final { const auto* ftn = tn.as(); - if (!ftn - || op->arg_types.size() != ftn->arg_types.size() - || op->type_constraints.size() != ftn->type_constraints.size()) { + if (!ftn || op->arg_types.size() != ftn->arg_types.size() || + op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } @@ -316,10 +302,7 @@ class TypeSolver::Unifier : public TypeFunctor { subst_map.Set(op->type_params[i], IncompleteType(kType)); } - FuncType ft = FuncType(op->arg_types, - op->ret_type, - ft_type_params, - op->type_constraints); + FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints); auto ft1 = Downcast(Bind(ft, subst_map)); auto ft2 = GetRef(ftn); @@ -333,8 +316,7 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); ++i) { - Type unified_constraint = Unify(ft1->type_constraints[i], - ft2->type_constraints[i]); + Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); CHECK(tcn) << "Two type constraints unified into a non-constraint?" << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; @@ -397,12 +379,10 @@ class TypeSolver::Resolver : public TypeMutator { class TypeSolver::Propagator : public TypeFunctor { public: explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) - : solver_(solver), rels_(rels) {} + : solver_(solver), rels_(rels) {} // adds the relation node to t and all child types of t - void Propagate(const Type& t) { - VisitType(t); - } + void Propagate(const Type& t) { VisitType(t); } void UpdateRelSet(const Type& t) { TypeNode* tnode = solver_->GetTypeNode(t); @@ -532,10 +512,8 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver( - const GlobalVar& current_func, - const IRModule& module, - ErrorReporter* err_reporter) +TypeSolver::TypeSolver(const GlobalVar& current_func, const IRModule& module, + ErrorReporter* err_reporter) : reporter_(make_object(this)), current_func(current_func), err_reporter_(err_reporter), @@ -566,7 +544,7 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { +void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { CHECK(location.defined()); CHECK(current_func.defined()); err_reporter_->ReportAt(current_func, location, err); @@ -583,20 +561,19 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef // populate the type information. for (size_t i = 0; i < op->args.size(); ++i) { // insert link to the type list - LinkNode* tlink = arena_.make >(); + LinkNode* tlink = arena_.make>(); TypeNode* tnode = GetTypeNode(op->args[i]); tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - std::unordered_set singleton { rnode }; + std::unordered_set singleton{rnode}; Propagator prop(this, &singleton); prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); } else { - LOG(FATAL) << "Do not know how to handle constraint type" - << constraint->GetTypeKey(); + LOG(FATAL) << "Do not know how to handle constraint type" << constraint->GetTypeKey(); } } @@ -642,11 +619,9 @@ bool TypeSolver::Solve() { rnode->resolved = false; } catch (const dmlc::Error& err) { rnode->resolved = false; - this->ReportError( - ErrorBuilder() << "an internal invariant was violated while " - << "typechecking your program " - << err.what(), - rnode->location); + this->ReportError(ErrorBuilder() << "an internal invariant was violated while " + << "typechecking your program " << err.what(), + rnode->location); } // Mark inqueue as false after the function call @@ -661,45 +636,40 @@ bool TypeSolver::Solve() { // Expose type solver only for debugging purposes. TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") -.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - ErrorReporter *err_reporter = new ErrorReporter(); - auto module = IRModule({}, {}); - auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); - auto solver = std::make_shared(dummy_fn_name, module, err_reporter); - - auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { - if (name == "Solve") { - return TypedPackedFunc([solver]() { - return solver->Solve(); - }); - } else if (name == "Unify") { - return TypedPackedFunc( - [module, solver, err_reporter](Type lhs, Type rhs) { - auto res = solver->Unify(lhs, rhs, lhs); - if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(module, true); - } - return res; - }); - } else if (name == "Resolve") { - return TypedPackedFunc([solver](Type t) { - return solver->Resolve(t); - }); - } else if (name == "AddConstraint") { - return TypedPackedFunc([solver](TypeConstraint c) { - Expr e = Var("dummy_var", - IncompleteType(Kind::kType)); + .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + ErrorReporter* err_reporter = new ErrorReporter(); + auto module = IRModule({}, {}); + auto dummy_fn_name = GlobalVar("test"); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); + auto solver = std::make_shared(dummy_fn_name, module, err_reporter); + + auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { + if (name == "Solve") { + return TypedPackedFunc([solver]() { return solver->Solve(); }); + } else if (name == "Unify") { + return TypedPackedFunc( + [module, solver, err_reporter](Type lhs, Type rhs) { + auto res = solver->Unify(lhs, rhs, lhs); + if (err_reporter->AnyErrors()) { + err_reporter->RenderErrors(module, true); + } + return res; + }); + } else if (name == "Resolve") { + return TypedPackedFunc([solver](Type t) { return solver->Resolve(t); }); + } else if (name == "AddConstraint") { + return TypedPackedFunc([solver](TypeConstraint c) { + Expr e = Var("dummy_var", IncompleteType(Kind::kType)); return solver->AddConstraint(c, e); }); - } else { - return PackedFunc(); - } - }; - *ret = runtime::TypedPackedFunc(mod); - }); + } else { + return PackedFunc(); + } + }; + *ret = runtime::TypedPackedFunc(mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 8ccc2c7244b03..9b7c06c7fce9a 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -24,21 +24,23 @@ #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ +#include +#include #include #include -#include -#include -#include + #include #include #include +#include + #include "../../support/arena.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /*! * \brief Interface of type solver used in type inference. diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 7ed2cda8d2e8f..57f6f3bb64810 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -29,12 +29,13 @@ #include #include #include + #include "../transforms/pass_util.h" namespace tvm { namespace relay { -template +template struct InsertionSet { std::unordered_set set; std::vector data; @@ -48,10 +49,8 @@ struct InsertionSet { class TypeVarTVisitor : public TypeVisitor { public: - TypeVarTVisitor( - InsertionSet* type_vars, - InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + TypeVarTVisitor(InsertionSet* type_vars, InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {} void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); @@ -150,8 +149,7 @@ class TypeVarEVisitor : private ExprVisitor { } void VisitType(const Type& t) final { - TypeVarTVisitor(&type_vars_, &bound_type_vars_) - .VisitType(t); + TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t); } private: @@ -205,9 +203,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { vars_.Insert(v); } - void VisitExpr_(const VarNode* var) final { - vars_.Insert(GetRef(var)); - } + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { @@ -222,13 +218,9 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { VisitExpr(op->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - MarkBounded(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); } private: InsertionSet vars_; @@ -259,82 +251,66 @@ tvm::Array AllTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).All(type); } -tvm::Array FreeVars(const Expr& expr) { - return VarVisitor().Free(expr); -} +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } -tvm::Array BoundVars(const Expr& expr) { - return VarVisitor().Bound(expr); -} +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } -tvm::Array BoundVars(const Pattern& pat) { - return VarVisitor().Bound(pat); -} +tvm::Array BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); } -tvm::Array AllVars(const Expr& expr) { - return VarVisitor().All(expr); -} +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.free_vars") -.set_body_typed(FreeVars); +TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relay.analysis.bound_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - if (x.as()) { - *ret = BoundVars(Downcast(x)); - } else { - *ret = BoundVars(Downcast(x)); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + if (x.as()) { + *ret = BoundVars(Downcast(x)); + } else { + *ret = BoundVars(Downcast(x)); + } +}); -TVM_REGISTER_GLOBAL("relay.analysis.all_vars") -.set_body_typed(AllVars); +TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = FreeTypeVars(Downcast(x), mod); - } else { - *ret = FreeTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = BoundTypeVars(Downcast(x), mod); - } else { - *ret = BoundTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = AllTypeVars(Downcast(x), mod); - } else { - *ret = AllTypeVars(Downcast(x), mod); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = FreeTypeVars(Downcast(x), mod); + } else { + *ret = FreeTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = BoundTypeVars(Downcast(x), mod); + } else { + *ret = BoundTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = AllTypeVars(Downcast(x), mod); + } else { + *ret = AllTypeVars(Downcast(x), mod); + } +}); /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map -GetExprRefCount(const Expr& body) { +std::unordered_map GetExprRefCount(const Expr& body) { class ExprRefCounter : private MixedModeVisitor { public: - std::unordered_map - Get(const Expr& body) { + std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter_); } @@ -392,9 +368,7 @@ bool IsAllPositiveConstant(const Expr& expr) { } } else if (const auto* op = expr.as()) { // tail recursion. - if (op->op == expand_dims_op || - op->op == reshape_op || - op->op == transpose_op || + if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op || op->op == squeeze_op) { return IsAllPositiveConstant(op->args[0]); } else { @@ -420,17 +394,11 @@ Type TypeSubst(const Type& type, const tvm::Map& subst_map) { Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { class TypeSubstMutator : public ExprMutator, public PatternMutator { public: - explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) { } - Type VisitType(const Type& t) final { - return TypeSubst(t, subst_map_); - } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) {} + Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index f3a2cadb363fb..33f52c9a8397d 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -24,12 +24,12 @@ #include #include #include + #include namespace tvm { namespace relay { - //! brief make sure each Var is bound at most once in a scope. class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; @@ -41,9 +41,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { struct Scope { WellFormedChecker* wfc; - explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { - wfc->scope.push_back({{}}); - } + explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); } ~Scope() { CHECK_GE(wfc->scope.size(), 0); for (const Var& v : wfc->scope.back()) { @@ -98,13 +96,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { VisitExpr(c->rhs); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitVar(const Var& v) final { - Bound(v); - } + void VisitVar(const Var& v) final { Bound(v); } void VisitExpr(const Expr& e) final { if (auto v = e.as()) { @@ -121,12 +115,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { } }; -bool WellFormed(const Expr& e) { - return WellFormedChecker().CheckWellFormed(e); -} +bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_GLOBAL("relay.analysis.well_formed") -.set_body_typed(WellFormed); +TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed(WellFormed); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index c26228ebeba34..ef273c30d296a 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -21,13 +21,14 @@ * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ -#include #include -#include -#include +#include #include -#include #include +#include +#include +#include + #include #include "../../target/source/codegen_source_base.h" @@ -37,7 +38,6 @@ namespace tvm { namespace relay { namespace backend { - using TargetsMap = Map; using namespace tvm::relay::transform; @@ -63,17 +63,11 @@ struct GraphCodegen { } ~GraphCodegen() {} - void Init(runtime::Module* m, TargetsMap targets) { - CallFunc("init", m, targets); - } + void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } - void Codegen(const Function& func) { - CallFunc("codegen", func); - } + void Codegen(const Function& func) { CallFunc("codegen", func); } - std::string GetJSON() { - return CallFunc("get_graph_json", nullptr); - } + std::string GetJSON() { return CallFunc("get_graph_json", nullptr); } Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); @@ -96,13 +90,13 @@ struct GraphCodegen { protected: tvm::runtime::Module mod; - template - R CallFunc(const std::string &name, Args... args) { + template + R CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); return pf(std::forward(args)...); } - template - void CallFunc(const std::string &name, Args... args) { + template + void CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); pf(std::forward(args)...); return; @@ -121,29 +115,24 @@ class RelayBuildModule : public runtime::ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGraphJSON(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); } else if (name == "get_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetModule(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); this->Build(args[0], args[1], args[2]); }); } else if (name == "list_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->ListParamNames(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); }); } else if (name == "get_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetParams(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -153,11 +142,11 @@ class RelayBuildModule : public runtime::ModuleNode { }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetIRModule(); + *rv = this->graph_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalModules(); + *rv = this->graph_codegen_->GetExternalModules(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -175,18 +164,14 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return const std::string graph_json */ - const std::string& GetGraphJSON() { - return ret_.graph_json; - } + const std::string& GetGraphJSON() { return ret_.graph_json; } /*! * \brief Get the Module object * * \return runtime::Module */ - runtime::Module GetModule() { - return ret_.mod; - } + runtime::Module GetModule() { return ret_.mod; } /*! * \brief List all paramter names @@ -220,18 +205,14 @@ class RelayBuildModule : public runtime::ModuleNode { * \param name name of parameter * \param data_in input DLTensor */ - void SetParam(const std::string& name, runtime::NDArray data_in) { - params_[name] = data_in; - } + void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } /*! * \brief type key * * \return const char* */ - const char* type_key() const final { - return "RelayBuildModule"; - } + const char* type_key() const final { return "RelayBuildModule"; } /*! * \brief Build relay IRModule for graph runtime @@ -240,9 +221,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; BuildRelay(mod, params_); @@ -258,13 +237,10 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize( - IRModule relay_module, - const TargetsMap& targets, - const std::unordered_map& params) { + IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + const std::unordered_map& params) { if (params.size()) { - CHECK(relay_module->ContainGlobalVar("main")) - << "Missing the main entry function"; + CHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); auto new_main = BindParamsByName(main_func, params); @@ -328,8 +304,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); if (targets_.size() > 1) { - relay_module = - RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); } // Fuse the operations if it is needed. @@ -386,8 +361,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return updated_module The updated module after device annotation. */ - IRModule RunDeviceAnnotationPass(const IRModule& relay_module, - int fallback_device) { + IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto updated_module = rewrite(relay_module); @@ -416,12 +390,11 @@ class RelayBuildModule : public runtime::ModuleNode { break; } for (auto kv : annotation_map) { - CHECK_EQ(kv.second->value, dev_type) - << "Expressions in the function are " - << "annotated with various device types," - << "but not device copy operators " - << "found. Please check the " - << "RewriteAnnotation pass."; + CHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are " + << "annotated with various device types," + << "but not device copy operators " + << "found. Please check the " + << "RewriteAnnotation pass."; } targets_.Set(0, CreateDefaultTarget(dev_type)); } @@ -435,9 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param relay_module The Relay IR module. * \param params The parameters. */ - void BuildRelay( - IRModule relay_module, - const std::unordered_map& params) { + void BuildRelay(IRModule relay_module, + const std::unordered_map& params) { // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); // Get the updated function. @@ -473,23 +445,19 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); } } else { - ret_.mod = tvm::build( - lowered_funcs, - target_host_, - BuildConfig::Current()); + ret_.mod = tvm::build(lowered_funcs, target_host_, BuildConfig::Current()); } Array ext_mods = graph_codegen_->GetExternalModules(); // Import all external runtime modules. - for (const auto& it : ext_mods) - ret_.mod.Import(it); + for (const auto& it : ext_mods) ret_.mod.Import(it); } private: Target GetTargetHost() { Target target_host = target_host_; if (!target_host_.defined()) { - for (const auto &it : targets_) { + for (const auto& it : targets_) { if (it.second->device_type == kDLCPU) { target_host = it.second; break; @@ -516,20 +484,19 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); }); TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Map params = args[1]; - std::unordered_map params_; - for (const auto& kv : params) { - params_[kv.first] = kv.second->data; - } - *rv = relay::backend::BindParamsByName(args[0], params_); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + Map params = args[1]; + std::unordered_map params_; + for (const auto& kv : params) { + params_[kv.first] = kv.second->data; + } + *rv = relay::backend::BindParamsByName(args[0], params_); + }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index b8e214255f46a..941a257c0003b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -103,8 +103,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> for (Var param : prim_func->params) { Array inputs; if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } else { @@ -114,8 +113,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const auto* ttype = field.as(); // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } @@ -128,7 +126,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -169,29 +167,31 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> CHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "compile_engine_const", topi::kBroadcast); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); return {value}; } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttr("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -206,12 +206,10 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); Array outputs; @@ -219,8 +217,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype, - te::Operation(), 0)); + outputs.push_back( + te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; @@ -230,8 +228,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) - << "Two complicated op in a primitive function " - << " master=" << master_op_ << " current=" << op; + << "Two complicated op in a primitive function " + << " master=" << master_op_ << " current=" << op; } if (op_pattern >= master_op_pattern_) { master_op_ = op; @@ -240,8 +238,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> master_implementation_ = impl; } if (outputs.size() != 1) { - const auto* tuple_type = - call_node->checked_type().as(); + const auto* tuple_type = call_node->checked_type().as(); CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK_EQ(tuple_type->fields.size(), outputs.size()); } @@ -271,8 +268,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -328,15 +324,15 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> shape_inputs.push_back(shape_tensor); }; - if (const auto *ttype = param->checked_type().as()) { + if (const auto* ttype = param->checked_type().as()) { add_placeholder(ttype); } else { // flatten tuple of tensor type. - const auto *tuple_type = param->type_as(); + const auto* tuple_type = param->type_as(); // TODO(@icemelon): Support recursive tuple CHECK(tuple_type); for (Type field : tuple_type->fields) { - const auto *ttype = field.as(); + const auto* ttype = field.as(); CHECK(ttype); add_placeholder(ttype); } @@ -351,7 +347,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -427,28 +423,31 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "data_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } else { - auto value = tvm::te::compute({}, [&](const Array&) { - return tir::make_const(DataType::Int(64), 0); - }, "shape_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } @@ -456,18 +455,15 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const CallNode* call_node) final { static auto fshape_func = Op::GetAttr("FShapeFunc"); - static auto tshape_data_dependant = Op::GetAttr( - "TShapeDataDependant"); - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + static auto tshape_data_dependant = Op::GetAttr("TShapeDataDependant"); + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); CHECK(data_dependants_.empty() || !data_dependants_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependant shape func"; - CHECK_GT(fshape_func.count(op), 0) - << "Internal error, cannot find ShapeFunc for " << op->name; + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependant shape func"; + CHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; CHECK_GT(tshape_data_dependant.count(op), 0) - << "Internal error, cannot find TShapeDataDependant for " << op->name; + << "Internal error, cannot find TShapeDataDependant for " << op->name; data_dependants_.push_back(IsDataDependant(call_node)); // Visit all inputs @@ -482,8 +478,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } // Get output ndims auto ret_type = call_node->checked_type(); @@ -522,8 +517,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -549,9 +543,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { - return LowerInternal(key)->cached_func; - } + CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { @@ -612,9 +604,7 @@ class CompileEngineImpl : public CompileEngineNode { return ret; } - void Clear() final { - cache_.clear(); - } + void Clear() final { cache_.clear(); } // List all items in the cache. Array ListItems() { std::lock_guard lock(mutex_); @@ -638,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -655,10 +645,8 @@ class CompileEngineImpl : public CompileEngineNode { // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); - const auto name_node = - key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(name_node.defined()) - << "External function has not been attached a name yet."; + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); @@ -669,8 +657,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object( - *(cfunc.operator->())); + auto cache_node = make_object(*(cfunc.operator->())); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; @@ -689,13 +676,11 @@ class CompileEngineImpl : public CompileEngineNode { } // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)( - cfunc->schedule, all_args, cache_node->func_name, key->source_func); + cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, - binds, bcfg); + cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds, bcfg); } value->cached_func = CachedFunc(cache_node); return value; @@ -719,8 +704,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object( - *(spair.second.operator->())); + auto cache_node = make_object(*(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; @@ -771,57 +755,41 @@ class CompileEngineImpl : public CompileEngineNode { const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. - static CompileEngine* inst = new CompileEngine( - make_object()); + static CompileEngine* inst = new CompileEngine(make_object()); return *inst; } TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") -.set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); -}); + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") -.set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); -}); + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") -.set_body_typed([]() { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { return CompileEngine::Global(); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") -.set_body_typed([](CompileEngine self) { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { self->Clear(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->Lower(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->LowerShapeFunc(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") -.set_body_typed([](CompileEngine self) { - return self->LowerExternalFunctions(); -}); + .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->JIT(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed( - [](CompileEngine self){ +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { return static_cast(self.operator->())->ListItems(); }); } // namespace relay diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 4a3a04d02dcd1..9abe80c363c56 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,13 +27,14 @@ #include #include -#include #include #include -#include #include -#include +#include +#include + #include +#include namespace tvm { namespace relay { @@ -150,9 +151,7 @@ class CCacheKey : public ObjectRef { */ TVM_DLL CCacheKey(Function source_func, Target target); - const CCacheKeyNode* operator->() const { - return static_cast(get()); - } + const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator inline bool operator==(const CCacheKey& other) const { CHECK(defined() && other.defined()); @@ -184,12 +183,8 @@ class CCacheValue : public ObjectRef { public: CCacheValue() {} explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { - return static_cast(get_mutable()); - } - const CCacheValueNode* operator->() const { - return static_cast(get()); - } + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } using ContainerType = CCacheValueNode; }; @@ -240,9 +235,7 @@ class CompileEngine : public ObjectRef { public: CompileEngine() {} explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { - return static_cast(get_mutable()); - } + CompileEngineNode* operator->() { return static_cast(get_mutable()); } using ContainerType = CompileEngineNode; /*! \brief The global compile engine. */ TVM_DLL static const CompileEngine& Global(); @@ -260,17 +253,15 @@ inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine( - hash_, std::hash()(target->str())); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; } -inline bool CCacheKeyNode::Equal( - const CCacheKeyNode* other) const { +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); + tvm::StructuralEqual()(this->source_func, other->source_func); } } // namespace relay @@ -278,7 +269,7 @@ inline bool CCacheKeyNode::Equal( namespace std { // overload hash -template<> +template <> struct hash<::tvm::relay::CCacheKey> { size_t operator()(const ::tvm::relay::CCacheKey& key) const { CHECK(key.defined()); diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index ed36fda90a21f..b8803d45d0f6f 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -151,8 +151,8 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; } - buf_stream << dtype << "* " << out << - " = (" << dtype << "*)std::malloc(4 * " << out_size << ");"; + buf_stream << dtype << "* " << out << " = (" << dtype << "*)std::malloc(4 * " << out_size + << ");"; buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 92263861d3593..2ee68cec8b40c 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -25,9 +25,10 @@ #define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ #include -#include #include +#include #include + #include #include #include @@ -69,8 +70,7 @@ class CSourceModuleCodegenBase { * \return An external symbol. */ std::string GetExtSymbol(const Function& func) const { - const auto name_node = - func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -124,8 +124,7 @@ class CodegenCBase { * * \endcode */ - void GenerateBackendCFunc(const std::string& func_name, - const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const Array& args, const Output& out) { // Print signature code_stream_ << "\n"; @@ -158,8 +157,8 @@ class CodegenCBase { code_stream_ << "}\n\n"; // Generate the macro - code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " - << func_name << "_wrapper_);\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name + << "_wrapper_);\n\n"; } /*! @@ -187,8 +186,7 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, - const std::vector& out) { + const std::vector& body, const std::vector& out) { // Create the signature. For example, it could be: // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} code_stream_ << "extern \"C\" void " << ext_func_id << "_("; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 5e45e94223f87..3db5dc440c8f2 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -174,7 +174,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C // Allocate large arrays on the static section to avoid stakc overflow. // Note that this would probably increase compilation time as the source // file could be really large. - buf_stream << "static float " << output.name << "[" << num_elems <<"] = {"; + buf_stream << "static float " << output.name << "[" << num_elems << "] = {"; for (int64_t i = 0; i < num_elems - 1; i++) { buf_stream << ptr[i] << ","; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 736509d2d97fe..820e17f8a4987 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -22,10 +22,11 @@ * \brief Memory index assignment pass for executing * the program in the graph runtime. */ -#include +#include #include #include -#include +#include + #include "../../support/arena.h" namespace tvm { @@ -60,9 +61,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { } } - void VisitExpr_(const ConstantNode* op) final { - this->CreateToken(op, false); - } + void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } void VisitExpr_(const VarNode* op) final { // Do nothing. @@ -96,9 +95,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { token_map_[op] = {tok[op->index]}; } - void VisitExpr_(const IfNode* op) final { - LOG(FATAL) << "if is not supported."; - } + void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } void VisitExpr_(const LetNode* op) final { auto token = GetToken(op->value); @@ -131,12 +128,11 @@ class StorageAllocaBaseVisitor : public ExprVisitor { class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: - explicit StorageAllocaInit(support::Arena* arena) - : arena_(arena) {} + explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ - std::unordered_map > - GetInitTokenMap(const Function& func) { + std::unordered_map > GetInitTokenMap( + const Function& func) { node_device_map_ = CollectDeviceInfo(func); this->Run(func); return std::move(token_map_); @@ -145,12 +141,11 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateToken(const ExprNode* op, bool can_realloc) final { CHECK(!token_map_.count(op)); std::vector tokens; - int device_type = node_device_map_.count(GetRef(op)) - ? node_device_map_[GetRef(op)]->value - : 0; + int device_type = + node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; if (const auto* tuple_type = op->checked_type().as()) { for (Type t : tuple_type->fields) { const auto* ttype = t.as(); @@ -227,10 +222,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { - LOG(FATAL) - << num_annotated_nodes << " out of " << num_nodes - << "expressions are assigned with virtual device types. Either all " - "or none of the expressions are expected to be annotated."; + LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes + << "expressions are assigned with virtual device types. Either all " + "or none of the expressions are expected to be annotated."; } return smap; } @@ -296,12 +290,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); - CHECK(pval != nullptr) - << "Cannot allocate memory symbolic tensor shape " - << ttype->shape; - CHECK_GE(*pval, 0) - << "Cannot allocate memory for tensor with negative shape" - << *pval; + CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; + CHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; size *= static_cast(pval[0]); } size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); @@ -324,7 +314,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -337,7 +327,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -390,8 +380,7 @@ Map > GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } -TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") -.set_body_typed(GraphPlanMemory); +TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7b686c76e3e70..c8ec1bf1e767b 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -44,7 +44,7 @@ class GraphInputNode; class GraphOpNode; using IntegerArray = Array; -using ShapeVector = std::vector >; +using ShapeVector = std::vector>; using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; @@ -70,8 +70,7 @@ class GraphNodeRef { public: GraphNodeRef() {} GraphNodeRef(int ident, int index, int version = 0) - : ident_(ident), index_(index), version_(version) {} - + : ident_(ident), index_(index), version_(version) {} inline void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(); @@ -81,9 +80,7 @@ class GraphNodeRef { writer->EndArray(); } - inline void Load(dmlc::JSONReader* reader) { - LOG(FATAL) << "Not implemented."; - } + inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } protected: int ident_; @@ -136,11 +133,8 @@ class GraphInputNode : public GraphNode { class GraphOpNode : public GraphNode { public: GraphOpNode() {} - GraphOpNode(const std::string& name, - const GraphAttrs& nd_attrs, - const std::string& op_name, - const std::vector& inputs, - const GraphAttrs& attrs, + GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, + const std::vector& inputs, const GraphAttrs& attrs, size_t num_outputs = 1) { name_ = name; attrs_ = nd_attrs; @@ -173,8 +167,7 @@ class GraphOpNode : public GraphNode { const GraphAttrs& nd_attrs, const std::string& op_name, const std::vector& inputs, - const GraphAttrs& attrs, - size_t num_outputs = 1) { + const GraphAttrs& attrs, size_t num_outputs = 1) { auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); return std::dynamic_pointer_cast(ptr); } @@ -335,8 +328,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, - const std::string& op_name, + std::vector GraphAddCallNode(const CallNode* op, const std::string& op_name, const std::string& func_name) { std::vector inputs; for (auto arg : op->args) { @@ -345,11 +337,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator(op)); } @@ -384,11 +372,11 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; // Normal Relay Function if (targets_.size() == 1) { - // homogeneous execution. + // homogeneous execution. const auto& it = targets_.begin(); target = (*it).second; } else { @@ -400,8 +388,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr()] = IRModule::Empty(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); - return GraphAddCallNode(op, - _GetUniqueName(lowered_func->func_name), - lowered_func->func_name); + return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); } std::vector VisitExpr_(const LetNode* op) override { @@ -560,37 +545,34 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator& sptr_to_self) { - if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 2) - << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; - void* mod = args[0]; - Map tmp = args[1]; - TargetsMap targets; - for (const auto& it : tmp) { - auto dev_type = it.first.as(); - CHECK(dev_type); - targets[dev_type->value] = it.second; - } - codegen_ = std::make_shared( - reinterpret_cast(mod), targets); - }); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + CHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = + std::make_shared(reinterpret_cast(mod), targets); + }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; this->output_ = this->codegen_->Codegen(func); }); } else if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.graph_json; - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; - for (const auto &kv : this->output_.params) { + for (const auto& kv : this->output_.params) { ret.push_back(kv.first); } *rv = ret; @@ -614,9 +596,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } } - const char* type_key() const final { - return "RelayGraphRuntimeCodegenModule"; - } + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: std::shared_ptr codegen_; @@ -629,9 +609,7 @@ runtime::Module CreateGraphCodegenMod() { } TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CreateGraphCodegenMod(); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 465f788449e2b..c5299975371fb 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,16 +21,16 @@ * \file src/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ -#include -#include -#include -#include -#include -#include +#include #include #include +#include #include -#include +#include +#include +#include +#include +#include #include "compile_engine.h" @@ -39,8 +39,7 @@ namespace relay { using namespace runtime; -InterpreterClosure::InterpreterClosure(tvm::Map env, - Function func) { +InterpreterClosure::InterpreterClosure(tvm::Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -48,10 +47,10 @@ InterpreterClosure::InterpreterClosure(tvm::Map env, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; + }); inline const PackedFunc& GetPackedFunc(const std::string& name) { const PackedFunc* pf = tvm::runtime::Registry::Get(name); @@ -69,10 +68,10 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RecClosureObj(" << node->clos << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RecClosureObj(" << node->clos << ")"; + }); RefValue::RefValue(ObjectRef value) { ObjectPtr n = make_object(); @@ -80,21 +79,19 @@ RefValue::RefValue(ObjectRef value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed([](ObjectRef value){ +TVM_REGISTER_GLOBAL("relay._make.RefValue").set_body_typed([](ObjectRef value) { return RefValue(value); }); TVM_REGISTER_NODE_TYPE(RefValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefValueObj(" << node->value << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefValueObj(" << node->value << ")"; + }); -ConstructorValue::ConstructorValue(int32_t tag, - tvm::Array fields, +ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; @@ -104,19 +101,17 @@ ConstructorValue::ConstructorValue(int32_t tag, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed([](int32_t tag, tvm::Array fields, - Constructor constructor) { - return ConstructorValue(tag, fields, constructor); -}); + .set_body_typed([](int32_t tag, tvm::Array fields, Constructor constructor) { + return ConstructorValue(tag, fields, constructor); + }); TVM_REGISTER_NODE_TYPE(ConstructorValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorValueObj(" << node->tag << "," - << node->fields << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; + }); /*! * \brief A stack frame in the Relay interpreter. @@ -161,9 +156,7 @@ struct Stack { */ struct LocalFrame { Stack& st; - explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { - st.frames.push_back(fr); - } + explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); } ~LocalFrame() { st.frames.pop_back(); } }; }; @@ -213,9 +206,8 @@ InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { // contains DAG in dataflow-form. // // Conversion to ANF is recommended before running the interpretation. -class Interpreter : - public ExprFunctor, - PatternFunctor { +class Interpreter : public ExprFunctor, + PatternFunctor { public: Interpreter(IRModule mod, DLContext context, Target target) : mod_(mod), @@ -232,21 +224,13 @@ class Interpreter : return f(); } - void extend(const Var& id, ObjectRef v) { - stack_.current_frame().locals.Set(id, v); - } + void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } - ObjectRef Lookup(const Var& local) { - return stack_.Lookup(local); - } + ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } - ObjectRef Eval(const Expr& expr) { - return VisitExpr(expr); - } + ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } - ObjectRef VisitExpr_(const VarNode* var_node) final { - return Lookup(GetRef(var_node)); - } + ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } ObjectRef VisitExpr_(const GlobalVarNode* op) final { return Eval(mod_->Lookup(GetRef(op))); @@ -260,9 +244,7 @@ class Interpreter : return ObjectRef(); } - ObjectRef VisitExpr_(const ConstantNode* op) final { - return op->data.CopyTo(context_); - } + ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(context_); } ObjectRef VisitExpr_(const TupleNode* op) final { std::vector values; @@ -302,8 +284,7 @@ class Interpreter : return MakeClosure(func); } - Array ComputeDynamicShape(const Function& func, - const Array& args) { + Array ComputeDynamicShape(const Function& func, const Array& args) { CCacheKey key(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -319,26 +300,26 @@ class Interpreter : cpu_ctx.device_id = 0; auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { - auto nd_array = Downcast(val); - if (need_shape) { - int64_t ndim = nd_array.Shape().size(); - NDArray shape_arr; - if (ndim == 0) { - shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); - } else { - shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - int64_t* data = reinterpret_cast(shape_arr->data); - for (auto j = 0; j < ndim; ++j) { - data[j] = nd_array.Shape()[j]; - } - } - inputs[i] = shape_arr; - setter(i, shape_arr); + auto nd_array = Downcast(val); + if (need_shape) { + int64_t ndim = nd_array.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); } else { - auto arr = nd_array.CopyTo(cpu_ctx); - inputs[i] = arr; - setter(i, arr); + shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = nd_array.Shape()[j]; + } } + inputs[i] = shape_arr; + setter(i, shape_arr); + } else { + auto arr = nd_array.CopyTo(cpu_ctx); + inputs[i] = arr; + setter(i, arr); + } }; size_t arg_counter = 0; @@ -367,17 +348,16 @@ class Interpreter : } } } - CHECK_EQ(arg_counter, cfunc->inputs.size()) - << "Shape function input sizes mismatch"; + CHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch"; auto fset_shape_output = [&](size_t i, Type val_type) { - // TODO(@icemelon): allow recursive tuple - const TensorTypeNode* rtype = val_type.as(); - CHECK(rtype != nullptr); - int64_t ndim = rtype->shape.size(); - auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - outputs[i] = arr; - setter(arg_counter + i, arr); + // TODO(@icemelon): allow recursive tuple + const TensorTypeNode* rtype = val_type.as(); + CHECK(rtype != nullptr); + int64_t ndim = rtype->shape.size(); + auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + outputs[i] = arr; + setter(arg_counter + i, arr); }; auto ret_type = func->body->checked_type(); @@ -392,8 +372,7 @@ class Interpreter : auto tt = Downcast(ret_type); fset_shape_output(0, tt); } - CHECK_EQ(cfunc->outputs.size(), out_cnt) - << "Shape function output sizes mismatch"; + CHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch"; PackedFunc shape_func; Module m; @@ -419,8 +398,7 @@ class Interpreter : return out_shapes; } - ObjectRef InvokePrimitiveOp(const Function& func, - const Array& args) { + ObjectRef InvokePrimitiveOp(const Function& func, const Array& args) { const auto* call_node = func->body.as(); if (call_node && call_node->op == debug_op_) { @@ -451,8 +429,7 @@ class Interpreter : if (const auto* tuple_type = func->body->checked_type().as()) { arg_len += tuple_type->fields.size(); } else { - CHECK(func->body->checked_type().as()) - << func->body->checked_type(); + CHECK(func->body->checked_type().as()) << func->body->checked_type(); arg_len += 1; } std::vector values(arg_len); @@ -463,16 +440,14 @@ class Interpreter : const auto nd_array = Downcast(val); setter(i, nd_array); DLContext arg_ctx = nd_array->ctx; - CHECK(arg_ctx.device_type == context_.device_type && - arg_ctx.device_id == context_.device_id) - << "Interpreter expect context to be " - << context_ << ", but get " << arg_ctx; + CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id) + << "Interpreter expect context to be " << context_ << ", but get " << arg_ctx; }; int arg_counter = 0; for (ObjectRef arg : args) { if (arg->IsInstance()) { - fset_input(arg_counter++, arg); + fset_input(arg_counter++, arg); } else { auto adt = Downcast(arg); for (size_t i = 0; i < adt.size(); ++i) { @@ -547,8 +522,7 @@ class Interpreter : } // Invoke the closure - ObjectRef Invoke(const InterpreterClosure& closure, - const tvm::Array& args, + ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { @@ -625,11 +599,9 @@ class Interpreter : ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); const auto* adt_obj = val.as(); - CHECK(adt_obj) - << "interal error: when evaluating TupleGetItem expected an ADT value"; + CHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value"; auto adt = GetRef(adt_obj); - CHECK_LT(static_cast(op->index), adt.size()) - << "internal error: index out of bounds"; + CHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; return adt[op->index]; } @@ -665,9 +637,7 @@ class Interpreter : } } - ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValue(Eval(op->value)); - } + ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { ObjectRef r = Eval(op->ref); @@ -718,9 +688,7 @@ class Interpreter : return true; } - bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { - return true; - } + bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { extend(op->var, v); @@ -754,17 +722,11 @@ class Interpreter : const Op& shape_of_op_; }; - -TypedPackedFunc -CreateInterpreter( - IRModule mod, - DLContext context, - Target target) { +TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target) { if (mod.defined()) { // eta expand to support constructors in argument position - transform::Sequential seq({ - transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::Sequential seq({transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); transform::PassContext pass_ctx = transform::PassContext::Current(); tvm::With ctx(pass_ctx); mod = seq(mod); @@ -779,8 +741,7 @@ CreateInterpreter( return TypedPackedFunc(packed); } -TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") -.set_body_typed(CreateInterpreter); +TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index e517fee3a4af9..cd760b867f6b6 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -22,86 +22,77 @@ * \brief Implementation and registration of parameter dictionary * serializing/deserializing functions. */ -#include +#include "param_dict.h" + #include +#include #include -#include #include - -#include "param_dict.h" - - +#include namespace tvm { namespace relay { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - CHECK_EQ(args.size() % 2, 0u); - // `args` is in the form "key, value, key, value, ..." - size_t num_params = args.size() / 2; - std::vector names; - names.reserve(num_params); - std::vector arrays; - arrays.reserve(num_params); - for (size_t i = 0; i < num_params * 2; i += 2) { - names.emplace_back(args[i].operator std::string()); - arrays.emplace_back(args[i + 1].operator DLTensor*()); - } - std::string bytes; - dmlc::MemoryStringStream strm(&bytes); - dmlc::Stream* fo = &strm; - uint64_t header = kTVMNDArrayListMagic, reserved = 0; - fo->Write(header); - fo->Write(reserved); - fo->Write(names); - { - uint64_t sz = static_cast(arrays.size()); - fo->Write(sz); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(fo, arrays[i]); - } +TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0u); + // `args` is in the form "key, value, key, value, ..." + size_t num_params = args.size() / 2; + std::vector names; + names.reserve(num_params); + std::vector arrays; + arrays.reserve(num_params); + for (size_t i = 0; i < num_params * 2; i += 2) { + names.emplace_back(args[i].operator std::string()); + arrays.emplace_back(args[i + 1].operator DLTensor*()); + } + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + fo->Write(header); + fo->Write(reserved); + fo->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + fo->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(fo, arrays[i]); } - TVMByteArray arr; - arr.data = bytes.c_str(); - arr.size = bytes.length(); - *rv = arr; - }); + } + TVMByteArray arr; + arr.data = bytes.c_str(); + arr.size = bytes.length(); + *rv = arr; +}); -TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string bytes = args[0]; - std::vector names; - dmlc::MemoryStringStream memstrm(&bytes); - dmlc::Stream* strm = &memstrm; - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; - CHECK(strm->Read(&names)) - << "Invalid parameters file format"; - uint64_t sz; - strm->Read(&sz, sizeof(sz)); - size_t size = static_cast(sz); - CHECK(size == names.size()) - << "Invalid parameters file format"; - tvm::Array ret; - for (size_t i = 0; i < size; ++i) { - tvm::runtime::NDArray temp; - temp.Load(strm); - auto n = tvm::make_object(); - n->name = std::move(names[i]); - n->array = temp; - ret.push_back(NamedNDArray(n)); - } - *rv = ret; - }); +TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string bytes = args[0]; + std::vector names; + dmlc::MemoryStringStream memstrm(&bytes); + dmlc::Stream* strm = &memstrm; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + CHECK(size == names.size()) << "Invalid parameters file format"; + tvm::Array ret; + for (size_t i = 0; i < size; ++i) { + tvm::runtime::NDArray temp; + temp.Load(strm); + auto n = tvm::make_object(); + n->name = std::move(names[i]); + n->array = temp; + ret.push_back(NamedNDArray(n)); + } + *rv = ret; +}); TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index c829e546b90bb..384201f946486 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -25,9 +25,9 @@ #define TVM_RELAY_BACKEND_PARAM_DICT_H_ #include -#include #include #include +#include #include diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 5c98728b1c25b..b19d272f68482 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -215,5 +215,4 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, } // namespace relay } // namespace tvm - #endif // TVM_RELAY_BACKEND_UTILS_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 9cdd365152971..b2a5e83ef43cc 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -22,27 +22,29 @@ * \brief A compiler from relay::Module to the VM byte code. */ -#include +#include "compiler.h" + +#include #include +#include #include #include #include -#include #include #include -#include -#include +#include +#include #include #include #include #include #include -#include "../utils.h" + #include "../../backend/compile_engine.h" -#include "../../transforms/pass_util.h" #include "../../op/op_common.h" -#include "compiler.h" +#include "../../transforms/pass_util.h" +#include "../utils.h" namespace tvm { namespace relay { @@ -93,8 +95,7 @@ struct AccessField : MatchValue { // Runtime register num after compiling the access field path RegName reg{-1}; - AccessField(MatchValuePtr parent, size_t index) - : parent(parent), index(index) {} + AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {} ~AccessField() {} }; @@ -115,8 +116,7 @@ struct VarBinding : ConditionNode { Var var; MatchValuePtr val; - VarBinding(Var var, MatchValuePtr val) - : var(var), val(val) {} + VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {} ~VarBinding() {} }; @@ -131,9 +131,7 @@ struct TagCompare : ConditionNode { /*! \brief The expected tag */ int target_tag; - TagCompare(MatchValuePtr obj, size_t target) - : obj(obj), target_tag(target) { - } + TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {} ~TagCompare() {} }; @@ -143,10 +141,8 @@ using TreeLeafNode = relay::TreeLeafNode; using TreeLeafFatalNode = relay::TreeLeafFatalNode; using TreeBranchNode = relay::TreeBranchNode; -TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, - Pattern pattern, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { +TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, + TreeObjectPtr then_branch, TreeObjectPtr else_branch) { if (pattern.as()) { // We ignore wildcard binding since it's not producing new vars return then_branch; @@ -176,11 +172,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, } } -TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, - Clause clause, - TreeObjectPtr else_branch) { - return BuildDecisionTreeFromPattern(data, clause->lhs, - TreeLeafNode::Make(clause->rhs), else_branch); +TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, + TreeObjectPtr else_branch) { + return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), + else_branch); } TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { @@ -196,12 +191,11 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array ToAllocTensorShape(NDArray shape) { std::vector raw_shape; CHECK_EQ(shape->ndim, 1u); - CHECK_EQ(shape->dtype.code, 0U) - << "The dtype of constant shape must be int32 or int64, but got " - << DLDataType2String(shape->dtype); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << DLDataType2String(shape->dtype); CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) - << "The dtype of constant shape must be int32 or int64, but got" - << DLDataType2String(shape->dtype); + << "The dtype of constant shape must be int32 or int64, but got" + << DLDataType2String(shape->dtype); if (shape->dtype.bits == 64) { int64_t* int_ptr = reinterpret_cast(shape->data); @@ -217,7 +211,6 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } - class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -310,11 +303,7 @@ class VMFunctionCompiler : ExprFunctor { } // TODO(@jroesch): use correct tag - Emit(Instruction::AllocADT( - 0, - tuple->fields.size(), - fields_registers, - NewRegister())); + Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); } void VisitExpr_(const MatchNode* match_node) { @@ -415,52 +404,46 @@ class VMFunctionCompiler : ExprFunctor { for (auto input : inputs) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : outputs) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - outputs.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(), + argument_registers)); } - void EmitInvokeTVMOp(const Function& func, - const Expr& inputs, - const Expr& outputs) { + void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) - << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; + << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); - CHECK(input_tuple) - << "internal error: invoke_tvm_op inputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," + << "please file a bug in the memory manifestation pass"; auto output_tuple = outputs.as(); - CHECK(output_tuple) - << "internal error: invoke_tvm_op outputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple," + << "please file a bug in the memory manifestation pass"; for (auto input : input_tuple->fields) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : output_tuple->fields) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } @@ -500,10 +483,8 @@ class VMFunctionCompiler : ExprFunctor { } } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - output_tuple->fields.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), + argument_registers)); } void VisitExpr_(const CallNode* call_node) { @@ -514,70 +495,73 @@ class VMFunctionCompiler : ExprFunctor { // allocation operations. if (op.as()) { OpMatch matcher; - matcher.Match("memory.invoke_tvm_op", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); - }).Match("memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - - // Get the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; - - // If the shape is constant then we will emit a static tensor allocation instruction. - auto const_shape = args[1].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - // TODO(@jroesch): we need to get an RFC done to standarize shape dtype - std::vector raw_shape = ToAllocTensorShape(shape); - // Add context field. - Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); - } else { - this->VisitExpr(args[1]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg( - storage_register, - shape_register, - dtype, - NewRegister())); - } - }).Match("memory.alloc_storage", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - // Compute the size of the allocation. - this->VisitExpr(args[0]); - auto size_register = last_register_; - - this->VisitExpr(args[1]); - auto alignment_register = last_register_; - - // Get the dtype hint from the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister())); - }).Match("memory.shape_func", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - auto shape_func = Downcast(args[0]); - auto inputs = Downcast(args[1]); - auto outputs = Downcast(args[2]); - EmitShapeFunc(shape_func, inputs->fields, outputs->fields); - }).Match("memory.kill", - [](const Array& args, const Attrs& attrs, const Array& type_arg) { - LOG(FATAL) << "memory.kill is not yet supported"; - }); + matcher + .Match("memory.invoke_tvm_op", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); + }) + .Match( + "memory.alloc_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2); + + // Get the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + + // If the shape is constant then we will emit a static tensor allocation + // instruction. + auto const_shape = args[1].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); + } else { + this->VisitExpr(args[1]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg(storage_register, shape_register, dtype, + NewRegister())); + } + }) + .Match("memory.alloc_storage", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2); + // Compute the size of the allocation. + this->VisitExpr(args[0]); + auto size_register = last_register_; + + this->VisitExpr(args[1]); + auto alignment_register = last_register_; + + // Get the dtype hint from the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, + NewRegister())); + }) + .Match("memory.shape_func", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + auto shape_func = Downcast(args[0]); + auto inputs = Downcast(args[1]); + auto outputs = Downcast(args[2]); + EmitShapeFunc(shape_func, inputs->fields, outputs->fields); + }) + .Match("memory.kill", + [](const Array& args, const Attrs& attrs, const Array& type_arg) { + LOG(FATAL) << "memory.kill is not yet supported"; + }); matcher(GetRef(call_node)); return; } @@ -600,14 +584,13 @@ class VMFunctionCompiler : ExprFunctor { auto it = context_->global_map.find(global); CHECK(it != context_->global_map.end()); DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint - << " with func_index=" << it->second; + << " with func_index=" << it->second; // TODO(tvm-team): // Think about mixed call into global that is not a relay::Function // perhaps establish as an invariance(all functions in mod must be relay::Function) auto func = Downcast(context_->module->Lookup(global)); - if (IsClosure(func)) { auto arity = func->params.size(); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); @@ -738,9 +721,7 @@ class VMFunctionCompiler : ExprFunctor { Target target_host_; }; - -PackedFunc VMCompiler::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); @@ -753,9 +734,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, this->Codegen(); }); } else if (name == "get_executable") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(exec_); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -786,11 +766,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { - CHECK_EQ(targets.size(), 1) - << "Currently VM compiler doesn't support heterogeneous compilation"; +void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { + CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) @@ -867,7 +844,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // eta expand to support constructors in argument position pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); + /* expand_constructor */ true, /* expand_global_var */ false)); pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -949,7 +926,7 @@ void VMCompiler::Codegen() { LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; return; } - auto const &cached_funcs = context_.cached_funcs; + auto const& cached_funcs = context_.cached_funcs; if (cached_funcs.size() == 0) { return; } @@ -999,8 +976,7 @@ runtime::Module CreateVMCompiler() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay._vm._VMCompiler") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._vm._VMCompiler").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateVMCompiler(); }); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index c1040f1ed18ea..7faab9d349602 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -28,9 +28,11 @@ #include #include #include -#include #include #include +#include +#include + #include #include #include @@ -38,8 +40,9 @@ #include #include #include -#include "../../../runtime/vm/profiler/vm.h" + #include "../../../runtime/vm/naive_allocator.h" +#include "../../../runtime/vm/profiler/vm.h" #include "../../backend/compile_engine.h" #include "../../transforms/pass_util.h" @@ -79,17 +82,13 @@ struct VMCompilerContext { std::unordered_map seen_funcs; }; - class VMCompiler : public runtime::ModuleNode { public: virtual ~VMCompiler() {} - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - const char* type_key() const { - return "VMCompiler"; - } + const char* type_key() const { return "VMCompiler"; } /*! * \brief Set the parameters @@ -107,9 +106,7 @@ class VMCompiler : public runtime::ModuleNode { to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host); + void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 12113b0683f2b..8e960a7385273 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -24,9 +24,10 @@ #include #include -#include #include #include +#include + #include #include @@ -125,18 +126,13 @@ struct PrimitiveInliner : ExprMutator { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - DLOG(INFO) << "Before inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(global, func, true); - DLOG(INFO) << "After inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false); } } return module_; @@ -149,16 +145,13 @@ namespace transform { Pass InlinePrimitives() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::PrimitiveInliner(m).Inline(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::PrimitiveInliner(m).Inline(); }; auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); // Eliminate dead code for each function after inlining. return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); } -TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives") -.set_body_typed(InlinePrimitives); +TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives").set_body_typed(InlinePrimitives); } // namespace transform diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index bfbefd57a3105..1d3fff7a7ba6a 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -24,12 +24,13 @@ #include #include +#include #include #include -#include -#include #include #include +#include + #include #include @@ -44,9 +45,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0) != 0; -} +bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -85,8 +84,7 @@ class LambdaLifter : public ExprMutator { if (!letrec_.empty() && var == letrec_.back()) { auto it = lambda_map_.find(var); CHECK(it != lambda_map_.end()); - return Call(it->second, call->args, call_node->attrs, - call_node->type_args); + return Call(it->second, call->args, call_node->attrs, call_node->type_args); } } return std::move(call); @@ -153,18 +151,15 @@ class LambdaLifter : public ExprMutator { if (captured_vars.size() == 0 && free_type_vars.size() == 0) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { - lifted_func = - Function(captured_vars, body, func->func_type_annotation(), free_type_vars); + lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } CHECK(lifted_func.defined()); - if (module_->ContainGlobalVar(name)) { const auto existing_func = module_->Lookup(name); - CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) - << "lifted function hash collision"; + CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) << "lifted function hash collision"; // If an identical function already exists, use its global var. global = module_->GetGlobalVar(name); } else { @@ -192,10 +187,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(pair.first, func, true); } @@ -215,14 +207,11 @@ namespace transform { Pass LambdaLift() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::LambdaLifter(m).Lift(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); }; return CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LambdaLift") -.set_body_typed(LambdaLift); +TVM_REGISTER_GLOBAL("relay._transform.LambdaLift").set_body_typed(LambdaLift); } // namespace transform diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index c2fe37f15453f..64ddbe3c2881b 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -22,12 +22,13 @@ * \brief Remove unused global relay functions in a relay module. */ +#include #include #include -#include -#include #include #include +#include + #include #include #include @@ -48,10 +49,7 @@ struct CallTracer : ExprVisitor { // Record the expressions that are being visited std::unordered_set visiting_; - explicit CallTracer(const IRModule& module) - : module_{module}, - called_funcs_{}, - visiting_{} {} + explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {} void VisitExpr_(const GlobalVarNode* op) final { called_funcs_.insert(op->name_hint); @@ -86,8 +84,7 @@ struct CallTracer : ExprVisitor { * * \return The module with dead functions removed. */ -IRModule RemoveUnusedFunctions(const IRModule& module, - Array entry_funcs) { +IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { auto funcs = CallTracer(module).Trace(entry); @@ -108,15 +105,14 @@ IRModule RemoveUnusedFunctions(const IRModule& module, namespace transform { Pass RemoveUnusedFunctions(Array entry_functions) { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); }; return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); } -TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions") -.set_body_typed(RemoveUnusedFunctions); +TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions); } // namespace transform diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 11c2cbb772fcb..d808351e841c8 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -21,8 +21,8 @@ * \file src/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ -#include #include +#include namespace tvm { namespace relay { @@ -34,15 +34,12 @@ PatternWildcard::PatternWildcard() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") -.set_body_typed([]() { - return PatternWildcard(); -}); +TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard").set_body_typed([]() { return PatternWildcard(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "PatternWildcardNode()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "PatternWildcardNode()"; + }); PatternVar::PatternVar(tvm::relay::Var var) { ObjectPtr n = make_object(); @@ -52,19 +49,17 @@ PatternVar::PatternVar(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternVar") -.set_body_typed([](tvm::relay::Var var) { +TVM_REGISTER_GLOBAL("relay.ir.PatternVar").set_body_typed([](tvm::relay::Var var) { return PatternVar(var); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternVarNode(" << node->var << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternVarNode(" << node->var << ")"; + }); -PatternConstructor::PatternConstructor(Constructor constructor, - tvm::Array patterns) { +PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array patterns) { ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); @@ -74,16 +69,15 @@ PatternConstructor::PatternConstructor(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") -.set_body_typed([](Constructor constructor, tvm::Array patterns) { - return PatternConstructor(constructor, patterns); -}); + .set_body_typed([](Constructor constructor, tvm::Array patterns) { + return PatternConstructor(constructor, patterns); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternConstructorNode(" << node->constructor - << ", " << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")"; + }); PatternTuple::PatternTuple(tvm::Array patterns) { ObjectPtr n = make_object(); @@ -93,16 +87,15 @@ PatternTuple::PatternTuple(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") -.set_body_typed([](tvm::Array patterns) { +TVM_REGISTER_GLOBAL("relay.ir.PatternTuple").set_body_typed([](tvm::Array patterns) { return PatternTuple(patterns); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternTupleNode(" << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternTupleNode(" << node->patterns << ")"; + }); Clause::Clause(Pattern lhs, Expr rhs) { ObjectPtr n = make_object(); @@ -113,17 +106,15 @@ Clause::Clause(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_GLOBAL("relay.ir.Clause") -.set_body_typed([](Pattern lhs, Expr rhs) { +TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) { return Clause(lhs, rhs); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ClauseNode(" << node->lhs << ", " - << node->rhs << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; + }); Match::Match(Expr data, tvm::Array clauses, bool complete) { ObjectPtr n = make_object(); @@ -136,16 +127,16 @@ Match::Match(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay.ir.Match") -.set_body_typed([](Expr data, tvm::Array clauses, bool complete) { - return Match(data, clauses, complete); -}); + .set_body_typed([](Expr data, tvm::Array clauses, bool complete) { + return Match(data, clauses, complete); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "MatchNode(" << node->data << ", " - << node->clauses << ", " << node->complete << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete + << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 76a3f9d4446e9..37b0ff5e747f2 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -23,8 +23,8 @@ */ #include -#include #include +#include namespace tvm { namespace relay { @@ -39,8 +39,7 @@ Id::Id(std::string name_hint) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.NodeSetSpan") -.set_body_typed([](ObjectRef node_ref, Span sp) { +TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp; } else if (auto* rn = node_ref.as()) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 169db62eee269..5ac5805e4a2b5 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -38,19 +38,18 @@ Constant::Constant(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay.ir.Constant") -.set_body_typed([](runtime::NDArray data) { +TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { return Constant(data); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PackedFunc* fprint = Registry::Get("relay._constant_repr"); - CHECK(fprint) << "unable to find printing function for constants"; - std::string data = (*fprint)(GetRef(node)); - p->stream << "Constant(" << data << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PackedFunc* fprint = Registry::Get("relay._constant_repr"); + CHECK(fprint) << "unable to find printing function for constants"; + std::string data = (*fprint)(GetRef(node)); + p->stream << "Constant(" << data << ")"; + }); TensorType ConstantNode::tensor_type() const { auto dtype = DataType(data->dtype); @@ -58,8 +57,7 @@ TensorType ConstantNode::tensor_type() const { for (int i = 0; i < data->ndim; i++) { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); - shape.push_back( - tvm::IntImm(DataType::Int(32), data->shape[i])); + shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorType(shape, dtype); @@ -73,17 +71,15 @@ Tuple::Tuple(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay.ir.Tuple") -.set_body_typed([](tvm::Array fields) { +TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields) { return Tuple(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Tuple(" << node->fields << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Tuple(" << node->fields << ")"; + }); Var::Var(Id vid, Type type_annotation) { ObjectPtr n = make_object(); @@ -94,21 +90,20 @@ Var::Var(Id vid, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay.ir.Var") -.set_body_typed([](std::string str, Type type_annotation) { +TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](std::string str, Type type_annotation) { return Var(str, type_annotation); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Var(" << node->name_hint(); - if (node->type_annotation.defined()) { - p->stream << ", ty="; - p->Print(node->type_annotation); - } - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Var(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); @@ -122,16 +117,16 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") -.set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { - return Call(op, args, attrs, type_args); -}); + .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { + return Call(op, args, attrs, type_args); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " + << node->type_args << ")"; + }); Let::Let(Var var, Expr value, Expr body) { ObjectPtr n = make_object(); @@ -143,17 +138,15 @@ Let::Let(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay.ir.Let") -.set_body_typed([](Var var, Expr value, Expr body) { +TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { return Let(var, value, body); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "LetNode(" << node->var << ", " << node->value - << ", " << node->body << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; + }); If::If(Expr cond, Expr true_branch, Expr false_branch) { ObjectPtr n = make_object(); @@ -166,16 +159,16 @@ If::If(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") -.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); -}); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { + return If(cond, true_branch, false_branch); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfNode(" << node->cond << ", " << node->true_branch - << ", " << node->false_branch << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); TupleGetItem::TupleGetItem(Expr tuple, int index) { ObjectPtr n = make_object(); @@ -186,16 +179,15 @@ TupleGetItem::TupleGetItem(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") -.set_body_typed([](Expr tuple, int index) { +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; + }); RefCreate::RefCreate(Expr value) { ObjectPtr n = make_object(); @@ -205,16 +197,15 @@ RefCreate::RefCreate(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay.ir.RefCreate") -.set_body_typed([](Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { return RefCreate(value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefCreateNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefCreateNode(" << node->value << ")"; + }); RefRead::RefRead(Expr ref) { ObjectPtr n = make_object(); @@ -224,16 +215,13 @@ RefRead::RefRead(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay.ir.RefRead") -.set_body_typed([](Expr ref) { - return RefRead(ref); -}); +TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefReadNode(" << node->ref << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefReadNode(" << node->ref << ")"; + }); RefWrite::RefWrite(Expr ref, Expr value) { ObjectPtr n = make_object(); @@ -244,24 +232,21 @@ RefWrite::RefWrite(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay.ir.RefWrite") -.set_body_typed([](Expr ref, Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { return RefWrite(ref, value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; + }); -TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize") -.set_body_typed([](TempExpr temp) { +TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay.ir.Any") -.set_body_typed([]() { return Any::make(); }); +TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any::make(); }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index cb5d06f2932c9..18fd1c711dd03 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -154,9 +154,7 @@ bool MixedModeMutator::CheckVisited(const Expr& expr) { } } -Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { - return ExprMutator::VisitExpr(expr); -} +Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); } Expr MixedModeMutator::VisitExpr(const Expr& expr) { auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; @@ -178,6 +176,7 @@ class PostOrderRewriter : public MixedModeMutator { auto post = ExprFunctor::VisitExpr(expr); return rewriter_->Rewrite(expr, post); } + protected: ExprRewriter* rewriter_; }; @@ -208,17 +207,11 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const OpNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; @@ -257,9 +250,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); - if (all_ty_params_unchanged && - all_params_unchanged && - ret_type.same_as(op->ret_type) && + if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef(op); } else { @@ -297,9 +288,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -310,10 +299,9 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); - if (op->cond.same_as(guard) && - op->true_branch.same_as(true_b) && + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op);; + return GetRef(op); } else { return If(guard, true_b, false_b); } @@ -356,9 +344,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { } } -Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { - return GetRef(c); -} +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } Expr ExprMutator::VisitExpr_(const MatchNode* m) { std::vector clauses; @@ -394,11 +380,9 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { } } -void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {} -void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {} void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { for (auto field : op->fields) { @@ -440,17 +424,11 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } -void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { - this->VisitExpr(op->tuple); -} +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { - this->VisitExpr(op->ref); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->ref); @@ -501,30 +479,23 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit") -.set_body_typed([](Expr expr, PackedFunc f) { - PostOrderVisit(expr, [f](const Expr& n) { - f(n); - }); - }); +TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); // Implement bind. class ExprBinder : public ExprMutator, PatternMutator { public: - explicit ExprBinder(const tvm::Map& args_map) - : args_map_(args_map) { - } + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} Expr VisitExpr_(const LetNode* op) final { - CHECK(!args_map_.count(op->var)) - << "Cannot bind an internel variable in let"; + CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let"; return ExprMutator::VisitExpr_(op); } Expr VisitExpr_(const FunctionNode* op) final { for (Var param : op->params) { - CHECK(!args_map_.count(param)) - << "Cannnot bind an internal function parameter"; + CHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter"; } return ExprMutator::VisitExpr_(op); } @@ -539,9 +510,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); @@ -549,8 +518,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } Var VisitVar(const Var& v) final { - CHECK(!args_map_.count(v)) - << "Cannnot bind an internal pattern variable"; + CHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable"; return v; } @@ -567,15 +535,10 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(param); } } - if (new_body.same_as(func->body) && - new_params.size() == func->params.size()) { + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } - auto ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -585,11 +548,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(v); } } - ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { @@ -597,15 +556,14 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_GLOBAL("relay.ir.Bind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef input = args[0]; - if (input->IsInstance()) { - *ret = Bind(Downcast(input), args[1]); - } else { - CHECK(input->IsInstance()); - *ret = Bind(Downcast(input), args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef input = args[0]; + if (input->IsInstance()) { + *ret = Bind(Downcast(input), args[1]); + } else { + CHECK(input->IsInstance()); + *ret = Bind(Downcast(input), args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 12a80c5698af2..5312e6d48447c 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -26,11 +26,8 @@ namespace tvm { namespace relay { -Function::Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - DictAttrs attrs) { +Function::Function(tvm::Array params, Expr body, Type ret_type, + tvm::Array type_params, DictAttrs attrs) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -45,34 +42,29 @@ Function::Function(tvm::Array params, FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { - Type param_type = (param->type_annotation.defined()) ? param->type_annotation - : IncompleteType(Kind::kType); + Type param_type = + (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType); param_types.push_back(param_type); } - Type ret_type = (this->ret_type.defined()) ? this->ret_type - : IncompleteType(Kind::kType); + Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType); return FuncType(param_types, ret_type, this->type_params, {}); } TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") -.set_body_typed([](tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, - tvm::DictAttrs attrs) { - return Function(params, body, ret_type, ty_params, attrs); -}); + .set_body_typed([](tvm::Array params, Expr body, Type ret_type, + tvm::Array ty_params, tvm::DictAttrs attrs) { + return Function(params, body, ret_type, ty_params, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type - << ", " << node->body << ", " << node->type_params << ", " - << node->attrs << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body + << ", " << node->type_params << ", " << node->attrs << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/op_strategy.cc b/src/relay/ir/op_strategy.cc index 4e407dbed6552..989e3a6127e92 100644 --- a/src/relay/ir/op_strategy.cc +++ b/src/relay/ir/op_strategy.cc @@ -31,21 +31,18 @@ TVM_REGISTER_NODE_TYPE(OpImplementationNode); TVM_REGISTER_NODE_TYPE(OpSpecializationNode); TVM_REGISTER_NODE_TYPE(OpStrategyNode); -Array OpImplementation::Compute(const Attrs& attrs, - const Array& inputs, +Array OpImplementation::Compute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return (*this)->fcompute(attrs, inputs, out_type); } -te::Schedule OpImplementation::Schedule(const Attrs& attrs, - const Array &outs, +te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array& outs, const Target& target) { return (*this)->fschedule(attrs, outs, target); } void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, - tvm::relay::FTVMSchedule fschedule, - std::string name, + tvm::relay::FTVMSchedule fschedule, std::string name, int plevel) { auto n = make_object(); n->fcompute = fcompute; @@ -55,9 +52,7 @@ void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, (*this)->implementations.push_back(OpImplementation(n)); } -void OpStrategy::AddImplementation(FTVMCompute fcompute, - FTVMSchedule fschedule, - std::string name, +void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name, int plevel) { auto curr_cond = te::SpecializedCondition::Current(); auto self = this->operator->(); @@ -77,38 +72,37 @@ void OpStrategy::AddImplementation(FTVMCompute fcompute, } TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array inputs = args[2]; - Type out_type = args[3]; - *rv = imp.Compute(attrs, inputs, out_type); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array inputs = args[2]; + Type out_type = args[3]; + *rv = imp.Compute(attrs, inputs, out_type); + }); TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array outs = args[2]; - Target target = args[3]; - *rv = imp.Schedule(attrs, outs, target); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array outs = args[2]; + Target target = args[3]; + *rv = imp.Schedule(attrs, outs, target); + }); -TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy") -.set_body([](TVMArgs args, TVMRetValue* rv) { - ObjectPtr n = make_object(); - *rv = OpStrategy(n); +TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = make_object(); + *rv = OpStrategy(n); }); TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpStrategy strategy = args[0]; - FTVMCompute compute = args[1]; - FTVMSchedule schedule = args[2]; - std::string name = args[3]; - int plevel = args[4]; - strategy.AddImplementation(compute, schedule, name, plevel); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpStrategy strategy = args[0]; + FTVMCompute compute = args[1]; + FTVMSchedule schedule = args[2]; + std::string name = args[3]; + int plevel = args[4]; + strategy.AddImplementation(compute, schedule, name, plevel); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index 6795884ef438f..8c366bad641a8 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -27,13 +27,9 @@ namespace tvm { namespace relay { -Pattern PatternMutator::Mutate(const Pattern& pat) { - return (*this)(pat); -} +Pattern PatternMutator::Mutate(const Pattern& pat) { return (*this)(pat); } -Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { - return GetRef(op); -} +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { return GetRef(op); } Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { return PatternVar(VisitVar(op->var)); @@ -55,28 +51,20 @@ Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { return PatternTuple(pat); } -Type PatternMutator::VisitType(const Type& t) { - return t; -} +Type PatternMutator::VisitType(const Type& t) { return t; } Var PatternMutator::VisitVar(const Var& v) { if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, - Var(v->name_hint(), - VisitType(v->type_annotation)))); + var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); } return var_map_.at(v); } -Constructor PatternMutator::VisitConstructor(const Constructor& v) { - return v; -} +Constructor PatternMutator::VisitConstructor(const Constructor& v) { return v; } -void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) {} -void PatternVisitor::VisitPattern_(const PatternVarNode* op) { - VisitVar(op->var); -} +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { VisitVar(op->var); } void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { VisitConstructor(op->constructor); @@ -91,11 +79,9 @@ void PatternVisitor::VisitPattern_(const PatternTupleNode* op) { } } -void PatternVisitor::VisitType(const Type& t) { } +void PatternVisitor::VisitType(const Type& t) {} -void PatternVisitor::VisitVar(const Var& v) { - VisitType(v->type_annotation); -} +void PatternVisitor::VisitVar(const Var& v) { VisitType(v->type_annotation); } void PatternVisitor::VisitConstructor(const Constructor& c) { for (const auto& inp : c->inputs) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 06dd2b16661f1..6b99c93a15d14 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -22,10 +22,9 @@ * \brief Relay specific transformation passes. */ #include -#include #include #include - +#include namespace tvm { namespace relay { @@ -56,9 +55,7 @@ class FunctionPassNode : public PassNode { FunctionPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -113,14 +110,11 @@ FunctionPass::FunctionPass( } // Perform Module -> Module optimizations at the Function level. -IRModule FunctionPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing function pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing function pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. @@ -130,9 +124,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, // only picks up relay::Function if (auto* n = it.second.as()) { Function func = GetRef(n); - auto updated_func = SkipFunction(func) - ? func - : pass_func(func, updated_mod, pass_ctx); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } } @@ -146,14 +138,12 @@ IRModule FunctionPassNode::operator()(IRModule mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return (func->GetAttr(attr::kCompiler).defined()) || - func->GetAttr(attr::kSkipOptimization, 0) != 0; + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, const std::string& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } @@ -161,18 +151,17 @@ Pass CreateFunctionPass( TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return FunctionPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Function pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); } // namespace transform } // namespace relay diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 5b03ceec6ccf3..a240974208737 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -21,17 +21,15 @@ * \file argsort.cc * \brief Argsort operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ArgsortAttrs); -bool ArgsortRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgsortRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const ArgsortAttrs* param = attrs.as(); @@ -39,18 +37,14 @@ bool ArgsortRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "Argsort: expect input type to be TensorType but get " - << types[0]; + << "Argsort: expect input type to be TensorType but get " << types[0]; return false; } reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Expr MakeArgsort(Expr data, - int axis, - bool is_ascend, - DataType dtype) { +Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->is_ascend = is_ascend; @@ -59,19 +53,17 @@ Expr MakeArgsort(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.argsort") -.set_body_typed(MakeArgsort); +TVM_REGISTER_GLOBAL("relay.op._make.argsort").set_body_typed(MakeArgsort); RELAY_REGISTER_OP("argsort") -.describe(R"doc(Returns the indices that would sort an + .describe(R"doc(Returns the indices that would sort an input array along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("Argsort", ArgsortRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(6) + .add_type_rel("Argsort", ArgsortRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 225575c69b001..f641f84aeb136 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -21,17 +21,15 @@ * \file topk.cc * \brief TopK operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(TopKAttrs); -bool TopKRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const TopKAttrs* param = attrs.as(); @@ -66,12 +64,7 @@ bool TopKRel(const Array& types, return true; } -Expr MakeTopK(Expr data, - int k, - int axis, - std::string ret_type, - bool is_ascend, - DataType dtype) { +Expr MakeTopK(Expr data, int k, int axis, std::string ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object(); attrs->k = k; attrs->axis = axis; @@ -82,19 +75,16 @@ Expr MakeTopK(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.topk") -.set_body_typed(MakeTopK); +TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") -.describe(R"doc(Get the top k elements in an input tensor along the given axis. + .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("TopK", TopKRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(6) + .add_type_rel("TopK", TopKRel); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index dd1bcdc1b9eb7..2e93b586baa1a 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -23,12 +23,12 @@ * \brief Registration of annotation operators. */ -#include +#include #include #include #include #include -#include +#include #include "../../transforms/infer_layout_util.h" #include "../type_relations.h" @@ -40,48 +40,46 @@ namespace relay { TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") -.set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int device_type) { + auto attrs = make_object(); + attrs->device_type = device_type; + static const Op& op = Op::Get("on_device"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("on_device") -.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); } -TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") -.set_body_typed([](Expr data) { - return StopFusion(data); +TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion").set_body_typed([](Expr data) { + return StopFusion(data); }); RELAY_REGISTER_OP("annotation.stop_fusion") -.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .describe( + R"code(Annotate an expression to prevent it being fused with previous expressions.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); // relay.annotation.cast_hint TVM_REGISTER_NODE_TYPE(CastHintAttrs); @@ -94,134 +92,127 @@ Expr CastHint(Expr data, DataType dtype) { } RELAY_REGISTER_OP("annotation.cast_hint") -.describe(R"code(Annotate an expression to be cast into specific data type.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - + .describe( + R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_start") -.describe(R"code( + .describe(R"code( Mark the start of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_end") -.describe(R"code( + .describe(R"code( Mark the end of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") -.set_body_typed([](Expr data) { + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint").set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); return Call(op, {data}, Attrs{}, {}); }); RELAY_REGISTER_OP("annotation.checkpoint") -.describe(R"code( + .describe(R"code( Mark a checkpoint for checkpointing memory optimization. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - Array outputs; - for (size_t i = 0; i < inputs.size(); ++i) { - outputs.push_back(topi::identity(inputs[i])); - } - return outputs; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + Array outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + outputs.push_back(topi::identity(inputs[i])); + } + return outputs; + }); TVM_REGISTER_NODE_TYPE(CompilerAttrs); RELAY_REGISTER_OP("annotation.compiler_begin") -.describe(R"code( + .describe(R"code( Beginning of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_begin"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_begin"); + return Call(op, {expr}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("annotation.compiler_end") -.describe(R"code( + .describe(R"code( End of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_end"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, std::string compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_end"); + return Call(op, {expr}, Attrs(attrs), {}); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index 8e8586f9d213d..790f1ee7e2a09 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -22,36 +22,37 @@ * \brief Property def of nn operators. */ -#include -#include -#include #include +#include +#include +#include + #include -#include "./type_relations.h" + #include "./op_common.h" +#include "./type_relations.h" namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(DebugAttrs); -Array DebugCompute(const Attrs& attrs, - const Array& inputs, +Array DebugCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return Array{ topi::identity(inputs[0]) }; + return Array{topi::identity(inputs[0])}; } RELAY_REGISTER_OP("debug") -.describe(R"code(Enter the interpreter's debugger. + .describe(R"code(Enter the interpreter's debugger. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("program", "Tuple", "The program to execute before debugging.") -.set_support_level(1) -.set_attrs_type() -.add_type_rel("Debug", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("FTVMCompute", DebugCompute); + .set_num_inputs(1) + .add_argument("program", "Tuple", "The program to execute before debugging.") + .set_support_level(1) + .set_attrs_type() + .add_type_rel("Debug", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("FTVMCompute", DebugCompute); Expr MakeDebug(Expr expr, std::string name) { auto dattrs = make_object(); @@ -64,9 +65,7 @@ Expr MakeDebug(Expr expr, std::string name) { return Call(op, {expr}, Attrs(dattrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.debug") -.set_body_typed(MakeDebug); +TVM_REGISTER_GLOBAL("relay.op._make.debug").set_body_typed(MakeDebug); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 4aae549f217bc..923965f981925 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -26,14 +26,14 @@ * used as "barrier" to avoid fusing operators belonging to differen devices. */ -#include #include #include #include #include +#include -#include "type_relations.h" #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -42,27 +42,25 @@ namespace relay { TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_GLOBAL("relay.op._make.device_copy") -.set_body_typed([](Expr data, int src_dev_type, - int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + static const Op& op = Op::Get("device_copy"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("device_copy") -.describe(R"code( + .describe(R"code( Copy data from one tensor to another. The source and destination might be on different devices. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/dilation2d.cc b/src/relay/op/image/dilation2d.cc index 7146f3736dd68..43ec856826741 100644 --- a/src/relay/op/image/dilation2d.cc +++ b/src/relay/op/image/dilation2d.cc @@ -21,9 +21,10 @@ * \file dilation2d.cc * \brief Morphological dilation operator */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -32,27 +33,20 @@ namespace relay { // relay.image.dilation2d TVM_REGISTER_NODE_TYPE(Dilation2DAttrs); -template -Array > Dilation2DInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > Dilation2DInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); - return Array >{{params->data_layout, params->kernel_layout}, - {params->data_layout}}; + return Array >{{params->data_layout, params->kernel_layout}, {params->data_layout}}; } // Positional relay function to create dilation2d operator // used by frontend FFI. -Expr MakeDilation2D(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilations, - std::string data_layout, - std::string kernel_layout, +Expr MakeDilation2D(Expr data, Expr weight, Array strides, Array padding, + Array dilations, std::string data_layout, std::string kernel_layout, DataType out_dtype) { auto attrs = make_object(); attrs->strides = std::move(strides); @@ -67,7 +61,7 @@ Expr MakeDilation2D(Expr data, template bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -113,15 +107,13 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -136,26 +128,24 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d") -.set_body_typed(MakeDilation2D); - +TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d").set_body_typed(MakeDilation2D); RELAY_REGISTER_OP("image.dilation2d") -.describe(R"code(Computes grayscale dilation of 4D input and 3D filter. + .describe(R"code(Computes grayscale dilation of 4D input and 3D filter. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - **weight**: (in_channels, height, width) - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Dilation2D", Dilation2DRel) -.set_attr("FInferCorrectLayout", - Dilation2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Dilation2D", Dilation2DRel) + .set_attr("FInferCorrectLayout", + Dilation2DInferCorrectLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index c8f9762566004..efd815b842f24 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -21,9 +21,10 @@ * \file resize.cc * \brief Image resize operators */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -31,9 +32,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); -bool ResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -46,8 +45,8 @@ bool ResizeRel(const Array& types, const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "Resize only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Resize only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, param->size[0]); @@ -59,20 +58,14 @@ bool ResizeRel(const Array& types, } // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, - Array size, - std::string layout, - std::string method, - std::string coordinate_transformation_mode, - DataType out_dtype) { +Expr MakeResize(Expr data, Array size, std::string layout, std::string method, + std::string coordinate_transformation_mode, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); @@ -83,13 +76,10 @@ Expr MakeResize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.image._make.resize") -.set_body_typed(MakeResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); RELAY_REGISTER_OP("image.resize") -.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -102,26 +92,22 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("Resize", ResizeRel) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize", ResizeRel) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); -bool CropAndResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CropAndResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* boxes = types[1].as(); const auto* box_indices = types[2].as(); - if (data == nullptr || boxes == nullptr || - box_indices == nullptr) return false; + if (data == nullptr || boxes == nullptr || box_indices == nullptr) return false; const CropAndResizeAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -142,19 +128,12 @@ bool CropAndResizeRel(const Array& types, oshape.Set(3, crop_size[1]); auto bshape = layout_converter.BackwardShape(oshape); // assign output type - reporter->Assign(types[3], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[3], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } -Expr MakeCropAndResize(Expr data, - Expr boxes, - Expr box_indices, - Array crop_size, - std::string layout, - std::string method, - double extrapolation_value, +Expr MakeCropAndResize(Expr data, Expr boxes, Expr box_indices, Array crop_size, + std::string layout, std::string method, double extrapolation_value, DataType out_dtype) { auto attrs = make_object(); attrs->crop_size = std::move(crop_size); @@ -166,12 +145,11 @@ Expr MakeCropAndResize(Expr data, return Call(op, {data, boxes, box_indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize") -.set_body_typed(MakeCropAndResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize").set_body_typed(MakeCropAndResize); RELAY_REGISTER_OP("image.crop_and_resize") - .describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -184,14 +162,14 @@ RELAY_REGISTER_OP("image.crop_and_resize") for layout NHWC (batch_size, crop_size[0], crop_size[1], channels) )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("boxes", "Tensor", "The boxes tensor.") -.add_argument("box_indices", "Tensor", "The box indices tensor.") -.set_attrs_type() -.set_support_level(5) -.add_type_rel("CropAndResize", CropAndResizeRel) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("boxes", "Tensor", "The boxes tensor.") + .add_argument("box_indices", "Tensor", "The box indices tensor.") + .set_attrs_type() + .set_support_level(5) + .add_type_rel("CropAndResize", CropAndResizeRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c7ffc95c05d57..ec96e23a01fb4 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -23,11 +23,11 @@ */ #include -#include #include #include #include #include +#include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" @@ -109,12 +109,11 @@ std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; std::vector raw_shape; CHECK_EQ(shape->ndim, 1u); - CHECK_EQ(shape->dtype.code, 0U) - << "The dtype of constant shape must be int32 or int64, but got " - << runtime::DLDataType2String(shape->dtype); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << runtime::DLDataType2String(shape->dtype); CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) - << "The dtype of constant shape must be int32 or int64, but got" - << runtime::DLDataType2String(shape->dtype); + << "The dtype of constant shape must be int32 or int64, but got" + << runtime::DLDataType2String(shape->dtype); if (shape->dtype.bits == 32) { const int32_t* int_ptr = reinterpret_cast(shape->data); @@ -331,14 +330,12 @@ Expr ToTupleType(const Type& t, const std::vector& exprs) { } } -TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType") -.set_body_typed([](Type type) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType").set_body_typed([](Type type) { auto types = FlattenTupleType(type); return Array(types.begin(), types.end()); }); -TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType") -.set_body_typed([](Type type, Expr expr) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType").set_body_typed([](Type type, Expr expr) { auto exprs = FromTupleType(type, expr); return Array(exprs.begin(), exprs.end()); }); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d2174579cc31b..08637d9aa0001 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -22,12 +22,12 @@ * \brief Property def of bitserial operators. */ -#include #include #include +#include -#include "../op_common.h" #include "../../transforms/infer_layout_util.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -109,11 +109,11 @@ efficient implementation of bitserial operations. packed must be divisible by number of bits. - **out**: Packed tensor with shape appropriately compressed. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(2) -.add_type_rel("BitPack", BitPackRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("BitPack", BitPackRel); // relay.nn.bitserial_conv2d TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); @@ -137,10 +137,8 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr Array oshape({dshape_nchw[0], param->channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set( - 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); - oshape.Set( - 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type @@ -187,14 +185,14 @@ on some platforms. - **out**: Output with same layout as input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("BinaryConv2D", BinaryConv2DRel) -.set_attr("FInferCorrectLayout", - BinaryConv2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("BinaryConv2D", BinaryConv2DRel) + .set_attr("FInferCorrectLayout", + BinaryConv2DInferCorrectLayout); // relay.nn.bitserial_dense TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); @@ -248,12 +246,12 @@ RELAY_REGISTER_OP("nn.bitserial_dense") - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "2D Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("BinaryDense", BinaryDenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "2D Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("BinaryDense", BinaryDenseRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index b3e1772e904f0..4a307c506a850 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -21,32 +21,25 @@ * \file convolution.cc * \brief Convolution operators */ -#include -#include +#include "convolution.h" + #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "convolution.h" namespace tvm { namespace relay { template -Expr MakeConv(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, std::string kernel_layout, + std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -63,19 +56,10 @@ Expr MakeConv(Expr data, } template -Expr MakeConvWinograd(Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, +Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; @@ -93,9 +77,7 @@ Expr MakeConvWinograd(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } -Expr MakeConvWinogradWeightTransform(Expr weight, - int tile_size, - std::string op_name) { +Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; const Op& op = Op::Get(op_name); @@ -103,20 +85,11 @@ Expr MakeConvWinogradWeightTransform(Expr weight, } template -Expr MakeConvTranspose(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype, - std::string op_name) { +Expr MakeConvTranspose(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -134,21 +107,11 @@ Expr MakeConvTranspose(Expr data, } template -Expr MakeDeformableConv(Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, std::string out_layout, + DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = strides; attrs->padding = padding; @@ -165,32 +128,21 @@ Expr MakeDeformableConv(Expr data, return Call(op, {data, offset, weight}, Attrs{attrs}, {}); } - // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv1d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv1d"); + }); RELAY_REGISTER_OP("nn.conv1d") -.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). + .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -202,40 +154,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_width) if `layout` is `NCW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1D", Conv1DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1D", Conv1DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv2d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv2d"); + }); RELAY_REGISTER_OP("nn.conv2d") -.describe(R"code(2D convolution layer (e.g. spatial convolution over images). + .describe(R"code(2D convolution layer (e.g. spatial convolution over images). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -247,40 +188,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv3d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv3d"); + }); RELAY_REGISTER_OP("nn.conv3d") -.describe(R"code(3D convolution layer (e.g. convolution over 3D image data, + .describe(R"code(3D convolution layer (e.g. convolution over 3D image data, like Magnetic Resonance Imaging (MRI) data in medicine). This layer creates a convolution kernel that is convolved @@ -293,40 +223,30 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv3D", Conv3DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv3D", Conv3DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); + }); RELAY_REGISTER_OP("nn.conv2d_transpose") -.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -347,40 +267,31 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout) -.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout) + .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); // relay.nn.conv1d_transpose TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); + }); RELAY_REGISTER_OP("nn.conv1d_transpose") -.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -400,39 +311,29 @@ said convolution. out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1DTranspose", Conv1DTransposeRel); // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform"); -}); - + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") -.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. + .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv2d_winograd_weight_transform. @@ -443,64 +344,54 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinograd", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinograd", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_weight_transform TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv2d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm. + .describe(R"code(Weight transformation of winograd fast convolution algorithm. Separate this into another operator in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); // relay.nn.contrib_conv3d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform"); -}); + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv3d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") -.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. + .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv3d_winograd_weight_transform. @@ -511,22 +402,21 @@ RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") - **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinograd", Conv3DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinograd", Conv3DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv3d_winograd_weight_transform TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv3d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform") .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. @@ -536,18 +426,16 @@ weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); -Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, - int convolution_algorithm, +Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { auto attrs = make_object(); attrs->convolution_algorithm = convolution_algorithm; @@ -557,99 +445,75 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") -.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); + .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. + .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. Separate this into another symbol in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // Positional relay function to create depthwise conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_depthwise_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); @@ -673,36 +537,26 @@ along the channel axis, and also evenly split `weight` along the first dimension the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained by concating all the *g* results. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("offset", "Tensor", "The offset tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(5) -.add_type_rel("DeformableConv2D", DeformableConv2DRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("offset", "Tensor", "The offset tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(5) + .add_type_rel("DeformableConv2D", DeformableConv2DRel); // Positional relay function to create deformable_conv2d operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") -.set_body_typed([](Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeDeformableConv( - data, offset, weight, strides, padding, dilation, - deformable_groups, groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); -}); + .set_body_typed([](Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, std::string out_layout, + DataType out_dtype) { + return MakeDeformableConv( + data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index f33cd7e4d7c76..5dc649b988825 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -35,7 +35,6 @@ namespace tvm { namespace relay { - // Standard convolution operator shape relations template bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -92,7 +91,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, auto wshape = trans_kernel_layout.ForwardShape(weight->shape); if (param->kernel_size.defined()) { // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) ) + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } @@ -110,7 +109,8 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, if (!dshape_ncw[2].as()) { oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, - param->strides[0]) + 1); + param->strides[0]) + + 1); } else { oshape.Set(2, dshape_ncw[2]); } @@ -159,8 +159,8 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); bool is_depthwise = false; if (param->groups > 1) { - CHECK(weight && weight->shape.defined()) << - "Weight shape must be specified when groups is greater than 1."; + CHECK(weight && weight->shape.defined()) + << "Weight shape must be specified when groups is greater than 1."; Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { @@ -222,15 +222,13 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -336,22 +334,19 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[3].as()) { - oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, - param->strides[2]) + 1); + oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -365,7 +360,6 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } - // Winograd convolution shape relations inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -378,15 +372,14 @@ inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_i CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - std::vector oshape { + std::vector oshape{ param->tile_size + data->shape[2] - 1, param->tile_size + data->shape[3] - 1, data->shape[0], data->shape[1], }; - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } @@ -404,7 +397,7 @@ inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_i // Shape of packed weights depends on whether depth is being transformed or not. Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); auto* depth_imm = data->shape[2].as(); - bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8); + bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); if (transform_depth) { oshape.Set(0, param->tile_size + data->shape[2] - 1); oshape.Set(1, param->tile_size + data->shape[3] - 1); @@ -449,10 +442,8 @@ inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int return true; } -template -bool Conv2DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -467,13 +458,13 @@ bool Conv2DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); @@ -508,14 +499,12 @@ bool Conv2DWinogradRel(const Array& types, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, (dshape_nchw[2] + pad_h - - dilated_ksize_y) / param->strides[0] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, (dshape_nchw[3] + pad_w - - dilated_ksize_x) / param->strides[1] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -530,11 +519,8 @@ bool Conv2DWinogradRel(const Array& types, return true; } - -template -bool Conv3DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -549,13 +535,13 @@ bool Conv3DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIDHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); @@ -591,20 +577,17 @@ bool Conv3DWinogradRel(const Array& types, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, (dshape_ncdhw[2] + pad_d - - dilated_ksize_d) / param->strides[0] + 1); + oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[2].as()) { - oshape.Set(3, (dshape_ncdhw[3] + pad_h - - dilated_ksize_y) / param->strides[1] + 1); + oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (dshape_ncdhw[4] + pad_w - - dilated_ksize_x) / param->strides[2] + 1); + oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -619,12 +602,9 @@ bool Conv3DWinogradRel(const Array& types, return true; } - // Transposed convolution shape relations template -bool Conv1DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -641,19 +621,19 @@ bool Conv1DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -664,9 +644,8 @@ bool Conv1DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 1); CHECK_EQ(param->dilation.size(), 1); - Array wshape({dshape_ncw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0]}); + Array wshape( + {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -683,14 +662,12 @@ bool Conv1DTransposeRel(const Array& types, // check the size CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -700,8 +677,8 @@ bool Conv1DTransposeRel(const Array& types, IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - - pad_w + param->output_padding[0])); + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -712,11 +689,8 @@ bool Conv1DTransposeRel(const Array& types, return true; } - template -bool Conv2DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -733,19 +707,19 @@ bool Conv2DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -756,10 +730,8 @@ bool Conv2DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape({dshape_nchw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -778,14 +750,12 @@ bool Conv2DTransposeRel(const Array& types, CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -796,10 +766,10 @@ bool Conv2DTransposeRel(const Array& types, Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - pad_h + param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - pad_w + param->output_padding[1])); + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -810,7 +780,6 @@ bool Conv2DTransposeRel(const Array& types, return true; } - // Deformable Convolution shape relations. template bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -830,11 +799,8 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& if (param->kernel_size.defined() && param->channels.defined()) { CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape( - {param->channels, - indexdiv(data->shape[1], param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({param->channels, indexdiv(data->shape[1], param->groups), + param->kernel_size[0], param->kernel_size[1]}); channels = param->channels; ksize_y = param->kernel_size[0]; ksize_x = param->kernel_size[1]; @@ -852,14 +818,12 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[0])) << "DeformableConv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; + << " channels=" << param->channels << " wshape=" << wshape; } CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); channels = wshape[0]; @@ -873,15 +837,13 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); DataType out_dtype = param->out_dtype; // infer offset shape - Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, - oshape[2], oshape[3]}); + Array offset_shape( + {data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); if (out_dtype.bits() == 0) { out_dtype = data->dtype; @@ -891,23 +853,20 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& return true; } - -template -Array > ConvInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > ConvInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); // We always make other operators to fit the layouts of convolution layers // So this inference ignores all inputs - return Array >{{params->data_layout, params->kernel_layout}, - {params->out_layout == "" ? - params->data_layout : params->out_layout}}; + return Array >{ + {params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? params->data_layout : params->out_layout}}; } - } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 5cdca8011aa2d..670878d278041 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -22,20 +22,23 @@ * \brief Property def of nn operators. */ -#include -#include -#include -#include +#include "nn.h" + #include #include -#include #include -#include +#include +#include +#include +#include +#include + #include -#include "../type_relations.h" +#include + #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "nn.h" +#include "../type_relations.h" namespace tvm { namespace relay { @@ -43,9 +46,7 @@ namespace relay { // relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); -bool BiasAddRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -61,45 +62,36 @@ bool BiasAddRel(const Array& types, << "axis " << param->axis << " is out of range"; // assign output type - reporter->Assign(types[1], TensorType( - {data->shape[axis]}, data->dtype)); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; } - // Positional relay function to create dense operator used by frontend FFI. -Expr MakeBiasAdd(Expr data, - Expr bias, - int axis) { +Expr MakeBiasAdd(Expr data, Expr bias, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return Call(op, {data, bias}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add") -.set_body_typed(MakeBiasAdd); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd); RELAY_REGISTER_OP("nn.bias_add") -.describe(R"code(Add bias to an axis of the input. + .describe(R"code(Add bias to an axis of the input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("bias", "1D Tensor", "Bias.") -.set_support_level(1) -.add_type_rel("BiasAdd", BiasAddRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("bias", "1D Tensor", "Bias.") + .set_support_level(1) + .add_type_rel("BiasAdd", BiasAddRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; + }); // relay.nn.fifo_buffer TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); @@ -111,9 +103,7 @@ Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { return Call(op, {input, buffer}, Attrs(attrs), {}); } -bool FIFOBufferRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FIFOBufferRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* input = types[0].as(); @@ -125,9 +115,8 @@ bool FIFOBufferRel(const Array& types, CHECK(param != nullptr); CHECK_EQ(input->shape.size(), buffer->shape.size()); - const size_t buffer_axis - = static_cast(param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis - : param->axis); + const size_t buffer_axis = static_cast( + param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis : param->axis); reporter->Assert(buffer_axis < buffer->shape.size()); for (size_t i = 0; i < buffer->shape.size(); ++i) { @@ -143,11 +132,10 @@ bool FIFOBufferRel(const Array& types, return true; } -TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer") -.set_body_typed(MakeFIFOBuffer); +TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer").set_body_typed(MakeFIFOBuffer); RELAY_REGISTER_OP("nn.fifo_buffer") -.describe(R"code(FIFO buffer + .describe(R"code(FIFO buffer Compute equivalent of ``` @@ -159,23 +147,18 @@ Useful for * Encoding explicit re-use of computation in convolution ops operated on a sliding window input * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Latest input") -.add_argument("buffer", "Tensor", - "Buffer storing latest [length_buffer] inputs") -.set_support_level(3) -.add_type_rel("FIFOBuffer", FIFOBufferRel); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Latest input") + .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs") + .set_support_level(3) + .add_type_rel("FIFOBuffer", FIFOBufferRel); // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); // Positional relay function to create dense operator used by frontend FFI. -Expr MakeDense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; @@ -183,70 +166,58 @@ Expr MakeDense(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.dense") -.set_body_typed(MakeDense); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("Dense", DenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("Dense", DenseRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. -Expr MakeLeakyRelu(Expr data, - double alpha) { +Expr MakeLeakyRelu(Expr data, double alpha) { auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu") -.set_body_typed(MakeLeakyRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu").set_body_typed(MakeLeakyRelu); RELAY_REGISTER_OP("nn.leaky_relu") -.describe(R"code(Leaky version of a Rectified Linear Unit. + .describe(R"code(Leaky version of a Rectified Linear Unit. `y = x > 0 ? x : alpha * x` )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(3) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::leaky_relu(inputs[0], param->alpha) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(3) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::leaky_relu(inputs[0], param->alpha)}; + }); // relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); -bool PReluRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PReluRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -256,7 +227,7 @@ bool PReluRel(const Array& types, CHECK(param != nullptr); CHECK(param->axis < static_cast(data->shape.size())) - << "Wrong axis (" << param->axis << ")value."; + << "Wrong axis (" << param->axis << ")value."; // assign alpha type Array alpha_shape({data->shape[param->axis]}); @@ -267,72 +238,59 @@ bool PReluRel(const Array& types, return true; } -template -Array > PReluInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { - +template +Array> PReluInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { CHECK_EQ(old_in_layouts.size(), 2U); CHECK_EQ(old_in_types.size(), 2U); Layout data_layout = old_in_layouts[0]; if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 2U); } - return Array >{{data_layout, Layout("C")}, - {data_layout}}; + return Array>{{data_layout, Layout("C")}, {data_layout}}; } // Positional relay function to create prelu operator used by frontend FFI. -Expr MakePRelu(Expr data, - Expr alpha, - int axis) { +Expr MakePRelu(Expr data, Expr alpha, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu") -.set_body_typed(MakePRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu").set_body_typed(MakePRelu); RELAY_REGISTER_OP("nn.prelu") -.describe(R"code(Parametric version of a Rectified Linear Unit. + .describe(R"code(Parametric version of a Rectified Linear Unit. It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`, where :math:`*` is an channelwise multiplication for each sample in the batch. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("alpha", "Tensor", "Input channelwise alpha.") -.set_support_level(3) -.add_type_rel("PRelu", PReluRel) -.set_attr("FInferCorrectLayout", PReluInferCorrectLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::prelu(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input data.") + .add_argument("alpha", "Tensor", "Input channelwise alpha.") + .set_support_level(3) + .add_type_rel("PRelu", PReluRel) + .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::prelu(inputs[0], inputs[1], param->axis)}; + }); // relay.softmax TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); -TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); }); - RELAY_REGISTER_OP("nn.softmax") .describe(R"code(Softmax layer. @@ -343,16 +301,14 @@ RELAY_REGISTER_OP("nn.softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel); // relay.nn.log_softmax -TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); @@ -369,26 +325,22 @@ RELAY_REGISTER_OP("nn.log_softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) - << "log_softmax currently only works on last dimension"; - return Array{ topi::nn::log_softmax(inputs[0]) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) + << "log_softmax currently only works on last dimension"; + return Array{topi::nn::log_softmax(inputs[0])}; + }); // relay.nn.batch_flatten -bool BatchFlattenRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchFlattenRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -418,13 +370,10 @@ Expr MakeBatchFlatten(Expr data) { return Call(op, {data}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten") -.set_body_typed(MakeBatchFlatten); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten").set_body_typed(MakeBatchFlatten); RELAY_REGISTER_OP("nn.batch_flatten") -.describe(R"code(Flattens the input into a 2-D array. + .describe(R"code(Flattens the input into a 2-D array. For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. @@ -445,53 +394,42 @@ Example:: [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("BatchFlatten", BatchFlattenRel) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::nn::flatten(inputs[0]) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("BatchFlatten", BatchFlattenRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::nn::flatten(inputs[0])}; + }); // relu -TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") -.set_body_typed([](Expr data) { - static const Op& op = Op::Get("nn.relu"); - return Call(op, {data}, Attrs(), {}); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.relu").set_body_typed([](Expr data) { + static const Op& op = Op::Get("nn.relu"); + return Call(op, {data}, Attrs(), {}); +}); RELAY_REGISTER_OP("nn.relu") -.describe(R"code(Returns the relu input array, computed element-wise. + .describe(R"code(Returns the relu input array, computed element-wise. .. math:: max(x, 0) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::relu(inputs[0], 0.0f) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::relu(inputs[0], 0.0f)}; + }); // Positional relay function to create LRN operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(LRNAttrs); -Expr MakeLRN(Expr data, - int size, - int axis, - double alpha, - double beta, - double bias) { +Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) { auto attrs = make_object(); attrs->size = size; attrs->axis = axis; @@ -502,11 +440,10 @@ Expr MakeLRN(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") -.set_body_typed(MakeLRN); +TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn").set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") -.describe(R"code(LRN layer. + .describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, @@ -519,20 +456,16 @@ centered at that value (zero padding is added where necessary). - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Identity", IdentityRel); // Positional relay function to create L2Normalize operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); -Expr MakeL2Normalize(Expr data, - double eps, - Array axis) { +Expr MakeL2Normalize(Expr data, double eps, Array axis) { auto attrs = make_object(); attrs->eps = eps; attrs->axis = std::move(axis); @@ -540,11 +473,10 @@ Expr MakeL2Normalize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") -.set_body_typed(MakeL2Normalize); +TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") -.describe(R"code(L2 Normalization layer. + .describe(R"code(L2 Normalization layer. Normalizes along dimension axis using an L2 norm @@ -553,19 +485,17 @@ Normalizes along dimension axis using an L2 norm - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Identity", IdentityRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Identity", IdentityRel); // Dropout TVM_REGISTER_NODE_TYPE(DropoutAttrs); -bool DropoutRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DropoutRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -585,22 +515,21 @@ Expr MakeDropout(Expr data, double rate) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") -.set_body_typed(MakeDropout); +TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout").set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") -.describe(R"code(Applies the dropout operation to the input array. + .describe(R"code(Applies the dropout operation to the input array. During training, each element of the input is set to zero with probability ``p``. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input to which dropout will be applied.") -.set_support_level(1) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Dropout", DropoutRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_support_level(1) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Dropout", DropoutRel); // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); @@ -639,9 +568,7 @@ Array> BatchNormInferCorrectLayout(const Attrs& attrs, {ret, c_layout, c_layout}}; } -bool BatchNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); const auto* data = types[0].as(); @@ -663,8 +590,7 @@ bool BatchNormRel(const Array& types, // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) std::vector fields; - auto vec_ty = TensorType(Array({data->shape[axis]}), - data->dtype); + auto vec_ty = TensorType(Array({data->shape[axis]}), data->dtype); fields.push_back(TensorType(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); @@ -672,8 +598,8 @@ bool BatchNormRel(const Array& types, return true; } -Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, - int axis, double epsilon, bool center, bool scale) { +Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, + double epsilon, bool center, bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -683,11 +609,10 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") -.set_body_typed(MakeBatchNorm); +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm").set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") -.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). + .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. @@ -723,24 +648,21 @@ axis to be the last item in the input shape. .. note:: This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "Input to which batch_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.add_argument("moving_mean", "Tensor", "Running mean of input.") -.add_argument("moving_var", "Tensor", "Running variance of input.") -.set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) -.set_support_level(1) -.add_type_rel("BatchNorm", BatchNormRel); - + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) + .set_support_level(1) + .add_type_rel("BatchNorm", BatchNormRel); // instance_norm TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -bool InstanceNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool InstanceNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -755,8 +677,8 @@ bool InstanceNormRel(const Array& types, return true; } -Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -767,12 +689,12 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon } TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeInstanceNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeInstanceNorm, args, rv); + }); RELAY_REGISTER_OP("nn.instance_norm") -.describe(R"code(Instance Normalization (Ulyanov and et al., 2016) + .describe(R"code(Instance Normalization (Ulyanov and et al., 2016) Applies instance normalization to the n-dimensional input array. .. math:: @@ -796,21 +718,18 @@ to be the last item in the input shape. This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which instance_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("InstanceNorm", InstanceNormRel); - + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("InstanceNorm", InstanceNormRel); // layer_norm TVM_REGISTER_NODE_TYPE(LayerNormAttrs); -bool LayerNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayerNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -825,8 +744,8 @@ bool LayerNormRel(const Array& types, return true; } -Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -837,27 +756,25 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, } TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLayerNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayerNorm, args, rv); + }); RELAY_REGISTER_OP("nn.layer_norm") -.describe(R"code( + .describe(R"code( )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which layer_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("LayerNorm", LayerNormRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("LayerNorm", LayerNormRel); // group_norm TVM_REGISTER_NODE_TYPE(GroupNormAttrs); -bool GroupNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GroupNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -872,10 +789,10 @@ bool GroupNormRel(const Array& types, return true; } -Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, - int axis, double epsilon, bool center, bool scale) { +Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, double epsilon, + bool center, bool scale) { auto attrs = make_object(); - attrs->num_groups = num_groups; + attrs->num_groups = num_groups; attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -885,12 +802,12 @@ Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, } TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGroupNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGroupNorm, args, rv); + }); RELAY_REGISTER_OP("nn.group_norm") -.describe(R"code( + .describe(R"code( Group normalization normalizes over group of channels for each training examples. We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put all the channels into a single group, group normalization becomes Layer normalization. @@ -917,19 +834,16 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which group_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("GroupNorm", GroupNormRel); - + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("GroupNorm", GroupNormRel); // relay.nn.batch_matmul -bool BatchMatmulRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); @@ -938,12 +852,10 @@ bool BatchMatmulRel(const Array& types, CHECK(x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; Array oshape = x->shape; oshape.Set(2, y->shape[1]); @@ -953,21 +865,16 @@ bool BatchMatmulRel(const Array& types, return true; } - // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, - Expr y) { +Expr MakeBatchMatmul(Expr x, Expr y) { static const Op& op = Op::Get("nn.batch_matmul"); return Call(op, {x, y}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul") -.set_body_typed(MakeBatchMatmul); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") -.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` + .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` are data in batch. .. math:: @@ -979,34 +886,31 @@ are data in batch. - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "3D Tensor", "First input.") -.add_argument("y", "3D Tensor", "Second input.") -.set_support_level(10) -.add_type_rel("BatchMatmul", BatchMatmulRel); - + .set_num_inputs(2) + .add_argument("x", "3D Tensor", "First input.") + .add_argument("y", "3D Tensor", "Second input.") + .set_support_level(10) + .add_type_rel("BatchMatmul", BatchMatmulRel); // relay.nn.cross_entropy -bool CrossEntropyRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 2 && y->shape.size() == 2) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; // assign output type reporter->Assign(types[2], TensorType({}, x->dtype)); return true; @@ -1018,22 +922,61 @@ Expr MakeCrossEntropy(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy") -.set_body_typed(MakeCrossEntropy); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy").set_body_typed(MakeCrossEntropy); RELAY_REGISTER_OP("nn.cross_entropy") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Do log on the data - do not accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); + +// relay.nn.dilate +TVM_REGISTER_NODE_TYPE(DilateAttrs); + +bool DilateRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* x = types[0].as(); + const DilateAttrs* param = attrs.as(); + if (x == nullptr) return false; + CHECK_EQ(x->shape.size(), param->strides.size()); + + std::vector oshape; + for (size_t i = 0; i < param->strides.size(); ++i) { + if (!x->shape[i].as()) { + oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1); + } else { + oshape.push_back(x->shape[i]); + } + } + reporter->Assign(types[1], TensorType(Array(oshape), x->dtype)); + return true; +} + +// Positional relay function to create dilate operator used by frontend FFI. +Expr MakeDilate(Expr data, Array strides) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + static const Op& op = Op::Get("nn.dilate"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate); + +RELAY_REGISTER_OP("nn.dilate") + .describe(R"code( +Dilate data with zeros. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("x", "1D Tensor", "Data to dilate.") + .set_support_level(10) + .add_type_rel("Dilate", DilateRel); // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { @@ -1041,21 +984,19 @@ Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits") -.set_body_typed(MakeCrossEntropyWithLogits); - + .set_body_typed(MakeCrossEntropyWithLogits); RELAY_REGISTER_OP("nn.cross_entropy_with_logits") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); @@ -1083,8 +1024,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, oshape[3] * block_size); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } @@ -1141,8 +1081,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, indexdiv(oshape[3], block_size)); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index dc876e863ad04..0fb02638db077 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,6 +24,11 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include +#include +#include +#include + #include namespace tvm { @@ -58,8 +63,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape = weight->shape; CHECK(static_cast(weight->shape.size()) == 2); - CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], - weight->shape[1])) + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) << "DenseRel: input dimension doesn't match," << " data shape=" << data->shape << ", weight shape=" << weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index abff06ef9d887..e416a066a468d 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -21,12 +21,14 @@ * \file pad.cc * \brief Implementation of operator pad */ +#include +#include +#include #include #include -#include -#include -#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,11 @@ namespace relay { // relay.nn.pad TVM_REGISTER_NODE_TYPE(PadAttrs); -Array > PadInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array> PadInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - PadAttrs *params = const_cast(attrs.as()); + PadAttrs* params = const_cast(attrs.as()); Layout ret; // If new_in_layouts are defined, this code tries to modify the layout. @@ -108,12 +108,10 @@ Array > PadInferCorrectLayout( } } - return Array >{{ret}, {ret}}; + return Array>{{ret}, {ret}}; } -bool PadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -124,28 +122,26 @@ bool PadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; if (!data->shape[i].as()) { auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); @@ -155,21 +151,17 @@ bool PadRel(const Array& types, } } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } -Array PadCompute(const Attrs& attrs, - const Array& inputs, +Array PadCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); auto pad_width = param->pad_width; - CHECK(pad_width.size() == inputs[0].ndim() && - pad_width[0].size() == 2) - << "Illegal pad_width"; + CHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width"; Array pad_before; for (size_t i = 0; i < pad_width.size(); ++i) { pad_before.push_back(pad_width[i][0]); @@ -179,18 +171,13 @@ Array PadCompute(const Attrs& attrs, pad_after.push_back(pad_width[i][1]); } const auto* out_ttype = out_type.as(); - return Array{ topi::pad(inputs[0], pad_before, pad_after, - tvm::tir::make_const(out_ttype->dtype, param->pad_value), - "T_pad", - topi::kElementWise, - param->pad_mode) }; + return Array{topi::pad(inputs[0], pad_before, pad_after, + tvm::tir::make_const(out_ttype->dtype, param->pad_value), + "T_pad", topi::kElementWise, param->pad_mode)}; } // Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, - Array > pad_width, - double pad_value, - std::string pad_mode) { +Expr MakePad(Expr data, Array> pad_width, double pad_value, std::string pad_mode) { auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); @@ -199,29 +186,25 @@ Expr MakePad(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.pad") -.set_body_typed(MakePad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad); RELAY_REGISTER_OP("nn.pad") -.describe(R"code(Pad for n-D tensor. + .describe(R"code(Pad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Pad", PadRel) -.set_attr("FInferCorrectLayout", PadInferCorrectLayout) -.set_attr("TOpPattern", kInjective) -.set_attr("FTVMCompute", PadCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Pad", PadRel) + .set_attr("FInferCorrectLayout", PadInferCorrectLayout) + .set_attr("TOpPattern", kInjective) + .set_attr("FTVMCompute", PadCompute); // relay.nn.mirror_pad TVM_REGISTER_NODE_TYPE(MirrorPadAttrs); -bool MirrorPadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MirrorPadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -232,40 +215,37 @@ bool MirrorPadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); oshape.push_back(data->shape[i] + padding); } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } // Handler to create a call to the padding op used by front-end FFI -Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mode) { +Expr MakeMirrorPad(Expr data, Array> pad_width, std::string mode) { auto attrs = make_object(); attrs->mode = mode; attrs->pad_width = std::move(pad_width); @@ -273,19 +253,18 @@ Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mo return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad") -.set_body_typed(MakeMirrorPad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad); RELAY_REGISTER_OP("nn.mirror_pad") -.describe(R"code(MirrorPad for n-D tensor. + .describe(R"code(MirrorPad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MirrorPad", MirrorPadRel) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MirrorPad", MirrorPadRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index c20793d9ac289..dd649514ad196 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,12 +21,14 @@ * \file pooling.cc * \brief Pooling operators */ -#include +#include +#include #include #include -#include -#include +#include + #include + #include "../../transforms/infer_layout_util.h" namespace tvm { @@ -37,13 +39,12 @@ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template -Array > PoolInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > PoolInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { // Set the pool with the new layout. @@ -56,12 +57,8 @@ Array > PoolInferCorrectLayout( } template -Expr MakeMaxPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, +Expr MakeMaxPool(Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, std::string op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); @@ -74,14 +71,9 @@ Expr MakeMaxPool(Expr data, } template -Expr MakeAvgPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad, - std::string op_name) { +Expr MakeAvgPool(Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad, std::string op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -94,9 +86,7 @@ Expr MakeAvgPool(Expr data, } template -bool Pool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -112,8 +102,7 @@ bool Pool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -140,8 +129,9 @@ bool Pool2DRel(const Array& types, oshape[hidx] = dshape[hidx]; } else { if (param->ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; } @@ -150,8 +140,9 @@ bool Pool2DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + - param->strides[1] - 1) / param->strides[1]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + param->strides[1] - 1) / + param->strides[1]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; } @@ -162,9 +153,8 @@ bool Pool2DRel(const Array& types, return true; } -template -Array Pool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -182,9 +172,7 @@ Array Pool2DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool2d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" << " or 5-D input (e.g. NCHWc on for vector instructions)" << " or 6-D input (e.g. NCHWnc for tensor accelerators)"; @@ -199,30 +187,23 @@ Array Pool2DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, + layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool2d"); -}); - + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool2d"); + }); RELAY_REGISTER_OP("nn.max_pool2d") -.describe(R"code(Max pooling operation for two dimensional data. + .describe(R"code(Max pooling operation for two dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -242,30 +223,25 @@ RELAY_REGISTER_OP("nn.max_pool2d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // AvgPool2D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool2d"); + }); RELAY_REGISTER_OP("nn.avg_pool2d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape @@ -286,24 +262,24 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // relay.nn.global_pool_2d & relay.nn.max_pool_2d TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); -bool GlobalPool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GlobalPool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; @@ -313,8 +289,7 @@ bool GlobalPool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -327,44 +302,38 @@ bool GlobalPool2DRel(const Array& types, return true; } - -template -Array GlobalPool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array GlobalPool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; + << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "global_avg_pool2d does not support input split on height"; + << "global_avg_pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "global_avg_pool2d does not support input split on width"; + << "global_avg_pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; - return Array{ - topi::nn::global_pool(inputs[0], mode, layout.name()) }; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + return Array{topi::nn::global_pool(inputs[0], mode, layout.name())}; } -Expr MakeGlobalAvgPool2D(Expr data, - std::string layout) { +Expr MakeGlobalAvgPool2D(Expr data, std::string layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d") -.set_body_typed(MakeGlobalAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d").set_body_typed(MakeGlobalAvgPool2D); // GlobalAvgPool RELAY_REGISTER_OP("nn.global_avg_pool2d") -.describe(R"code(Global average pooling operation for 2D data. + .describe(R"code(Global average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -372,30 +341,26 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool -Expr MakeGlobalMaxPool2D(Expr data, - std::string layout) { +Expr MakeGlobalMaxPool2D(Expr data, std::string layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d") -.set_body_typed(MakeGlobalMaxPool2D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d").set_body_typed(MakeGlobalMaxPool2D); RELAY_REGISTER_OP("nn.global_max_pool2d") -.describe(R"code(Global max pooling operation for 2D data. + .describe(R"code(Global max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -403,44 +368,40 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // relay.nn.adaptive_pool_2d TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); -bool AdaptivePool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) - << "Pool2D only support input >= 2-D: input must have height and width"; + << "Pool2D only support input >= 2-D: input must have height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 2U) - << "output_size can have up to 2 elements."; + CHECK_LE(output_size.size(), 2U) << "output_size can have up to 2 elements."; IndexExpr output_height, output_width; if (output_size.empty()) { output_height = dshape[hidx]; @@ -461,24 +422,23 @@ bool AdaptivePool2DRel(const Array& types, return true; } -template -Array AdaptivePool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; + << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool2d does not support input split on height"; + << "Adaptive pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool2d does not support input split on width"; + << "Adaptive pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -494,15 +454,12 @@ Array AdaptivePool2DCompute(const Attrs& attrs, output_height = output_size[0]; output_width = output_size[1]; } - return Array{ - topi::nn::adaptive_pool(inputs[0], Array{ output_height, output_width }, - mode, layout.name()) }; + return Array{topi::nn::adaptive_pool( + inputs[0], Array{output_height, output_width}, mode, layout.name())}; } // relay.nn.adaptive_avg_pool2d -Expr MakeAdaptiveAvgPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -510,11 +467,10 @@ Expr MakeAdaptiveAvgPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d") -.set_body_typed(MakeAdaptiveAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d").set_body_typed(MakeAdaptiveAvgPool2D); RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") - .describe(R"code(Adaptive average pooling operation for 2D data. + .describe(R"code(Adaptive average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -528,19 +484,17 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); // relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -548,11 +502,10 @@ Expr MakeAdaptiveMaxPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d") -.set_body_typed(MakeAdaptiveMaxPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d").set_body_typed(MakeAdaptiveMaxPool2D); RELAY_REGISTER_OP("nn.adaptive_max_pool2d") - .describe(R"code(Adaptive max pooling operation for 2D data. + .describe(R"code(Adaptive max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -566,45 +519,43 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs); -bool AdaptivePool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 3U) - << "Pool3D only support input >= 3-D: input must have depth, height and width"; + << "Pool3D only support input >= 3-D: input must have depth, height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && - !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 3U) - << "output_size can have up to 3 elements."; + CHECK_LE(output_size.size(), 3U) << "output_size can have up to 3 elements."; IndexExpr output_depth, output_height, output_width; if (output_size.empty()) { output_depth = dshape[didx]; @@ -629,26 +580,25 @@ bool AdaptivePool3DRel(const Array& types, return true; } -template -Array AdaptivePool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) - << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; + << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) - << "Adaptive pool3d does not support input split on depth"; + << "Adaptive pool3d does not support input split on depth"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool3d does not support input split on height"; + << "Adaptive pool3d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool3d does not support input split on width"; + << "Adaptive pool3d does not support input split on width"; CHECK(inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) - << "Pool3D only support 5-D input (e.g., NCDHW)" - << " or 6-D input (last dimension is a split of channel)"; + << "Pool3D only support 5-D input (e.g., NCDHW)" + << " or 6-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); @@ -669,16 +619,12 @@ Array AdaptivePool3DCompute(const Attrs& attrs, output_width = output_size[2]; } - auto osize = Array{ output_depth, output_height, output_width }; - return Array { - topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name()) - }; + auto osize = Array{output_depth, output_height, output_width}; + return Array{topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name())}; } // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveMaxPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool3D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -686,11 +632,10 @@ Expr MakeAdaptiveMaxPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d") -.set_body_typed(MakeAdaptiveMaxPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d").set_body_typed(MakeAdaptiveMaxPool3D); RELAY_REGISTER_OP("nn.adaptive_max_pool3d") - .describe(R"code(Adaptive max pooling operation for 3D data. + .describe(R"code(Adaptive max pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -704,19 +649,17 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d") (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveAvgPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, std::string layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -724,11 +667,10 @@ Expr MakeAdaptiveAvgPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d") -.set_body_typed(MakeAdaptiveAvgPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d").set_body_typed(MakeAdaptiveAvgPool3D); RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - .describe(R"code(Adaptive avg pooling operation for 3D data. + .describe(R"code(Adaptive avg pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. - **output_size**: If this argument is not provided, input depth, height and width will be used @@ -740,15 +682,14 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - **out**: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -763,8 +704,7 @@ bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, } template -Array Pool2DGradCompute(const Attrs& attrs, - const Array& inputs, +Array Pool2DGradCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -802,17 +742,18 @@ Array Pool2DGradCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + mode, ceil_mode, layout.name(), + count_include_pad)}; } else { return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + mode, ceil_mode, layout.name())}; } } - // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode) { + Array strides, Array padding, std::string layout, + bool ceil_mode) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -825,7 +766,6 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); - RELAY_REGISTER_OP("nn.max_pool2d_grad") .describe(R"code(Gradient of max pooling operation for two dimensional data. @@ -849,18 +789,17 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // AvgPool2DGrad Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode, - bool count_include_pad) { + Array strides, Array padding, std::string layout, + bool ceil_mode, bool count_include_pad) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -874,7 +813,6 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); - RELAY_REGISTER_OP("nn.avg_pool2d_grad") .describe(R"code(Gradient of average pooling operation for two dimensional data. @@ -898,22 +836,19 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // relay.nn.max_pool1d & relay.nn.avg_pool1d TVM_REGISTER_NODE_TYPE(MaxPool1DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool1DAttrs); template -bool Pool1DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -921,15 +856,13 @@ bool Pool1DRel(const Array& types, if (data == nullptr) return false; const auto dshape = data->shape; - CHECK_GE(dshape.size(), 1U) - << "Pool1D only support input >= 1-D: input must have width"; + CHECK_GE(dshape.size(), 1U) << "Pool1D only support input >= 1-D: input must have width"; const auto param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool1D layout must have W, which cannot be split"; + << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -949,8 +882,9 @@ bool Pool1DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1; } @@ -961,10 +895,8 @@ bool Pool1DRel(const Array& types, return true; } - -template -Array Pool1DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool1DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCW("NCW"); const auto* param = attrs.as(); @@ -980,9 +912,7 @@ Array Pool1DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool1d does not support input split on width"; - CHECK(inputs[0].ndim() == 3U || - inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U) + CHECK(inputs[0].ndim() == 3U || inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) << "Pool1D only support 3-D input (e.g., NCW)" << " or 4-D input (e.g. NCWc on for vector instructions)" << " or 5-D input (e.g. NCWnc for tensor accelerators)"; @@ -993,29 +923,23 @@ Array Pool1DCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool1d"); + }); RELAY_REGISTER_OP("nn.max_pool1d") -.describe(R"code(Max pooling operation for one dimensional data. + .describe(R"code(Max pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, channels, width) if `layout` is `NCW`. @@ -1033,30 +957,25 @@ RELAY_REGISTER_OP("nn.max_pool1d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // AvgPool1D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool1d"); + }); RELAY_REGISTER_OP("nn.avg_pool1d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape @@ -1075,23 +994,20 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // relay.nn.max_pool3d & relay.nn.avg_pool3d TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs); template -bool Pool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1108,8 +1024,8 @@ bool Pool3DRel(const Array& types, CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -1143,8 +1059,9 @@ bool Pool3DRel(const Array& types, oshape[ii] = dshape[ii]; } else { if (param->ceil_mode) { - oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + - param->strides[i] - 1) / param->strides[i]) + 1; + oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + param->strides[i] - 1) / + param->strides[i]) + + 1; } else { oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i]) / param->strides[i]) + 1; } @@ -1156,10 +1073,8 @@ bool Pool3DRel(const Array& types, return true; } - -template -Array Pool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); @@ -1179,9 +1094,7 @@ Array Pool3DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool3d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool3D only support 5-D input (e.g., NCDHW)" << " or 6-D input (e.g. NCDHWc on for vector instructions)" << " or 7-D input (e.g. NCDHWnc for tensor accelerators)"; @@ -1197,29 +1110,23 @@ Array Pool3DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool3d"); + }); RELAY_REGISTER_OP("nn.max_pool3d") -.describe(R"code(Max pooling operation for three dimensional data. + .describe(R"code(Max pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -1240,30 +1147,25 @@ RELAY_REGISTER_OP("nn.max_pool3d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); // AvgPool3D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, std::string layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool3d"); + }); RELAY_REGISTER_OP("nn.avg_pool3d") -.describe(R"code( + .describe(R"code( Average pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape @@ -1285,13 +1187,13 @@ Average pooling operation for three dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index c761c3f8466ea..0aca00ce80a4b 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -22,9 +22,10 @@ * \brief Property def of nn.sparse_dense operator. */ -#include -#include #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" @@ -53,9 +54,8 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs if (weight_data->shape.size() == 3) { // BSR case. - Array oshape({ - data->shape[0], - (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); + Array oshape( + {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -71,32 +71,32 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); -}); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSparseDense, args, rv); + }); RELAY_REGISTER_OP("nn.sparse_dense") -.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. + .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(4) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight_data", "1D Tensor", "Weight data matrix.") -.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") -.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") -.set_support_level(1) -.add_type_rel("SparseDense", SparseDenseRel); + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight_data", "1D Tensor", "Weight data matrix.") + .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") + .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") + .set_support_level(1) + .add_type_rel("SparseDense", SparseDenseRel); // relay.nn.sparse_transpose TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs); bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* sparse_data = types[0].as(); CHECK_EQ(sparse_data->shape.size(), 1); @@ -119,24 +119,22 @@ Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indp return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose") -.set_body_typed(MakeSparseTranspose); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose").set_body_typed(MakeSparseTranspose); RELAY_REGISTER_OP("nn.sparse_transpose") -.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix + .describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix - **input**: `(N, N)` - **out**: `(N, N)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") -.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") -.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") -.set_support_level(1) -.add_type_rel("SparseTranspose", SparseTransposeRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") + .set_support_level(1) + .add_type_rel("SparseTranspose", SparseTransposeRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 63bd42d8f508d..7f5e68390e756 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,11 +21,13 @@ * \file upsampling.cc * \brief upsampling operator */ -#include -#include #include +#include #include +#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,12 @@ TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); template -Array > UpsamplingInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 1); @@ -49,12 +50,12 @@ Array > UpsamplingInferCorrectLayout( Layout raw_layout(params->layout); Layout input = new_in_layouts[0]; if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&& + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && (input.IndexOf(LayoutAxis::Get('D')) == -1 || - (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && - !input.Contains(LayoutAxis::Get('d'))))) { - params->layout = input.name(); // modify self to follow the input layout + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout } } @@ -62,9 +63,7 @@ Array > UpsamplingInferCorrectLayout( return Array >{{inferred_layout}, {inferred_layout}}; } -bool UpSamplingRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -78,29 +77,22 @@ bool UpSamplingRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "UpSampling only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "UpSampling only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } - // Positional relay function to create upsampling operator // used by frontend FFI. -Expr MakeUpSampling(Expr data, - double scale_h, - double scale_w, - std::string layout, - std::string method, - bool align_corners) { +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, std::string layout, + std::string method, bool align_corners) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -111,12 +103,11 @@ Expr MakeUpSampling(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling") -.set_body_typed(MakeUpSampling); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling").set_body_typed(MakeUpSampling); RELAY_REGISTER_OP("nn.upsampling") -.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -130,20 +121,17 @@ RELAY_REGISTER_OP("nn.upsampling") (batch_size, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling", UpSamplingRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling", UpSamplingRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); // UpSampling3D -bool UpSampling3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSampling3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -157,8 +145,8 @@ bool UpSampling3DRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(layout_converter.defined()) - << "UpSampling3D only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "UpSampling3D only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); @@ -166,21 +154,14 @@ bool UpSampling3DRel(const Array& types, oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create upsampling3d operator // used by frontend FFI. -Expr MakeUpSampling3D(Expr data, - double scale_d, - double scale_h, - double scale_w, - std::string layout, - std::string method, - std::string coordinate_transformation_mode) { +Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, std::string layout, + std::string method, std::string coordinate_transformation_mode) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -192,12 +173,10 @@ Expr MakeUpSampling3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d") -.set_body_typed(MakeUpSampling3D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D); RELAY_REGISTER_OP("nn.upsampling3d") -.describe(R"code(Perform upsampling on input array with nearest neighbour or + .describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 5D array of shape @@ -212,14 +191,14 @@ bilinear interpolation. (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling3D", UpSampling3DRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling3D", UpSampling3DRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 2d89d778e62cd..b560aa341aab9 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -28,11 +28,13 @@ #include #include #include -#include + #include #include -#include "type_relations.h" +#include + #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -47,21 +49,18 @@ namespace relay { * \param OpName the name of registry. */ -#define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .add_type_rel("Identity", IdentityRel) \ - .set_attr("TOpPattern", kElemWise) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - ElemwiseArbitraryLayout) \ - +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) /*! Quick helper macro * - Expose a positional make function to construct the node. @@ -73,42 +72,37 @@ namespace relay { * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("Broadcast", BroadcastRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) +#define RELAY_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("Broadcast", BroadcastRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) // Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) - +#define RELAY_REGISTER_CMP_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) /*! \brief A helper class for matching and rewriting operators. */ -template +template class OpMatch { public: using MatchFunc = @@ -157,8 +151,7 @@ inline void GetPaddingWidth(const Array& padding, IndexExpr* pad_w) { } else if (padding.size() == 2) { *pad_w = padding[0] + padding[1]; } else { - CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size(); } } @@ -175,8 +168,7 @@ inline void GetPaddingHeightWidth(const Array& padding, IndexExpr* pa *pad_h = padding[0] + padding[2]; *pad_w = padding[1] + padding[3]; } else { - CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size(); } } @@ -196,8 +188,7 @@ inline void GetPaddingDepthHeightWidth(const Array& padding, IndexExp *pad_h = padding[1] + padding[4]; *pad_w = padding[2] + padding[5]; } else { - CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " - << padding.size(); + CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size(); } } diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 0f47c9aa25534..026dfc21dd5fb 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -21,166 +21,145 @@ * \file binary.cc * \brief binary broadcast operators. */ +#include #include #include -#include -#include "../type_relations.h" + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_BINARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - CHECK_EQ(inputs.size(), 2U); \ - return {FTOPI(inputs[0], inputs[1])}; \ - } \ +#define RELAY_BINARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { \ + CHECK_EQ(inputs.size(), 2U); \ + return {FTOPI(inputs[0], inputs[1])}; \ + } // Addition RELAY_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); + .describe("Elementwise add with with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("subtract") -.describe("Elementwise substract with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); + .describe("Elementwise substract with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("right_shift") -.describe("Elementwise right shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); - + .describe("Elementwise right shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); RELAY_REGISTER_BINARY_OP("left_shift") -.describe("Elementwise left shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); - + .describe("Elementwise left shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); RELAY_REGISTER_BINARY_OP("maximum") -.describe("Elementwise maximum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); - + .describe("Elementwise maximum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); RELAY_REGISTER_BINARY_OP("minimum") -.describe("Elementwise minimum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); - + .describe("Elementwise minimum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); RELAY_REGISTER_BINARY_OP("divide") -.describe("Elementwise divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); - + .describe("Elementwise divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); RELAY_REGISTER_BINARY_OP("floor_divide") -.describe("Elementwise floor divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); - + .describe("Elementwise floor divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); RELAY_REGISTER_BINARY_OP("multiply") -.describe("Elementwise multiply with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); - + .describe("Elementwise multiply with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); RELAY_REGISTER_BINARY_OP("power") -.describe("Elementwise power with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); - + .describe("Elementwise power with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); RELAY_REGISTER_BINARY_OP("mod") -.describe("Elementwise mod with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); - + .describe("Elementwise mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); RELAY_REGISTER_BINARY_OP("floor_mod") - .describe("Elementwise floor mod with broadcasting") - .set_support_level(1) - .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); - + .describe("Elementwise floor mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); RELAY_REGISTER_BINARY_OP("logical_and") -.describe("Elementwise logical AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); - + .describe("Elementwise logical AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); RELAY_REGISTER_BINARY_OP("logical_or") -.describe("Elementwise logical OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); - + .describe("Elementwise logical OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); RELAY_REGISTER_BINARY_OP("logical_xor") -.describe("Elementwise logical XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); - + .describe("Elementwise logical XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); RELAY_REGISTER_BINARY_OP("bitwise_and") -.describe("Elementwise bitwise AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); - + .describe("Elementwise bitwise AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); RELAY_REGISTER_BINARY_OP("bitwise_or") -.describe("Elementwise bitwise OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); - + .describe("Elementwise bitwise OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); RELAY_REGISTER_BINARY_OP("bitwise_xor") -.describe("Elementwise bitwise XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); - + .describe("Elementwise bitwise XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); RELAY_REGISTER_CMP_OP("equal") -.describe("Elementwise equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); - + .describe("Elementwise equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); RELAY_REGISTER_CMP_OP("not_equal") -.describe("Elementwise not equal with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); - + .describe("Elementwise not equal with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); RELAY_REGISTER_CMP_OP("less") -.describe("Elementwise less than with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); - + .describe("Elementwise less than with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); RELAY_REGISTER_CMP_OP("less_equal") -.describe("Elementwise less than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); - + .describe("Elementwise less than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); RELAY_REGISTER_CMP_OP("greater") -.describe("Elementwise greater than compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); - + .describe("Elementwise greater than compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); RELAY_REGISTER_CMP_OP("greater_equal") -.describe("Elementwise greater than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); + .describe("Elementwise greater than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 3f220fb64ad5d..d526cef5bf626 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -21,13 +21,15 @@ * \file reduce.cc * \brief Reduction operators. */ -#include -#include -#include #include #include -#include +#include +#include +#include + #include +#include + #include "../op_common.h" #include "../type_relations.h" @@ -37,14 +39,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); /*! -* \brief GetReduceAxes, get the new axis from indim and other arguments -* \param indim Number of dimensions of input data. -* \param axis The input axis vector. -* \param exclude Whether 'axis' input given is the excluded axis. -* \return r_axes The new reduced axes of the output. -*/ -inline std::vector GetReduceAxes(const uint32_t indim, - const Array& inaxis, + * \brief GetReduceAxes, get the new axis from indim and other arguments + * \param indim Number of dimensions of input data. + * \param axis The input axis vector. + * \param exclude Whether 'axis' input given is the excluded axis. + * \return r_axes The new reduced axes of the output. + */ +inline std::vector GetReduceAxes(const uint32_t indim, const Array& inaxis, bool exclude) { if (!inaxis.defined()) { std::vector r_axes(indim); @@ -60,16 +61,13 @@ inline std::vector GetReduceAxes(const uint32_t indim, } // Check out of bounds error - CHECK(axis >= 0) - << "Axis out of bounds in reduce operator."; - CHECK(axis < indim) - << "Axis out of bounds in reduce operator."; + CHECK(axis >= 0) << "Axis out of bounds in reduce operator."; + CHECK(axis < indim) << "Axis out of bounds in reduce operator."; in_axes.push_back(axis); } CHECK(in_axes[in_axes.size() - 1] < indim) - << "Reduction axis " << in_axes[in_axes.size() - 1] - << " exceeds input dimensions " << indim; + << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim; std::sort(in_axes.begin(), in_axes.end()); @@ -81,18 +79,16 @@ inline std::vector GetReduceAxes(const uint32_t indim, std::vector r_axes(r_size); for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) { if (j < in_axes.size() && in_axes[j] == i) { - ++j; - continue; + ++j; + continue; } r_axes[k++] = i; } return r_axes; } - // Get axis under exclude condition. -Array GetExcludeAxes(size_t indim, - const Array& inaxis) { +Array GetExcludeAxes(size_t indim, const Array& inaxis) { CHECK(inaxis.defined()) << "Cannot set exclude when axis=None"; std::vector axis_flag(indim, true); for (auto i : inaxis) { @@ -101,10 +97,8 @@ Array GetExcludeAxes(size_t indim, axis = axis + static_cast(indim); } // Check out of bounds error - CHECK_GE(axis, 0) - << "Axis out of bounds in reduce operator."; - CHECK_LT(axis, static_cast(indim)) - << "Axis out of bounds in reduce operator."; + CHECK_GE(axis, 0) << "Axis out of bounds in reduce operator."; + CHECK_LT(axis, static_cast(indim)) << "Axis out of bounds in reduce operator."; axis_flag[axis] = false; } @@ -177,34 +171,32 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, return Array>{{ret}, {ret}}; } -template -Array ReduceCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type, - F f) { +template +Array ReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); if (inputs[0]->shape.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); if (axes.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } } - return { f(inputs[0], axes, param->keepdims, false) }; + return {f(inputs[0], axes, param->keepdims, false)}; } /*! -* \brief ReduceShapeImpl get the outshape for the reduction operator -* \param in_shape Shape of input data. -* \param param ReduceAttrs details. -* \param reporter The reporter to report solution to. -* \return oshape Output shape inferred. -*/ -inline std::vector ReduceShapeImpl(const std::vector &in_shape, + * \brief ReduceShapeImpl get the outshape for the reduction operator + * \param in_shape Shape of input data. + * \param param ReduceAttrs details. + * \param reporter The reporter to report solution to. + * \return oshape Output shape inferred. + */ +inline std::vector ReduceShapeImpl(const std::vector& in_shape, const ReduceAttrs* param, const TypeReporter& reporter) { uint32_t indim = in_shape.size(); @@ -225,9 +217,9 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } if (is_dynamic_input) { - CHECK(reporter->Assert(max_shape < tir::make_const( - DataType::Int(64), std::numeric_limits::max()))) - << "The maximum possible index of reduced shape cannot be more than int32 max."; + CHECK(reporter->Assert(max_shape < + tir::make_const(DataType::Int(64), std::numeric_limits::max()))) + << "The maximum possible index of reduced shape cannot be more than int32 max."; } if (param->keepdims) { @@ -255,16 +247,14 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } /*! -* \brief ArgReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ArgReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; @@ -281,15 +271,13 @@ bool ArgReduceRel(const Array& types, } /*! -* \brief ReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ReduceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -305,70 +293,57 @@ bool ReduceRel(const Array& types, return true; } -#define RELAY_REGISTER_REDUCE_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([]( \ - Expr data, \ - Array axis, \ - bool keepdims, \ - bool exclude) { \ - auto attrs = make_object(); \ - attrs->axis = std::move(axis); \ - attrs->keepdims = keepdims; \ - attrs->exclude = exclude; \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(attrs), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") - - -Array ArgMaxCompute(const Attrs& attrs, - const Array& inputs, +#define RELAY_REGISTER_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ + auto attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + attrs->exclude = exclude; \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(attrs), {}); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + +Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmax); } - RELAY_REGISTER_REDUCE_OP("argmax") -.describe(R"code(Creates an operation that finds the indices of the maximum + .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array ArgMinCompute(const Attrs& attrs, - const Array& inputs, +Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_REDUCE_OP("argmin") -.describe(R"code(Creates an operation that finds the indices of the minimum + .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMinCompute) -.set_attr("TOpPattern", kCommReduce); - -Array SumCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("TOpPattern", kCommReduce); + +Array SumCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::sum); } - RELAY_REGISTER_REDUCE_OP("sum") -.describe(R"code(Computes the sum of array elements over given axes. + .describe(R"code(Computes the sum of array elements over given axes. Example:: @@ -385,23 +360,20 @@ Example:: [ 12. 19. 27.] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) -.set_attr("FTVMCompute", SumCompute) -.set_attr("TOpPattern", kCommReduce); - - -Array AllCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FTVMCompute", SumCompute) + .set_attr("TOpPattern", kCommReduce); + +Array AllCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::all); } - RELAY_REGISTER_REDUCE_OP("all") -.describe(R"code(Computes the logical AND of boolean array elements over given axes. + .describe(R"code(Computes the logical AND of boolean array elements over given axes. Example:: @@ -422,22 +394,19 @@ Example:: [False, True, False]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AllCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AllCompute) + .set_attr("TOpPattern", kCommReduce); -Array AnyCompute(const Attrs& attrs, - const Array& inputs, +Array AnyCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::any); } - RELAY_REGISTER_REDUCE_OP("any") -.describe(R"code(Computes the logical OR of boolean array elements over given axes. + .describe(R"code(Computes the logical OR of boolean array elements over given axes. Example:: @@ -458,56 +427,49 @@ Example:: [False, True, True]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AnyCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AnyCompute) + .set_attr("TOpPattern", kCommReduce); -Array MaxCompute(const Attrs& attrs, - const Array& inputs, +Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::max); } RELAY_REGISTER_REDUCE_OP("max") -.describe(R"code(Computes the max of array elements over given axes. + .describe(R"code(Computes the max of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array MinCompute(const Attrs& attrs, - const Array& inputs, +Array MinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::min); } - RELAY_REGISTER_REDUCE_OP("min") -.describe(R"code(Computes the min of array elements over given axes. + .describe(R"code(Computes the min of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MinCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MinCompute) + .set_attr("TOpPattern", kCommReduce); -Array ProdCompute(const Attrs& attrs, - const Array& inputs, +Array ProdCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::prod); } RELAY_REGISTER_REDUCE_OP("prod") -.describe(R"code(Computes the products of array elements over given axes. + .describe(R"code(Computes the products of array elements over given axes. Example:: @@ -522,32 +484,27 @@ Example:: [ 36 480 2058] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", ProdCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", ProdCompute) + .set_attr("TOpPattern", kCommReduce); - -Array MeanCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { +Array MeanCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; - for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { count *= inputs[0]->shape[i]; } auto res = ReduceCompute(attrs, inputs, out_type, topi::sum); return {topi::divide(res[0], count)}; } - RELAY_REGISTER_REDUCE_OP("mean") -.describe(R"code(Computes the mean of array elements over given axes. + .describe(R"code(Computes the mean of array elements over given axes. Example:: @@ -562,16 +519,13 @@ Example:: [ 2. 3.16666667 4.5] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MeanCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MeanCompute) + .set_attr("TOpPattern", kCommReduce); -bool VarianceRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -593,8 +547,7 @@ bool VarianceRel(const Array& types, return true; } -Array VarianceCompute(const Attrs& attrs, - const Array& inputs, +Array VarianceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); @@ -602,9 +555,7 @@ Array VarianceCompute(const Attrs& attrs, auto axes = param->axis; auto data = inputs[0]; auto mean = inputs[1]; - for (int64_t i : GetReduceAxes(data->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) { count *= data->shape[i]; } std::vector expand_shape; @@ -614,11 +565,7 @@ Array VarianceCompute(const Attrs& attrs, return {var}; } -Expr MakeVariance(Expr data, - Expr mean, - Array axis, - bool keepdims, - bool exclude) { +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; @@ -627,23 +574,22 @@ Expr MakeVariance(Expr data, return Call(op, {data, mean}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._variance") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeVariance, args, rv); }); RELAY_REGISTER_OP("variance") -.describe(R"code(Computes the variance of array elements over given axes. + .describe(R"code(Computes the variance of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("mean", "Tensor", "The mean tensor.") -.add_type_rel("Variance", VarianceRel) -.set_attr("FTVMCompute", VarianceCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("mean", "Tensor", "The mean tensor.") + .add_type_rel("Variance", VarianceRel) + .set_attr("FTVMCompute", VarianceCompute) + .set_attr("TOpPattern", kCommReduce); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2c9010cee2765..836a155325be7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -21,24 +21,27 @@ * \file transform.cc * \brief Transform operators. */ -#include +#include "transform.h" + +#include +#include +#include +#include +#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include +#include +#include +#include + #include -#include "../op_common.h" + #include "../../../arith/compute_expr.h" #include "../../transforms/infer_layout_util.h" #include "../../transforms/pattern_util.h" -#include "transform.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -47,115 +50,95 @@ using tir::IntImmNode; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); -bool CastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); - reporter->Assign(types[1], TensorType( - data->shape, param->dtype)); + reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Array CastCompute(const Attrs& attrs, - const Array& inputs, +Array CastCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const CastAttrs *param = attrs.as(); + const CastAttrs* param = attrs.as(); CHECK(param != nullptr); DataType dtype = param->dtype; - return { topi::cast(inputs[0], dtype) }; + return {topi::cast(inputs[0], dtype)}; } -Expr MakeCast(Expr data, - DataType dtype) { +Expr MakeCast(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.ir.cast") -.set_body_typed(MakeCast); +TVM_REGISTER_GLOBAL("relay.ir.cast").set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") -.describe(R"code(Cast the data into a new data type. + .describe(R"code(Cast the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Cast", CastRel) -.set_attr("FTVMCompute", CastCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Cast", CastRel) + .set_attr("FTVMCompute", CastCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.cast_like -bool CastLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* dtype_like = types[1].as(); if (dtype_like == nullptr) { CHECK(types[1].as()) - << "cast: expect input type to be TensorType but get " - << types[1]; + << "cast: expect input type to be TensorType but get " << types[1]; return false; } reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype)); return true; } - -Array CastLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CastLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::cast(inputs[0], inputs[1]->dtype) }; + return {topi::cast(inputs[0], inputs[1]->dtype)}; } - -Expr MakeCastLike(Expr data, - Expr dtype_like) { +Expr MakeCastLike(Expr data, Expr dtype_like) { static const Op& op = Op::Get("cast_like"); return Call(op, {data, dtype_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.ir.cast_like") -.set_body_typed(MakeCastLike); +TVM_REGISTER_GLOBAL("relay.ir.cast_like").set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") -.describe(R"code(Cast the data into the type of another tensor. + .describe(R"code(Cast the data into the type of another tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("dtype_like", "Tensor", "The tensor to cast to.") -.set_support_level(3) -.add_type_rel("CastLike", CastLikeRel) -.set_attr("FTVMCompute", CastLikeCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - - -Array ReinterpretCompute(const Attrs& attrs, - const Array& inputs, + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("dtype_like", "Tensor", "The tensor to cast to.") + .set_support_level(3) + .add_type_rel("CastLike", CastLikeRel) + .set_attr("FTVMCompute", CastLikeCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +Array ReinterpretCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const CastAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -175,44 +158,39 @@ TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, }); RELAY_REGISTER_OP("reinterpret") -.describe(R"code(Reinterpret the data into a new data type. + .describe(R"code(Reinterpret the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reinterpret", CastRel) -.set_attr("FTVMCompute", ReinterpretCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reinterpret", CastRel) + .set_attr("FTVMCompute", ReinterpretCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); -bool ExpandDimsRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "expand_dims: expect input type to be TensorType but get " - << types[0]; + << "expand_dims: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis + 1 : axis; std::vector oshape; oshape.reserve(ndim + num_newaxis); @@ -229,17 +207,14 @@ bool ExpandDimsRel(const Array& types, return true; } -Array ExpandDimsCompute(const Attrs& attrs, - const Array& inputs, +Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ExpandDimsAttrs *param = attrs.as(); + const ExpandDimsAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) }; + return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)}; } -Expr MakeExpandDims(Expr data, - int axis, - int num_newaxis) { +Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { auto attrs = make_object(); attrs->axis = axis; attrs->num_newaxis = num_newaxis; @@ -247,75 +222,68 @@ Expr MakeExpandDims(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.expand_dims") -.set_body_typed(MakeExpandDims); +TVM_REGISTER_GLOBAL("relay.op._make.expand_dims").set_body_typed(MakeExpandDims); RELAY_REGISTER_OP("expand_dims") -.describe(R"code(Insert `num_newaxis` axises at the position given by `axis` + .describe(R"code(Insert `num_newaxis` axises at the position given by `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("ExpandDims", ExpandDimsRel) -.set_attr("FTVMCompute", ExpandDimsCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("ExpandDims", ExpandDimsRel) + .set_attr("FTVMCompute", ExpandDimsCompute) + .set_attr("TOpPattern", kBroadcast); // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -Array ConcatenateCompute(const Attrs& attrs, - const Array& inputs, +Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ConcatenateAttrs *param = attrs.as(); + const ConcatenateAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::concatenate(inputs, param->axis) }; + return {topi::concatenate(inputs, param->axis)}; } -Expr MakeConcatenate(Expr data, - int axis) { +Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.concatenate") -.set_body_typed(MakeConcatenate); +TVM_REGISTER_GLOBAL("relay.op._make.concatenate").set_body_typed(MakeConcatenate); RELAY_REGISTER_OP("concatenate") -.describe(R"code(Concatenate the input tensors along the given axis. + .describe(R"code(Concatenate the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are concatenated. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel) -.set_attr("FInferCorrectLayout", ConcatenateLayout) -.set_attr("FTVMCompute", ConcatenateCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(1) + .add_type_rel("Concatenate", ConcatenateRel) + .set_attr("FInferCorrectLayout", ConcatenateLayout) + .set_attr("FTVMCompute", ConcatenateCompute) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); -bool StackRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StackRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TupleType but get " - << types[0]; + << "cast: expect input type to be TupleType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -324,11 +292,9 @@ bool StackRel(const Array& types, // Sanity check: axis int axis = param->axis; - CHECK(-ndim <= axis && axis < ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; - axis = axis < 0 ? ndim + axis + 1: axis; + CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; + axis = axis < 0 ? ndim + axis + 1 : axis; // Sanity check: ndim and dtype. const DataType dtype = first->dtype; @@ -341,8 +307,9 @@ bool StackRel(const Array& types, for (size_t j = 0; j < first->shape.size(); ++j) { if (j == static_cast(axis)) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.stack requires all tensors have the same shape " - "on non-stacking axes"); + throw Error( + "relay.stack requires all tensors have the same shape " + "on non-stacking axes"); } } @@ -361,55 +328,49 @@ bool StackRel(const Array& types, return true; } -Array StackCompute(const Attrs& attrs, - const Array& inputs, +Array StackCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const StackAttrs *param = attrs.as(); + const StackAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::stack(inputs, param->axis) }; + return {topi::stack(inputs, param->axis)}; } -Expr MakeStack(Expr data, - int axis) { +Expr MakeStack(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.stack") -.set_body_typed(MakeStack); +TVM_REGISTER_GLOBAL("relay.op._make.stack").set_body_typed(MakeStack); RELAY_REGISTER_OP("stack") -.describe(R"code(Stack the input tensors along the given axis. + .describe(R"code(Stack the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are stacked. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(3) -.add_type_rel("Stack", StackRel) -.set_attr("FTVMCompute", StackCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(3) + .add_type_rel("Stack", StackRel) + .set_attr("FTVMCompute", StackCompute) + .set_attr("TOpPattern", kInjective); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); -bool TransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "transpose: expect input type to be TensorType but get " - << types[0]; + << "transpose: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -417,8 +378,8 @@ bool TransposeRel(const Array& types, const Array& axes = param->axes; // check dimension match CHECK(!axes.defined() || static_cast(axes.size()) == ndim) - << "Dimension mismatch: axes has " << axes.size() << " elements" - << ", but data.ndim = " << ndim; + << "Dimension mismatch: axes has " << axes.size() << " elements" + << ", but data.ndim = " << ndim; // construct int_axes std::vector int_axes; int_axes.reserve(ndim); @@ -433,9 +394,8 @@ bool TransposeRel(const Array& types, int64_t axis = e; // sanity check for axis and ndim CHECK(-ndim <= axis && axis < ndim) - << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; axis = axis < 0 ? axis + ndim : axis; // sanity check for duplication CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; @@ -452,40 +412,37 @@ bool TransposeRel(const Array& types, return true; } -Array TransposeCompute(const Attrs& attrs, - const Array& inputs, +Array TransposeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::transpose(inputs[0], param->axes) }; + return Array{topi::transpose(inputs[0], param->axes)}; } -Expr MakeTranspose(Expr data, - Array axes) { +Expr MakeTranspose(Expr data, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.transpose") -.set_body_typed(MakeTranspose); +TVM_REGISTER_GLOBAL("relay.op._make.transpose").set_body_typed(MakeTranspose); RELAY_REGISTER_OP("transpose") -.describe(R"code(Permutes the dimensions of an array. + .describe(R"code(Permutes the dimensions of an array. - **data**: The input data to the operator. - **axes**: The target axes order, reverse order if not specified. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Transpose", TransposeRel) -.set_attr("FTVMCompute", TransposeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Transpose", TransposeRel) + .set_attr("FTVMCompute", TransposeCompute) + .set_attr("TOpPattern", kInjective); /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); @@ -528,9 +485,7 @@ double ToScalar(const runtime::NDArray& array, int i = 0) { return -std::numeric_limits::infinity(); } -bool ReshapeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* param = attrs.as(); if (param->reverse) { @@ -543,8 +498,7 @@ bool ReshapeRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reshape: expect input type to be TensorType but get " - << types[0]; + << "reshape: expect input type to be TensorType but get " << types[0]; return false; } @@ -599,8 +553,7 @@ bool ReshapeRel(const Array& types, oshape.push_back(data_shape[src_idx++]); } else if (svalue == -1) { // inference based on rest - CHECK_LT(infer_idx, 0) - << "One and only one dim can be inferred"; + CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred"; infer_idx = i; oshape.push_back(1); ++src_idx; @@ -634,8 +587,7 @@ bool ReshapeRel(const Array& types, Integer d1 = newshape[++i]; Integer d2 = newshape[++i]; if (d1->value == -1) { - CHECK(d2->value != -1) - << "Split dims cannot both be -1."; + CHECK(d2->value != -1) << "Split dims cannot both be -1."; used_output_dims.insert(oshape.size()); if (d0.as()) { oshape.push_back(Any::make()); @@ -691,16 +643,15 @@ bool ReshapeRel(const Array& types, } if (param->reverse) { - reporter->Assign(types[1], TensorType( - Array(oshape.rbegin(), oshape.rend()), data->dtype)); + reporter->Assign(types[1], + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { reporter->Assign(types[2], TensorType(oshape, data->dtype)); } return true; } -Array ReshapeCompute(const Attrs& attrs, - const Array& inputs, +Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); @@ -712,11 +663,10 @@ Array ReshapeCompute(const Attrs& attrs, newshape.push_back(val); } } - return { topi::reshape(inputs[0], newshape) }; + return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, - Expr newshape) { +Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); attrs->newshape = newshape; attrs->reverse = false; @@ -724,11 +674,10 @@ Expr MakeReshape(Expr data, return Call(op, {data, newshape}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reshape") -.set_body_typed(MakeReshape); +TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); RELAY_REGISTER_OP("reshape") -.describe(R"code(Reshapes the input array. + .describe(R"code(Reshapes the input array. Example:: @@ -778,27 +727,24 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("newshape", "Tensor", "The shape of output tensor.") -.set_support_level(3) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("newshape", "Tensor", "The shape of output tensor.") + .set_support_level(3) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); /*! -* \brief ReshapeLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool ReshapeLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReshapeLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -819,43 +765,36 @@ bool ReshapeLikeRel(const Array& types, } if (is_static_shape) { CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible."; + << "Reshape inputs size should be compatible."; } reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); return true; } - -Expr MakeReshapeLike(Expr data, - Expr shape_like) { +Expr MakeReshapeLike(Expr data, Expr shape_like) { static const Op& op = Op::Get("reshape_like"); return Call(op, {data, shape_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.reshape_like") -.set_body_typed(MakeReshapeLike); - +TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike); RELAY_REGISTER_OP("reshape_like") -.describe(R"code(Reshapes the input array by the size of another array. + .describe(R"code(Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes the input array into an output array with the same shape as the second input array. .. note:: Sizes for both array should be compatible. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(3) -.add_type_rel("ReshapeLike", ReshapeLikeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(3) + .add_type_rel("ReshapeLike", ReshapeLikeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); // ArgWhere -bool ArgWhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgWhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); @@ -869,35 +808,36 @@ bool ArgWhereRel(const Array& types, return true; } -TVM_REGISTER_GLOBAL("relay.op._make.argwhere") -.set_body_typed([](Expr data) { +TVM_REGISTER_GLOBAL("relay.op._make.argwhere").set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("argwhere") -.describe(R"doc(Find the indices of elements of a tensor that are + .describe(R"doc(Find the indices of elements of a tensor that are non-zero)doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("condition", "Tensor", "The input condition tensor.") -.add_type_rel("ArgWhere", ArgWhereRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kOpaque) -.set_support_level(10); + .set_num_inputs(1) + .add_argument("condition", "Tensor", "The input condition tensor.") + .add_type_rel("ArgWhere", ArgWhereRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); -bool TakeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TakeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + return false; + } const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + return false; + } CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; const auto param = attrs.as(); CHECK(param != nullptr); @@ -913,9 +853,8 @@ bool TakeRel(const Array& types, const auto ndim_indices = static_cast(indices->shape.size()); int axis = static_cast(param->axis->value); if (axis < 0) axis += ndim_data; - CHECK_LE(axis, ndim_data) - << "axis should be with in data shape" - << ", but got = " << axis; + CHECK_LE(axis, ndim_data) << "axis should be with in data shape" + << ", but got = " << axis; oshape.reserve(ndim_data - 1 + ndim_indices); for (int i = 0; i < axis; ++i) { @@ -924,7 +863,7 @@ bool TakeRel(const Array& types, for (int i = 0; i < ndim_indices; ++i) { oshape.emplace_back(indices->shape[i]); } - for (int i = axis+1; i < ndim_data; ++i) { + for (int i = axis + 1; i < ndim_data; ++i) { oshape.emplace_back(data->shape[i]); } @@ -932,22 +871,18 @@ bool TakeRel(const Array& types, return true; } -Array TakeCompute(const Attrs& attrs, - const Array& inputs, +Array TakeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); if (!param->axis.defined()) { - return Array{ topi::take(inputs[0], inputs[1], param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->mode)}; } else { - return Array{ topi::take(inputs[0], inputs[1], param->axis, param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->axis, param->mode)}; } } -Expr MakeTake(Expr data, - Expr indices, - Integer axis, - std::string mode) { +Expr MakeTake(Expr data, Expr indices, Integer axis, std::string mode) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); @@ -955,11 +890,10 @@ Expr MakeTake(Expr data, return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.take") -.set_body_typed(MakeTake); +TVM_REGISTER_GLOBAL("relay.op._make.take").set_body_typed(MakeTake); RELAY_REGISTER_OP("take") -.describe(R"code(Take elements from an array along an axis. + .describe(R"code(Take elements from an array along an axis. When axis is not None, this function does the same thing as 'fancy' indexing (indexing arrays using arrays); however, it can be easier to use if you need @@ -981,22 +915,19 @@ Examples:: [ 4., 3.]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("indices", "Tensor", "The indices tensor.") -.set_support_level(3) -.add_type_rel("Take", TakeRel) -.set_attr("FTVMCompute", TakeCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_support_level(3) + .add_type_rel("Take", TakeRel) + .set_attr("FTVMCompute", TakeCompute) + .set_attr("TOpPattern", kInjective); // Init ops TVM_REGISTER_NODE_TYPE(InitOpAttrs); -bool FullRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const InitOpAttrs* param = attrs.as(); @@ -1011,23 +942,19 @@ bool FullRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "Fill value should be a scalar but has dimension " - << fill_value->shape.size() << "."; + << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; reporter->Assign(types[1], TensorType(param->shape, out_dtype)); return true; } -Array FullCompute(const Attrs& attrs, - const Array& inputs, +Array FullCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); - return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) }; + return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; } -Expr MakeFull(Expr fill_value, - Array shape, - DataType dtype) { +Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -1035,24 +962,21 @@ Expr MakeFull(Expr fill_value, return Call(op, {fill_value}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full") -.set_body_typed(MakeFull); +TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull); RELAY_REGISTER_OP("full") -.describe(R"code(Fill array with scalar value. + .describe(R"code(Fill array with scalar value. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("fill_value", "double", "The value to fill.") -.set_support_level(3) -.add_type_rel("Full", FullRel) -.set_attr("FTVMCompute", FullCompute) -.set_attr("TOpPattern", kElemWise); - -bool InitOpRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(1) + .add_argument("fill_value", "double", "The value to fill.") + .set_support_level(3) + .add_type_rel("Full", FullRel) + .set_attr("FTVMCompute", FullCompute) + .set_attr("TOpPattern", kElemWise); + +bool InitOpRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 1); const InitOpAttrs* param = attrs.as(); @@ -1061,8 +985,7 @@ bool InitOpRel(const Array& types, return true; } -Expr MakeZeros(Array shape, - DataType dtype) { +Expr MakeZeros(Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -1070,20 +993,18 @@ Expr MakeZeros(Array shape, return Call(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.zeros") -.set_body_typed(MakeZeros); +TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros); RELAY_REGISTER_OP("zeros") -.describe(R"code(Fill array with zeros. + .describe(R"code(Fill array with zeros. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); + .set_attrs_type() + .set_num_inputs(0) + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); -Expr MakeOnes(Array shape, - DataType dtype) { +Expr MakeOnes(Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -1091,21 +1012,18 @@ Expr MakeOnes(Array shape, return Call(op, {}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.ones") -.set_body_typed(MakeOnes); +TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes); RELAY_REGISTER_OP("ones") -.describe(R"code(Fill array with ones. + .describe(R"code(Fill array with ones. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); - -bool FullLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(0) + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); + +bool FullLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -1118,47 +1036,42 @@ bool FullLikeRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "The fill value should be a scalar but here it has dimension " - << fill_value->shape.size() << "."; + << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size() + << "."; reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } -Array FullLikeCompute(const Attrs& attrs, - const Array& inputs, +Array FullLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::full_like(inputs[0], inputs[1]()) }; + return {topi::full_like(inputs[0], inputs[1]())}; } -Expr MakeFullLike(Expr data, - Expr fill_value) { +Expr MakeFullLike(Expr data, Expr fill_value) { static const Op& op = Op::Get("full_like"); return Call(op, {data, fill_value}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full_like") -.set_body_typed(MakeFullLike); +TVM_REGISTER_GLOBAL("relay.op._make.full_like").set_body_typed(MakeFullLike); RELAY_REGISTER_OP("full_like") -.describe(R"code(Return an scalar value array with the same shape + .describe(R"code(Return an scalar value array with the same shape and type as the input array. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("fill_value", "double", "Scalar value to fill.") -.set_support_level(3) -.add_type_rel("FullLike", FullLikeRel) -.set_attr("FTVMCompute", FullLikeCompute) -.set_attr("TOpPattern", kElemWise); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("fill_value", "double", "Scalar value to fill.") + .set_support_level(3) + .add_type_rel("FullLike", FullLikeRel) + .set_attr("FTVMCompute", FullLikeCompute) + .set_attr("TOpPattern", kElemWise); // arange operator TVM_REGISTER_NODE_TYPE(ArangeAttrs); -bool ArangeRel(const Array& types, - int num_inputs, - const Attrs& raw_attrs, +bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const ArangeAttrs* attrs = raw_attrs.as(); @@ -1168,16 +1081,14 @@ bool ArangeRel(const Array& types, reporter->Assign(types[1], types[2]); reporter->Assign(types[2], TensorType({}, attrs->dtype)); - if ((cstart = attrs->start.as()) && - (cstop = attrs->stop.as()) && + if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && (cstep = attrs->step.as())) { double start = ToScalar(cstart->data); double stop = ToScalar(cstop->data); double step = ToScalar(cstep->data); int32_t num_elem = static_cast(std::ceil((stop - start) / step)); - CHECK_GT(num_elem, 0) - << "Invalid arange attributes (start, stop, step): " << attrs->start - << ", " << attrs->stop << ", " << attrs->step; + CHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start + << ", " << attrs->stop << ", " << attrs->step; reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); return true; } else { @@ -1186,32 +1097,28 @@ bool ArangeRel(const Array& types, } } -inline te::Tensor DynamicArange(const te::Tensor& start, - const te::Tensor& stop, - const te::Tensor& step, - tvm::DataType dtype, - std::string name = "tensor", - std::string tag = topi::kInjective) { +inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, + const te::Tensor& step, tvm::DataType dtype, + std::string name = "tensor", std::string tag = topi::kInjective) { tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); - return te::compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start[0] + step[0] * indices[0]); - }, name, tag); + return te::compute( + {num_elem}, + [&](const Array& indices) { + return tvm::cast(dtype, start[0] + step[0] * indices[0]); + }, + name, tag); } -Array ArangeCompute(const Attrs& attrs, - const Array& inputs, +Array ArangeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const ArangeAttrs* param = attrs.as(); te::Tensor start = inputs[0]; - te::Tensor stop = inputs[1]; + te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; - return { DynamicArange(start, stop, step, param->dtype) }; + return {DynamicArange(start, stop, step, param->dtype)}; } -Expr MakeArange(Expr start, - Expr stop, - Expr step, - DataType dtype) { +Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { auto attrs = make_object(); attrs->start = start; attrs->stop = stop; @@ -1221,8 +1128,7 @@ Expr MakeArange(Expr start, return Call(op, {start, stop, step}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.arange") -.set_body_typed(MakeArange); +TVM_REGISTER_GLOBAL("relay.op._make.arange").set_body_typed(MakeArange); // An issue with the existing design is that we require dependency // to type the operator precisely. @@ -1238,45 +1144,40 @@ TVM_REGISTER_GLOBAL("relay.op._make.arange") // In general I think we should avoid this pattern, and introduce // a secondary shape analysis to recover more precise information. RELAY_REGISTER_OP("arange") -.describe(R"code(Returns evenly spaced values within a given interval. + .describe(R"code(Returns evenly spaced values within a given interval. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.set_support_level(3) -.add_type_rel("Arange", ArangeRel) -.set_attr("FTVMCompute", ArangeCompute) -// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape -.set_attr("TOpPattern", kOpaque) -.set_attr("AnyCodegenStrategy", kVariableDimensions); + .set_attrs_type() + .set_num_inputs(3) + .set_support_level(3) + .add_type_rel("Arange", ArangeRel) + .set_attr("FTVMCompute", ArangeCompute) + // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape + .set_attr("TOpPattern", kOpaque) + .set_attr("AnyCodegenStrategy", kVariableDimensions); // repeat operator TVM_REGISTER_NODE_TYPE(RepeatAttrs); -bool RepeatRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool RepeatRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "repeat: expect input type to be TensorType but get " - << types[0]; + << "repeat: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int repeats = param->repeats; const int axis = param->axis; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis : axis; std::vector oshape; oshape.reserve(ndim + repeats); @@ -1291,17 +1192,14 @@ bool RepeatRel(const Array& types, return true; } -Array RepeatCompute(const Attrs& attrs, - const Array& inputs, +Array RepeatCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const RepeatAttrs *param = attrs.as(); + const RepeatAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::repeat(inputs[0], param->repeats, param->axis) }; + return {topi::repeat(inputs[0], param->repeats, param->axis)}; } -Expr MakeRepeat(Expr data, - int repeats, - int axis) { +Expr MakeRepeat(Expr data, int repeats, int axis) { auto attrs = make_object(); attrs->repeats = repeats; attrs->axis = axis; @@ -1309,50 +1207,45 @@ Expr MakeRepeat(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.repeat") -.set_body_typed(MakeRepeat); +TVM_REGISTER_GLOBAL("relay.op._make.repeat").set_body_typed(MakeRepeat); RELAY_REGISTER_OP("repeat") -.describe(R"code(Repeat elements of an array `repeats` times along axis `axis` + .describe(R"code(Repeat elements of an array `repeats` times along axis `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Repeat", RepeatRel) -.set_attr("FTVMCompute", RepeatCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Repeat", RepeatRel) + .set_attr("FTVMCompute", RepeatCompute) + .set_attr("TOpPattern", kBroadcast); // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); -bool TileRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "tile: expect input type to be TensorType but get " - << types[0]; + << "tile: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const size_t ndim = data->shape.size(); const Array& reps = param->reps; // check dimension match - CHECK(reps.defined()) - << "repetition array is not defined. data.ndim = " << ndim; + CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); for (size_t i = 0; i < rndim; ++i) { if (const tvm::tir::IntImmNode* val = reps[i].as()) { - CHECK_GT(val->value, 0) - << "Tile reps value should always be larger than 0, but get: " << val->value; + CHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: " + << val->value; } } size_t tndim = (ndim > rndim) ? ndim : rndim; @@ -1401,103 +1294,91 @@ bool TileRel(const Array& types, return true; } -Array TileCompute(const Attrs& attrs, - const Array& inputs, +Array TileCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const TileAttrs *param = attrs.as(); + const TileAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::tile(inputs[0], param->reps) }; + return {topi::tile(inputs[0], param->reps)}; } -Expr MakeTile(Expr data, - Array reps) { +Expr MakeTile(Expr data, Array reps) { auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.tile") -.set_body_typed(MakeTile); +TVM_REGISTER_GLOBAL("relay.op._make.tile").set_body_typed(MakeTile); RELAY_REGISTER_OP("tile") -.describe(R"code(Repeat the whole array multiple times. + .describe(R"code(Repeat the whole array multiple times. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Tile", TileRel) -.set_attr("FTVMCompute", TileCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Tile", TileRel) + .set_attr("FTVMCompute", TileCompute) + .set_attr("TOpPattern", kBroadcast); // reverse operator TVM_REGISTER_NODE_TYPE(ReverseAttrs); -bool ReverseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool ReverseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reverse: expect input type to be TensorType but get " - << types[0]; + << "reverse: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; CHECK(-ndim <= axis && axis < ndim) - << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; reporter->Assign(types[1], types[0]); return true; } -Array ReverseCompute(const Attrs& attrs, - const Array& inputs, +Array ReverseCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ReverseAttrs *param = attrs.as(); + const ReverseAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::flip(inputs[0], param->axis) }; + return {topi::flip(inputs[0], param->axis)}; } -Expr MakeReverse(Expr data, - int axis) { +Expr MakeReverse(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reverse") -.set_body_typed(MakeReverse); +TVM_REGISTER_GLOBAL("relay.op._make.reverse").set_body_typed(MakeReverse); RELAY_REGISTER_OP("reverse") -.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. + .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reverse", ReverseRel) -.set_attr("FTVMCompute", ReverseCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reverse", ReverseRel) + .set_attr("FTVMCompute", ReverseCompute) + .set_attr("TOpPattern", kInjective); // where operator -bool WhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4U); const auto* condition = types[0].as(); @@ -1511,17 +1392,16 @@ bool WhereRel(const Array& types, CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; if (cond_shape.size() != x_shape.size()) { - CHECK_EQ(cond_shape.size(), 1) - << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; + CHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape + << " must be either equal to x or has dimension of 1."; } for (size_t i = 0; i < x_shape.size(); i++) { CHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) << "x and y must have the same shape: " << x_shape << " vs " << y_shape; if (i < cond_shape.size()) { - CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; + CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) + << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; } } reporter->Assign(types[3], TensorType(x_shape, x->dtype)); @@ -1534,17 +1414,15 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { return Call(op, {condition, x, y}); } -Array WhereCompute(const Attrs& attrs, - const Array& inputs, +Array WhereCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::where(inputs[0], inputs[1], inputs[2]) }; + return {topi::where(inputs[0], inputs[1], inputs[2])}; } -TVM_REGISTER_GLOBAL("relay.op._make.where") -.set_body_typed(MakeWhere); +TVM_REGISTER_GLOBAL("relay.op._make.where").set_body_typed(MakeWhere); RELAY_REGISTER_OP("where") -.describe(R"code( + .describe(R"code( Return the elements, either from x or y, depending on the condition. Given three ndarrays, condition, x, and y, return an ndarray with the elements @@ -1572,34 +1450,28 @@ Examples:: where(cond, x, y) = [[1, 2], [7, 8]] )code" TVM_ADD_FILELINE) -.add_argument("condition", "Tensor", "Condition array") -.add_argument("x", "Tensor", "First array to be selected") -.add_argument("y", "Tensor", "Second array to be selected") -.set_num_inputs(3) -.set_support_level(4) -.add_type_rel("Where", WhereRel) -.set_attr("FTVMCompute", WhereCompute) -.set_attr("TOpPattern", kBroadcast); - + .add_argument("condition", "Tensor", "Condition array") + .add_argument("x", "Tensor", "First array to be selected") + .add_argument("y", "Tensor", "Second array to be selected") + .set_num_inputs(3) + .set_support_level(4) + .add_type_rel("Where", WhereRel) + .set_attr("FTVMCompute", WhereCompute) + .set_attr("TOpPattern", kBroadcast); // Squeeze TVM_REGISTER_NODE_TYPE(SqueezeAttrs); -Expr MakeSqueeze(Expr data, - Array axis) { +Expr MakeSqueeze(Expr data, Array axis) { auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.squeeze") -.set_body_typed(MakeSqueeze); - +TVM_REGISTER_GLOBAL("relay.op._make.squeeze").set_body_typed(MakeSqueeze); -bool SqueezeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1623,7 +1495,7 @@ bool SqueezeRel(const Array& types, } } else { // pair up original shape with a boolean which control whether it will be in the final shape. - std::vector > original_shape; + std::vector> original_shape; for (const auto& e : data->shape) { original_shape.push_back(std::pair(e, true)); } @@ -1650,78 +1522,70 @@ bool SqueezeRel(const Array& types, return true; } -Array SqueezeCompute(const Attrs& attrs, - const Array& inputs, +Array SqueezeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const SqueezeAttrs *param = attrs.as(); + const SqueezeAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::squeeze(inputs[0], param->axis) }; + return {topi::squeeze(inputs[0], param->axis)}; } - RELAY_REGISTER_OP("squeeze") -.describe(R"code(Squeeze the input tensor at the dimensions given by axes + .describe(R"code(Squeeze the input tensor at the dimensions given by axes - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Squeeze", SqueezeRel) -.set_attr("FTVMCompute", SqueezeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Squeeze", SqueezeRel) + .set_attr("FTVMCompute", SqueezeCompute) + .set_attr("TOpPattern", kInjective); // CollapseSumLike: -> B where BroadCast(A, B) = A -bool CollapseSumLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CollapseSumLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter); } -Expr MakeCollapseSumLike(Expr data, - Expr collapse_type) { +Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { static const Op& op = Op::Get("collapse_sum_like"); return Call(op, {data, collapse_type}, Attrs(), {}); } -Array CollapseSumLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CollapseSumLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::collapse_sum(inputs[0], out_ttype->shape) }; + return {topi::collapse_sum(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like") -.set_body_typed(MakeCollapseSumLike); +TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like").set_body_typed(MakeCollapseSumLike); RELAY_REGISTER_OP("collapse_sum_like") -.describe(R"code(Collapse the first input to match the shape of the second input. + .describe(R"code(Collapse the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") -.set_support_level(10) -.add_type_rel("CollapseSumLike", CollapseSumLikeRel) -.set_attr("FTVMCompute", CollapseSumLikeCompute) -.set_attr("TOpPattern", kCommReduce); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") + .set_support_level(10) + .add_type_rel("CollapseSumLike", CollapseSumLikeRel) + .set_attr("FTVMCompute", CollapseSumLikeCompute) + .set_attr("TOpPattern", kCommReduce); // BroadCastTo: -> B where BroadCast(A, B) = B -bool BroadCastToRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); auto ioattrs = attrs.as(); CHECK(ioattrs); auto intt = types[0].as(); - if (intt == nullptr) { return false; } + if (intt == nullptr) { + return false; + } auto type = TensorType(ioattrs->shape, intt->dtype); reporter->Assign(types[1], type); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); @@ -1734,87 +1598,75 @@ Expr MakeBroadCastTo(Expr data, Array shape) { return Call(op, {data}, Attrs(attrs), {}); } -Array BroadCastToCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { auto ioattrs = attrs.as(); CHECK(ioattrs != nullptr); - return { topi::broadcast_to(inputs[0], ioattrs->shape) }; + return {topi::broadcast_to(inputs[0], ioattrs->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to") -.set_body_typed(MakeBroadCastTo); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo); RELAY_REGISTER_OP("broadcast_to") -.describe(R"code(Broadcast the first input to match the shape argument. + .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.add_type_rel("BroadCastTo", BroadCastToRel) -.set_attr("FTVMCompute", BroadCastToCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(4) + .add_type_rel("BroadCastTo", BroadCastToRel) + .set_attr("FTVMCompute", BroadCastToCompute) + .set_attr("TOpPattern", kBroadcast); // BroadCastToLike: -> B where BroadCast(A, B) = B -bool BroadCastToLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); } -Expr MakeBroadCastToLike(Expr data, - Expr broadcast_type) { +Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { static const Op& op = Op::Get("broadcast_to_like"); return Call(op, {data, broadcast_type}, Attrs(), {}); } -Array BroadCastToLikeCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::broadcast_to(inputs[0], out_ttype->shape) }; + return {topi::broadcast_to(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like") -.set_body_typed(MakeBroadCastToLike); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like").set_body_typed(MakeBroadCastToLike); RELAY_REGISTER_OP("broadcast_to_like") -.describe(R"code(Broadcast the first input to match the shape of the second input. + .describe(R"code(Broadcast the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") -.set_support_level(10) -.add_type_rel("BroadCastToLike", BroadCastToLikeRel) -.set_attr("FTVMCompute", BroadCastToLikeCompute) -.set_attr("TOpPattern", kBroadcast); - + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") + .set_support_level(10) + .add_type_rel("BroadCastToLike", BroadCastToLikeRel) + .set_attr("FTVMCompute", BroadCastToLikeCompute) + .set_attr("TOpPattern", kBroadcast); // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { - CHECK(!arr[i].defined() || arr[i].as()) - << "Expect an int array"; + CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Downcast >(arr); + return Downcast>(arr); } - // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); -bool StridedSliceRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; - const StridedSliceAttrs *param = attrs.as(); + const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); auto dshape = data->shape; @@ -1862,12 +1714,8 @@ bool StridedSliceRel(const Array& types, int64_t begin_v = begin_vec[i]; int64_t end_v = end_vec[i]; - if ((stride_v == 1 && - begin_v == 0 && - end_v == max_range) || - (stride_v == -1 && - begin_v == max_range && - end_v == 0)) { + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { // Quick path, do not slice this dimension. oshape[i] = dshape[i]; continue; @@ -1876,8 +1724,7 @@ bool StridedSliceRel(const Array& types, // Require concrete integer as symbolic inference of min/max // can get complicated and not very helpful. const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - CHECK(p_dim_size) - << "strided_slice requires sliced dimension to be concrete int"; + CHECK(p_dim_size) << "strided_slice requires sliced dimension to be concrete int"; int64_t dim_size = p_dim_size[0]; begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; end_v = (end_v < 0) ? dim_size + end_v : end_v; @@ -1885,16 +1732,14 @@ bool StridedSliceRel(const Array& types, int64_t slice_range, step; if (stride_v < 0) { if (end_v < -1) end_v = -1; - CHECK_LT(end_v, begin_v) - << "strided_slice get empty slice at axis " << i; + CHECK_LT(end_v, begin_v) << "strided_slice get empty slice at axis " << i; begin_v = std::min(dim_size - 1, begin_v); slice_range = begin_v - end_v; step = -stride_v; } else { if (begin_v < 0) begin_v = 0; CHECK_GE(stride_v, 0); - CHECK_LT(begin_v, end_v) - << "strided_slice get empty slice at axis " << i; + CHECK_LT(begin_v, end_v) << "strided_slice get empty slice at axis " << i; end_v = std::min(dim_size, end_v); slice_range = end_v - begin_v; step = stride_v; @@ -1905,13 +1750,10 @@ bool StridedSliceRel(const Array& types, return true; } - -Array > StridedSliceInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - +Array> StridedSliceInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array> old_in_shapes; for (auto old_in_t : old_in_types) { CHECK(old_in_t.as()); @@ -1930,7 +1772,7 @@ Array > StridedSliceInferCorrectLayout( auto shape = old_in_shapes[0]; // NOTE: Discard "const" qualifier here. - auto *params = const_cast(attrs.as()); + auto* params = const_cast(attrs.as()); Array new_begin, new_end; @@ -1953,8 +1795,8 @@ Array > StridedSliceInferCorrectLayout( } } int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; - int64_t end = params->end[i].defined() ? params->end[i]->value : - shape[i].as()->value; + int64_t end = + params->end[i].defined() ? params->end[i]->value : shape[i].as()->value; if (begin % factor || end % factor) { // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; @@ -1970,12 +1812,8 @@ Array > StridedSliceInferCorrectLayout( return {{layout}, {layout}}; } - // Positional relay function to create StridedSlice operator used by frontend FFI. -Expr MakeStridedSlice(Expr data, - Array begin, - Array end, - Array strides) { +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); @@ -1984,20 +1822,15 @@ Expr MakeStridedSlice(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -Array StridedSliceCompute(const Attrs& attrs, - const Array& inputs, +Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const StridedSliceAttrs *param = attrs.as(); + const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); return Array{ - topi::strided_slice(inputs[0], param->begin, param->end, param->strides) - }; + topi::strided_slice(inputs[0], param->begin, param->end, param->strides)}; } - -TVM_REGISTER_GLOBAL("relay.op._make.strided_slice") -.set_body_typed(MakeStridedSlice); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice); RELAY_REGISTER_OP("strided_slice") .describe(R"code(Strided slice of an array. @@ -2023,40 +1856,32 @@ Examples:: [[ 5., 6.], [ 7., 8.]]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.set_attrs_type() -.add_type_rel("StridedSlice", StridedSliceRel) -.set_attr("FTVMCompute", StridedSliceCompute) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(4) + .set_attrs_type() + .add_type_rel("StridedSlice", StridedSliceRel) + .set_attr("FTVMCompute", StridedSliceCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // strided_set -bool StridedSetRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StridedSetRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); reporter->Assign(types[5], types[0]); return true; } -Expr MakeStridedSet(Expr data, - Expr v, - Expr begin, - Expr end, - Expr strides) { +Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) { static const Op& op = Op::Get("strided_set"); return Call(op, {data, v, begin, end, strides}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.strided_set") -.set_body_typed(MakeStridedSet); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_set").set_body_typed(MakeStridedSet); RELAY_REGISTER_OP("strided_set") - .describe(R"code(Strided set of an array. + .describe(R"code(Strided set of an array. Example:: x = [[ 1., 4., 7., 10.], @@ -2071,22 +1896,20 @@ Example:: [ 2., 44., 55., 66.], [ 3., 6., 9., 12.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(5) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("v", "Tensor", "The data to set.") -.add_argument("begin", "Tensor", "Indices for the start of the slice.") -.add_argument("end", "Tensor", "Indices indicating the end of the slice.") -.add_argument("strides", "Tensor", "The strides values.") -.set_support_level(4) -.set_attr("TOpPattern", kInjective) -.add_type_rel("StridedSet", StridedSetRel); + .set_num_inputs(5) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("v", "Tensor", "The data to set.") + .add_argument("begin", "Tensor", "Indices for the start of the slice.") + .add_argument("end", "Tensor", "Indices indicating the end of the slice.") + .add_argument("strides", "Tensor", "The strides values.") + .set_support_level(4) + .set_attr("TOpPattern", kInjective) + .add_type_rel("StridedSet", StridedSetRel); // relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); -bool SplitRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); @@ -2099,21 +1922,19 @@ bool SplitRel(const Array& types, if (axis < 0) { axis += data->shape.size(); } - CHECK_LT(axis, data->shape.size()) - << "axis should be within the input dimension range."; - CHECK_GE(axis, 0) - << "axis should be within the input dimension range."; + CHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; + CHECK_GE(axis, 0) << "axis should be within the input dimension range."; if (const IntImmNode* sections = param->indices_or_sections.as()) { - CHECK(reporter->Assert(indexmod(data->shape[axis], - sections->value) == tir::make_zero(DataType::Int(64)))) + CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == + tir::make_zero(DataType::Int(64)))) << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; for (int i = 0; i < sections->value; ++i) { - std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = indexdiv(oshape[axis], sections->value); - auto vec_type = TensorType(oshape, data->dtype); - fields.push_back(vec_type); + std::vector oshape(data->shape.begin(), data->shape.end()); + oshape[axis] = indexdiv(oshape[axis], sections->value); + auto vec_type = TensorType(oshape, data->dtype); + fields.push_back(vec_type); } reporter->Assign(types[1], TupleType(Array(fields))); } else { @@ -2140,25 +1961,21 @@ bool SplitRel(const Array& types, return true; } -Array SplitCompute(const Attrs& attrs, - const Array& inputs, +Array SplitCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto param = attrs.as(); CHECK(param != nullptr); if (const IntImmNode* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; - return Array{ - topi::split_sections(inputs[0], num_sections, param->axis) }; + return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { - auto indices = Downcast >(param->indices_or_sections); - return Array{ topi::split(inputs[0], indices, param->axis) }; + auto indices = Downcast>(param->indices_or_sections); + return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, - ObjectRef indices_or_sections, - int axis) { +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -2166,22 +1983,20 @@ Expr MakeSplit(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = MakeSplit(args[0], - tir::make_const(DataType::Int(32), static_cast(args[1])), - args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { + if (args.type_codes[1] == kDLInt) { + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. + *rv = + MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } }); RELAY_REGISTER_OP("split") -.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. Indices or sections to split into. Accepts an int or a tuple If indices_or_sections is an integer, the input will be divided equally @@ -2191,29 +2006,26 @@ If indices_or_sections is a tuple of sorted integers, the entries indicate where along axis the array is split. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Split", SplitRel) -.set_attr("FTVMCompute", SplitCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Split", SplitRel) + .set_attr("FTVMCompute", SplitCompute) + .set_attr("TOpPattern", kInjective); // relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); /*! -* \brief SliceLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool SliceLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief SliceLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool SliceLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -2238,8 +2050,8 @@ bool SliceLikeRel(const Array& types, if (i < target_shape.size()) { oshape[i] = target_shape[i]; CHECK(reporter->Assert(oshape[i] <= dshape[i])) - << "End index of axis " << i << " exceeds input shape: " - << oshape[i] << " vs " << dshape[i]; + << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs " + << dshape[i]; } } } else { @@ -2250,12 +2062,11 @@ bool SliceLikeRel(const Array& types, axis += dshape.size(); } CHECK(axis < static_cast(target_shape.size())) - << "Axis " << axis << " exceeds dimension " - << target_shape.size() << " of target_shape."; + << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape."; oshape[axis] = target_shape[axis]; CHECK(reporter->Assert(oshape[axis] <= dshape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << oshape[axis] << " vs " << dshape[axis]; + << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs " + << dshape[axis]; } } @@ -2263,18 +2074,14 @@ bool SliceLikeRel(const Array& types, return true; } - -Expr MakeSliceLike(Expr data, - Expr shape_like, - Array axes) { +Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); return Call(op, {data, shape_like}, Attrs(attrs), {}); } -Array SliceLikeCompute(const Attrs& attrs, - const Array& inputs, +Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); @@ -2290,11 +2097,10 @@ Array SliceLikeCompute(const Attrs& attrs, for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { end_idx.Set(i, target_shape[i]); - CHECK_LE(topi::GetConstInt(end_idx[i]), - topi::GetConstInt(src_shape[i])) - << "End index of axis " << i << " exceeds input shape: " - << topi::GetConstInt(end_idx[i]) << " vs " - << topi::GetConstInt(src_shape[i]); + CHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) + << "End index of axis " << i + << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " + << topi::GetConstInt(src_shape[i]); } } } else { @@ -2303,77 +2109,64 @@ Array SliceLikeCompute(const Attrs& attrs, axis = static_cast(src_shape.size()) + axis; } end_idx.Set(axis, target_shape[axis]); - CHECK_LE(topi::GetConstInt(end_idx[axis]), - topi::GetConstInt(src_shape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << topi::GetConstInt(end_idx[axis]) << " vs " - << topi::GetConstInt(src_shape[axis]); + CHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) + << "End index of axis " << axis + << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " + << topi::GetConstInt(src_shape[axis]); } } - return Array{ - topi::strided_slice(inputs[0], - GetIntArray(begin_idx), - GetIntArray(end_idx), - GetIntArray(strides)) - }; + return Array{topi::strided_slice(inputs[0], GetIntArray(begin_idx), + GetIntArray(end_idx), GetIntArray(strides))}; } - -TVM_REGISTER_GLOBAL("relay.op._make.slice_like") -.set_body_typed(MakeSliceLike); - +TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike); RELAY_REGISTER_OP("slice_like") -.describe(R"code(Slice the first input respect to the second input. + .describe(R"code(Slice the first input respect to the second input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(10) -.add_type_rel("SliceLike", SliceLikeRel) -.set_attr("FTVMCompute", SliceLikeCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(10) + .add_type_rel("SliceLike", SliceLikeRel) + .set_attr("FTVMCompute", SliceLikeCompute) + .set_attr("TOpPattern", kInjective); // relay.layout_transform TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); -Array LayoutTransformCompute(const Attrs& attrs, - const Array& inputs, +Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ - topi::layout_transform(inputs[0], param->src_layout, param->dst_layout) - }; + return Array{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)}; } -bool LayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "LayoutTransform: expect input data type to be TensorType but get " << types[0]; + return false; + } const LayoutTransformAttrs* params = attrs.as(); Layout src_layout(params->src_layout); Layout dst_layout(params->dst_layout); - CHECK(src_layout.defined() && dst_layout.defined()) - << "cannot convert from/to undefined layout"; - + CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); CHECK(layout_converter.defined()) - << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; const auto& out_shape = layout_converter.ForwardShape(data->shape); reporter->Assign(types[1], TensorType(out_shape, data->dtype)); return true; } -Expr MakeLayoutTransform(Expr data, - std::string src_layout, - std::string dst_layout) { +Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) { auto attrs = make_object(); attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); @@ -2381,27 +2174,24 @@ Expr MakeLayoutTransform(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.layout_transform") -.set_body_typed(MakeLayoutTransform); +TVM_REGISTER_GLOBAL("relay.op._make.layout_transform").set_body_typed(MakeLayoutTransform); RELAY_REGISTER_OP("layout_transform") -.describe(R"code(Transform the input data layout. + .describe(R"code(Transform the input data layout. For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("layout_transform", LayoutTransformRel) -.set_support_level(5) -.set_attr("FTVMCompute", LayoutTransformCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("layout_transform", LayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", LayoutTransformCompute); /* relay._contrib_reverse_reshape */ -Expr MakeReverseReshape(Expr data, - Expr newshape) { +Expr MakeReverseReshape(Expr data, Expr newshape) { auto attrs = make_object(); attrs->newshape = newshape; attrs->reverse = true; @@ -2409,11 +2199,10 @@ Expr MakeReverseReshape(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape") -.set_body_typed(MakeReverseReshape); +TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape); RELAY_REGISTER_OP("_contrib_reverse_reshape") -.describe(R"code(Reshapes the input array where the special values are inferred from + .describe(R"code(Reshapes the input array where the special values are inferred from right to left. Example:: @@ -2426,18 +2215,16 @@ example below:: - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); // gather_nd operator -bool GatherNDRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); @@ -2445,48 +2232,40 @@ bool GatherNDRel(const Array& types, const auto* indices = types[1].as(); if (data == nullptr) { CHECK(types[0].as()) - << "GatherND: expect input data type to be TensorType but get " - << types[0]; + << "GatherND: expect input data type to be TensorType but get " << types[0]; return false; } if (indices == nullptr) { CHECK(types[1].as()) - << "GatherND: expect indices type to be TensorType but get " - << types[1]; + << "GatherND: expect indices type to be TensorType but get " << types[1]; return false; } const size_t ndim = data->shape.size(); const IntImmNode* mdim = indices->shape[0].as(); const size_t kdim = indices->shape.size() - 1; - CHECK(size_t(mdim->value) <= ndim) - << "GatherND: indices shape does satisfy."; + CHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy."; Array oshape; - for (size_t i = 1; i < kdim + 1; ++i) - oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) - oshape.push_back(data->shape[i]); + for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); + for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } -Array GatherNDCompute(const Attrs& attrs, - const Array& inputs, +Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::gather_nd(inputs[0], inputs[1]) }; + return {topi::gather_nd(inputs[0], inputs[1])}; } -Expr MakeGatherND(Expr data, - Expr indices) { +Expr MakeGatherND(Expr data, Expr indices) { static const Op& op = Op::Get("gather_nd"); return Call(op, {data, indices}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.gather_nd") -.set_body_typed(MakeGatherND); +TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND); RELAY_REGISTER_OP("gather_nd") -.describe(R"code(Gather elements or slices from data and store to + .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with @@ -2494,19 +2273,17 @@ shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("GatherND", GatherNDRel) -.set_attr("FTVMCompute", GatherNDCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("GatherND", GatherNDRel) + .set_attr("FTVMCompute", GatherNDCompute) + .set_attr("TOpPattern", kInjective); // relay.sequence_mask TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs); -bool SequenceMaskRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SequenceMaskRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, valid_length, result] CHECK_EQ(types.size(), 3); @@ -2523,19 +2300,15 @@ bool SequenceMaskRel(const Array& types, return true; } -Array SequenceMaskCompute(const Attrs& attrs, - const Array& inputs, +Array SequenceMaskCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); return Array{ - topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) }; + topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)}; } -Expr MakeSequenceMask(Expr data, - Expr valid_length, - double mask_value, - int axis) { +Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { auto attrs = make_object(); attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); @@ -2543,11 +2316,11 @@ Expr MakeSequenceMask(Expr data, return Call(op, {data, valid_length}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask") -.set_body_typed(MakeSequenceMask); +TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask").set_body_typed(MakeSequenceMask); RELAY_REGISTER_OP("sequence_mask") -.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value. + .describe( + R"code(Sets all elements outside the expected length of the sequence to a constant value. This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...] and returns an array of the same shape. @@ -2595,21 +2368,19 @@ Examples:: [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") -.set_support_level(10) -.add_type_rel("SequenceMask", SequenceMaskRel) -.set_attr("FTVMCompute", SequenceMaskCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") + .set_support_level(10) + .add_type_rel("SequenceMask", SequenceMaskRel) + .set_attr("FTVMCompute", SequenceMaskCompute) + .set_attr("TOpPattern", kInjective); // relay.one_hot TVM_REGISTER_NODE_TYPE(OneHotAttrs); -bool OneHotRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [indices, on_value, off_value, result] CHECK_EQ(types.size(), 4); @@ -2635,27 +2406,15 @@ bool OneHotRel(const Array& types, return true; } -Array OneHotCompute(const Attrs& attrs, - const Array& inputs, +Array OneHotCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array { - topi::one_hot(inputs[0], - inputs[1](), - inputs[2](), - param->depth, - param->axis, - param->dtype) - }; -} - -Expr MakeOneHot(Expr indices, - Expr on_value, - Expr off_value, - int depth, - int axis, - DataType dtype) { + return Array{ + topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)}; +} + +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) { auto attrs = make_object(); attrs->depth = std::move(depth); attrs->axis = axis; @@ -2664,11 +2423,10 @@ Expr MakeOneHot(Expr indices, return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.one_hot") -.set_body_typed(MakeOneHot); +TVM_REGISTER_GLOBAL("relay.op._make.one_hot").set_body_typed(MakeOneHot); RELAY_REGISTER_OP("one_hot") -.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, other locations take value 0. Final dimension is x depth. **indices** Locations to set to 1. @@ -2682,42 +2440,36 @@ RELAY_REGISTER_OP("one_hot") **axis** Axis to fill. **dtype**)code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("indices", "Tensor", "Locations to set to on_value.") -.add_argument("on_value", "Expr", "Value to fill at indices.") -.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") -.set_support_level(10) -.add_type_rel("OneHot", OneHotRel) -.set_attr("FTVMCompute", OneHotCompute) -.set_attr("TOpPattern", kOutEWiseFusable); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "Locations to set to on_value.") + .add_argument("on_value", "Expr", "Value to fill at indices.") + .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") + .set_support_level(10) + .add_type_rel("OneHot", OneHotRel) + .set_attr("FTVMCompute", OneHotCompute) + .set_attr("TOpPattern", kOutEWiseFusable); /* relay.unravel_index */ -bool UnRavelIndexRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UnRavelIndexRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* indices = types[0].as(); if (indices == nullptr) { CHECK(types[0].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[0]; + << "unravel_index: expect input type to be TensorType but get " << types[0]; return false; } - CHECK(indices->dtype.is_int()) - << "indices of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer"; const auto* shape = types[1].as(); if (shape == nullptr) { CHECK(types[1].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[1]; + << "unravel_index: expect input type to be TensorType but get " << types[1]; return false; } - CHECK(indices->dtype.is_int()) - << "shape of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer"; Array indices_shape; Array shape_shape; @@ -2733,32 +2485,30 @@ bool UnRavelIndexRel(const Array& types, return true; } -Array UnRavelIndexCompute(const Attrs& attrs, - const Array& inputs, +Array UnRavelIndexCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return Array{topi::unravel_index(inputs[0], inputs[1])}; } -Expr MakeUnRavelIndex(Expr data, - Expr shape) { +Expr MakeUnRavelIndex(Expr data, Expr shape) { static const Op& op = Op::Get("unravel_index"); return Call(op, {data, shape}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.unravel_index") -.set_body_typed(MakeUnRavelIndex); +TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex); RELAY_REGISTER_OP("unravel_index") -.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. + .describe( + R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. Example:: - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]] )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.set_support_level(3) -.add_type_rel("UnRavelIndexRel", UnRavelIndexRel) -.set_attr("FTVMCompute", UnRavelIndexCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .set_support_level(3) + .add_type_rel("UnRavelIndexRel", UnRavelIndexRel) + .set_attr("FTVMCompute", UnRavelIndexCompute) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a64dcd5a6b301..62433c297e8e8 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -26,33 +26,32 @@ #include #include -#include +#include + #include #include #include #include #include +#include namespace tvm { namespace relay { template -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. - */ + */ const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { throw Error( - ErrorBuilder() - << "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); + ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); } else if (types[0].as() != nullptr) { return false; } @@ -69,10 +68,8 @@ bool ConcatenateRel(const Array& types, // Sanity check: axis int axis = param->axis; if (!(-ndim <= axis && axis < ndim)) { - throw Error(ErrorBuilder() << - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim); + throw Error(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim); } axis = axis < 0 ? ndim + axis : axis; @@ -94,14 +91,15 @@ bool ConcatenateRel(const Array& types, for (size_t j = 0; j < first->shape.size(); ++j) { if (j == static_cast(axis)) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); } } // Calculate shape std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; + IndexExpr& concat_dim = oshape[axis]; bool has_any = false; if (concat_dim.as()) { has_any = true; @@ -125,11 +123,10 @@ bool ConcatenateRel(const Array& types, return true; } -static inline Array> ConcatenateLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +static inline Array> ConcatenateLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { ConcatenateAttrs* param = const_cast(attrs.as()); Array> old_in_shapes; @@ -141,8 +138,8 @@ static inline Array> ConcatenateLayout( } } - size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : - static_cast(param->axis); + size_t axis = + param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); Layout ret; bool is_new_layout_selected = false; @@ -175,11 +172,11 @@ static inline Array> ConcatenateLayout( } if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } - return Array > {Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } } // namespace relay diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 152c7693fd816..ccf6dd2fa1b1f 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -21,366 +21,393 @@ * \file unary.cc * \brief Unary operators. */ -#include -#include -#include #include #include -#include "../type_relations.h" +#include +#include +#include + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_UNARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - return {FTOPI(inputs[0])}; \ - } \ - +#define RELAY_UNARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { return {FTOPI(inputs[0])}; } RELAY_REGISTER_UNARY_OP("log") -.describe(R"code(Returns the log input array, computed element-wise. + .describe(R"code(Returns the log input array, computed element-wise. .. math:: log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); RELAY_REGISTER_UNARY_OP("log2") -.describe(R"code(Returns the log to base 2 of input array, computed element-wise. + .describe(R"code(Returns the log to base 2 of input array, computed element-wise. .. math:: log2(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); RELAY_REGISTER_UNARY_OP("log10") -.describe(R"code(Returns the log to base 10 of input array, computed element-wise. + .describe(R"code(Returns the log to base 10 of input array, computed element-wise. .. math:: log10(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); RELAY_REGISTER_UNARY_OP("tan") -.describe(R"code(Returns the tan of input array, computed element-wise. + .describe(R"code(Returns the tan of input array, computed element-wise. .. math:: Y = tan(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); RELAY_REGISTER_UNARY_OP("cos") -.describe(R"code(Returns the cos of input array, computed element-wise. + .describe(R"code(Returns the cos of input array, computed element-wise. .. math:: Y = cos(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); RELAY_REGISTER_UNARY_OP("cosh") -.describe(R"code(Returns the cosh of input array, computed element-wise. + .describe(R"code(Returns the cosh of input array, computed element-wise. .. math:: Y = cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh)); RELAY_REGISTER_UNARY_OP("sin") -.describe(R"code(Returns the sin of input array, computed element-wise. + .describe(R"code(Returns the sin of input array, computed element-wise. .. math:: Y = sin(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); RELAY_REGISTER_UNARY_OP("sinh") -.describe(R"code(Returns the sinh of input array, computed element-wise. + .describe(R"code(Returns the sinh of input array, computed element-wise. .. math:: Y = sinh(X) +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh)); + +RELAY_REGISTER_UNARY_OP("acos") +.describe(R"code(Returns the acos of input array, computed element-wise. + +.. math:: + Y = acos(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::acos)); + + +RELAY_REGISTER_UNARY_OP("acosh") +.describe(R"code(Returns the acosh of input array, computed element-wise. + +.. math:: + Y = acosh(X) + )code" TVM_ADD_FILELINE) .set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh)); +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::acosh)); + + +RELAY_REGISTER_UNARY_OP("asin") +.describe(R"code(Returns the asin of input array, computed element-wise. + +.. math:: + Y = asin(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::asin)); + + +RELAY_REGISTER_UNARY_OP("asinh") +.describe(R"code(Returns the asinh of input array, computed element-wise. + +.. math:: + Y = asinh(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::asinh)); RELAY_REGISTER_UNARY_OP("atan") -.describe(R"code(Returns the atan of input array, computed element-wise. + .describe(R"code(Returns the atan of input array, computed element-wise. .. math:: Y = atan(X) +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); + +RELAY_REGISTER_UNARY_OP("atanh") +.describe(R"code(Returns the atanh of input array, computed element-wise. + +.. math:: + Y = atanh(X) + )code" TVM_ADD_FILELINE) .set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atanh)); RELAY_REGISTER_UNARY_OP("exp") -.describe(R"code(Returns the exp input array, computed element-wise. + .describe(R"code(Returns the exp input array, computed element-wise. .. math:: \exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); RELAY_REGISTER_UNARY_OP("fast_exp") -.describe(R"code(Returns the fast_exp input array, computed element-wise. + .describe(R"code(Returns the fast_exp input array, computed element-wise. .. math:: \fast_exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); RELAY_REGISTER_UNARY_OP("erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); RELAY_REGISTER_UNARY_OP("fast_erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \fast_erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); RELAY_REGISTER_UNARY_OP("sqrt") -.describe(R"code(Returns the sqrt input array, computed element-wise. + .describe(R"code(Returns the sqrt input array, computed element-wise. .. math:: sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); RELAY_REGISTER_UNARY_OP("rsqrt") -.describe(R"code(Returns the rsqrt input array, computed element-wise. + .describe(R"code(Returns the rsqrt input array, computed element-wise. .. math:: 1/sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); RELAY_REGISTER_UNARY_OP("zeros_like") -.describe(R"code(Returns an array of zeros, with same type and shape as the input. + .describe(R"code(Returns an array of zeros, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("ones_like") -.describe(R"code(Returns an array of ones, with same type and shape as the input. + .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("sigmoid") -.describe(R"code(Returns the sigmoid input array, computed element-wise. + .describe(R"code(Returns the sigmoid input array, computed element-wise. .. math:: sigmoid(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); RELAY_REGISTER_UNARY_OP("copy") -.describe(R"code(Copy a tensor. + .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); -TVM_REGISTER_GLOBAL("relay.op._make.clip") -.set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_object(); - attrs->a_min = a_min; - attrs->a_max = a_max; - static const Op& op = Op::Get("clip"); +TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_object(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); return Call(op, {a}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("clip") -.describe(R"code(Clip tensor values. + .describe(R"code(Clip tensor values. This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kElemWise) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attrs_type() -.set_support_level(3); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(3); RELAY_REGISTER_UNARY_OP("floor") -.describe(R"code(Returns the floor of input array, computed element-wise. + .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); RELAY_REGISTER_UNARY_OP("ceil") -.describe(R"code(Returns the ceil of input array, computed element-wise. + .describe(R"code(Returns the ceil of input array, computed element-wise. .. math:: ceil(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); RELAY_REGISTER_UNARY_OP("trunc") -.describe(R"code(Returns the trunc of input array, computed element-wise. + .describe(R"code(Returns the trunc of input array, computed element-wise. .. math:: trunc(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); RELAY_REGISTER_UNARY_OP("round") -.describe(R"code(Returns the round of input array, computed element-wise. + .describe(R"code(Returns the round of input array, computed element-wise. .. math:: round(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); RELAY_REGISTER_UNARY_OP("sign") -.describe(R"code(Returns the sign of input array, computed element-wise. + .describe(R"code(Returns the sign of input array, computed element-wise. .. numpy:: sign(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); RELAY_REGISTER_UNARY_OP("abs") -.describe(R"code(Returns the abs of input array, computed element-wise. + .describe(R"code(Returns the abs of input array, computed element-wise. .. math:: abs(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); RELAY_REGISTER_UNARY_OP("tanh") -.describe(R"code(Returns the tanh of input array, computed element-wise. + .describe(R"code(Returns the tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); RELAY_REGISTER_UNARY_OP("fast_tanh") -.describe(R"code(Returns the fast_tanh of input array, computed element-wise. + .describe(R"code(Returns the fast_tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); RELAY_REGISTER_UNARY_OP("negative") -.describe(R"code(Returns the numeric negative of input array, computed element-wise. + .describe(R"code(Returns the numeric negative of input array, computed element-wise. .. math:: -(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); RELAY_REGISTER_UNARY_OP("logical_not") -.describe(R"code(Returns the logical inverse of input array, computed element-wise. + .describe(R"code(Returns the logical inverse of input array, computed element-wise. .. math:: !(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); RELAY_REGISTER_UNARY_OP("bitwise_not") -.describe(R"code(Returns the bitwise inverse of input array, computed element-wise. + .describe(R"code(Returns the bitwise inverse of input array, computed element-wise. .. math:: ~(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); // shape_of TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); -bool ShapeOfRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); - CHECK(tt != nullptr); + if (tt == nullptr) { + return false; + } const auto* param = attrs.as(); CHECK(param != nullptr); auto rank_shape = RankShape(tt->shape); @@ -388,8 +415,7 @@ bool ShapeOfRel(const Array& types, return true; } -Array ShapeOfCompute(const Attrs& attrs, - const Array& inputs, +Array ShapeOfCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -397,8 +423,7 @@ Array ShapeOfCompute(const Attrs& attrs, return {topi::shape(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.shape_of") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); @@ -406,29 +431,25 @@ TVM_REGISTER_GLOBAL("relay.op._make.shape_of") }); RELAY_REGISTER_OP("shape_of") -.describe(R"code(Returns a tensor representing the shape of a tensor. + .describe(R"code(Returns a tensor representing the shape of a tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("ShapeOf", ShapeOfRel) -.set_attr("TOpIsStateful", false) -// Use kOpaque for shape_of op for now since it won't be performance critic, -// and it makes things easier for dynamic shape func -.set_attr("TOpPattern", kOpaque) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", ShapeOfCompute); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("ShapeOf", ShapeOfRel) + .set_attr("TOpIsStateful", false) + // Use kOpaque for shape_of op for now since it won't be performance critic, + // and it makes things easier for dynamic shape func + .set_attr("TOpPattern", kOpaque) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", ShapeOfCompute); TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); -bool NdarraySizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool NdarraySizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); CHECK(tt != nullptr); @@ -438,8 +459,7 @@ bool NdarraySizeRel(const Array& types, return true; } -Array NdarraySizeCompute(const Attrs& attrs, - const Array& inputs, +Array NdarraySizeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -447,8 +467,7 @@ Array NdarraySizeCompute(const Attrs& attrs, return Array{topi::ndarray_size(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("ndarray_size"); @@ -456,46 +475,45 @@ TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") }); RELAY_REGISTER_OP("ndarray_size") -.describe(R"code(Returns a tensor representing the number of elements of input tensor. + .describe(R"code(Returns a tensor representing the number of elements of input tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("NdarraySize", NdarraySizeRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", -ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", NdarraySizeCompute); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("NdarraySize", NdarraySizeRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", NdarraySizeCompute); RELAY_REGISTER_UNARY_OP("isnan") -.describe(R"code(Returns whether the input contains any NaN, computed element-wise. + .describe(R"code(Returns whether the input contains any NaN, computed element-wise. .. math:: isnan(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); RELAY_REGISTER_UNARY_OP("isfinite") -.describe(R"code(Returns the finiteness of input, computed element-wise. + .describe(R"code(Returns the finiteness of input, computed element-wise. .. math:: isfinite(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); RELAY_REGISTER_UNARY_OP("isinf") -.describe(R"code(Returns the infiniteness of input, computed element-wise. + .describe(R"code(Returns the infiniteness of input, computed element-wise. .. math:: isinf(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index e2e7f4994349b..677683c94021e 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,19 +22,19 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include "./type_relations.h" + #include -#include #include #include +#include + #include -#include "./type_relations.h" namespace tvm { namespace relay { -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { for (size_t i = 1; i < types.size(); ++i) { reporter->Assign(types[i], types[0]); @@ -42,8 +42,7 @@ bool IdentityRel(const Array& types, return true; } -bool EqualCheck(const IndexExpr& lhs, - const IndexExpr& rhs) { +bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) { IndexExpr diff = lhs - rhs; if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; @@ -64,9 +63,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, - const TensorType& t2, - DataType output_dtype) { +Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); @@ -87,9 +84,7 @@ Type ConcreteBroadcast(const TensorType& t1, } else if (EqualCheck(s1, s2)) { oshape.push_back(s1); } else { - throw Error(ErrorBuilder() - << "Incompatible broadcast type " - << t1 << " and " << t2); + throw Error(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); } } @@ -98,13 +93,10 @@ Type ConcreteBroadcast(const TensorType& t1, for (; i <= max_ndim; ++i) { oshape.push_back(rshape[max_ndim - i]); } - return TensorType(Array( - oshape.rbegin(), oshape.rend()), output_dtype); + return TensorType(Array(oshape.rbegin(), oshape.rend()), output_dtype); } -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -112,17 +104,15 @@ bool BroadcastRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); + reporter->Assign( + types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); return true; } } return false; } -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -130,17 +120,15 @@ bool BroadcastCompRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), DataType::Bool())); + reporter->Assign(types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), + DataType::Bool())); return true; } } return false; } -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { if (auto* t0 = types[0].as()) { Type out_type = TensorType(GetRef(t0)->shape, DataType::Bool()); @@ -154,7 +142,7 @@ Array RankShape(const Array& shape) { if (shape.size() == 0) { return {}; } else { - return { tvm::Integer(shape.size()) }; + return {tvm::Integer(shape.size())}; } } diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 48a545bddd0b2..acd4b2dae1be6 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -40,9 +41,7 @@ namespace relay { * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -55,9 +54,7 @@ bool IdentityRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -74,15 +71,11 @@ bool BroadcastRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter); +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); Array RankShape(const Array& shape); diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index cafe9b6dd0c3e..18a2edb4540ab 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -21,46 +21,38 @@ * \file multibox_op.cc * \brief Multibox related operators */ -#include -#include #include +#include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs); -bool MultiboxPriorRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiboxPriorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); const MultiBoxPriorAttrs* param = attrs.as(); const auto& dshape = data->shape; CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: " - "[batch, channel, height, width]"; + "[batch, channel, height, width]"; IndexExpr in_height = dshape[2]; IndexExpr in_width = dshape[3]; int num_sizes = static_cast(param->sizes.size()); int num_ratios = static_cast(param->ratios.size()); // since input sizes are same in each batch, we could share MultiBoxPrior - std::vector oshape( - {1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); + std::vector oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); // assign output type reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } - -Expr MakeMultiBoxPrior(Expr data, - Array sizes, - Array ratios, - Array steps, - Array offsets, - bool clip) { +Expr MakeMultiBoxPrior(Expr data, Array sizes, Array ratios, + Array steps, Array offsets, bool clip) { auto attrs = make_object(); attrs->sizes = std::move(sizes); attrs->ratios = std::move(ratios); @@ -71,25 +63,20 @@ Expr MakeMultiBoxPrior(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior") -.set_body_typed(MakeMultiBoxPrior); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior").set_body_typed(MakeMultiBoxPrior); RELAY_REGISTER_OP("vision.multibox_prior") -.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." + .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("MultiBoxPrior", MultiboxPriorRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("MultiBoxPrior", MultiboxPriorRel); TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); -bool MultiBoxTransformLocRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiBoxTransformLocRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); @@ -102,20 +89,15 @@ bool MultiBoxTransformLocRel(const Array& types, const auto& loc_shape = loc_pred->shape; const auto& anchor_shape = anchor->shape; - CHECK_EQ(cls_shape.size(), 3U) - << "The dimension of class probability should be 3, but received " - << cls_shape.size(); + CHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received " + << cls_shape.size(); CHECK_EQ(loc_shape.size(), 2U) - << "The dimension of location prediction should be 2, but received " - << loc_shape.size(); + << "The dimension of location prediction should be 2, but received " << loc_shape.size(); CHECK_EQ(anchor_shape.size(), 3U) - << "The dimension of anchor should be 3, but received " - << anchor_shape.size(); + << "The dimension of anchor should be 3, but received " << anchor_shape.size(); - CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) - << "Number of anchors mismatch found"; - CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) - << "# anchors mismatch with # loc."; + CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found"; + CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc."; CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0."; CHECK(reporter->AssertEQ(anchor_shape[2], 4)); @@ -130,12 +112,8 @@ bool MultiBoxTransformLocRel(const Array& types, return true; } -Expr MakeMultiBoxTransformLoc(Expr cls_prob, - Expr loc_pred, - Expr anchor, - bool clip, - double threshold, - Array variances) { +Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip, + double threshold, Array variances) { auto attrs = make_object(); attrs->clip = std::move(clip); attrs->threshold = std::move(threshold); @@ -145,18 +123,18 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, } TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc") -.set_body_typed(MakeMultiBoxTransformLoc); + .set_body_typed(MakeMultiBoxTransformLoc); RELAY_REGISTER_OP("vision.multibox_transform_loc") -.describe(R"doc("Location transformation for multibox detection." + .describe(R"doc("Location transformation for multibox detection." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Class probabilities.") -.add_argument("loc_pred", "Tensor", "Location regression predictions.") -.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") -.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) -.set_support_level(5); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Class probabilities.") + .add_argument("loc_pred", "Tensor", "Location regression predictions.") + .add_argument("anchor", "Tensor", "Multibox prior anchor boxes") + .add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) + .set_support_level(5); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 25743f98bc0b3..b1aaaf01ae9c0 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -21,17 +21,15 @@ * \file nms.cc * \brief Non-maximum suppression operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); -bool GetValidCountRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -48,10 +46,7 @@ bool GetValidCountRel(const Array& types, return true; } -Expr MakeGetValidCounts(Expr data, - double score_threshold, - int id_index, - int score_index) { +Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { auto attrs = make_object(); attrs->score_threshold = score_threshold; attrs->id_index = id_index; @@ -60,33 +55,26 @@ Expr MakeGetValidCounts(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts") -.set_body_typed(MakeGetValidCounts); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts); RELAY_REGISTER_OP("vision.get_valid_counts") -.describe(R"doc(Get valid count of bounding boxes given + .describe(R"doc(Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(5) -.add_type_rel("GetValidCount", GetValidCountRel); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(5) + .add_type_rel("GetValidCount", GetValidCountRel); TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); -bool NMSRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); - const NonMaximumSuppressionAttrs* param = - attrs.as(); + const NonMaximumSuppressionAttrs* param = attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; @@ -102,18 +90,9 @@ bool NMSRel(const Array& types, return true; } - -Expr MakeNMS(Expr data, - Expr valid_count, - int max_output_size, - double iou_threshold, - bool force_suppress, - int top_k, - int coord_start, - int score_index, - int id_index, - bool return_indices, - bool invalid_to_bottom) { +Expr MakeNMS(Expr data, Expr valid_count, int max_output_size, double iou_threshold, + bool force_suppress, int top_k, int coord_start, int score_index, int id_index, + bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; @@ -128,21 +107,18 @@ Expr MakeNMS(Expr data, return Call(op, {data, valid_count}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression") -.set_body_typed(MakeNMS); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); RELAY_REGISTER_OP("vision.non_max_suppression") -.describe(R"doc(Non-maximum suppression. The input boxes should + .describe(R"doc(Non-maximum suppression. The input boxes should be in the format of [class_id, score, left, top, right, bottom]. Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") -.set_support_level(5) -.add_type_rel("NMS", NMSRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input data.") + .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") + .set_support_level(5) + .add_type_rel("NMS", NMSRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 6b221a279bacc..efedb5ef3d7ad 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -21,9 +21,9 @@ * \file rcnn_op.cc * \brief Faster RCNN and Mask RCNN operators */ +#include #include #include -#include namespace tvm { namespace relay { @@ -36,6 +36,8 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* rois = types[1].as(); + CHECK(data); + CHECK(rois); const auto& dshape = data->shape; const auto& rshape = rois->shape; CHECK(roi_align_attrs); @@ -60,8 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align") -.set_body_typed(MakeROIAlign); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align").set_body_typed(MakeROIAlign); RELAY_REGISTER_OP("vision.roi_align") .describe(R"doc(ROI Align operator. @@ -73,16 +74,16 @@ RELAY_REGISTER_OP("vision.roi_align") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIAlign", ROIAlignRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIAlign", ROIAlignRel); TVM_REGISTER_NODE_TYPE(ROIPoolAttrs); bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { auto roi_pool_attrs = attrs.as(); CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -110,8 +111,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool") -.set_body_typed(MakeROIPool); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool").set_body_typed(MakeROIPool); RELAY_REGISTER_OP("vision.roi_pool") .describe(R"doc(ROI Pool operator. @@ -123,11 +123,11 @@ RELAY_REGISTER_OP("vision.roi_pool") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIPool", ROIPoolRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIPool", ROIPoolRel); TVM_REGISTER_NODE_TYPE(ProposalAttrs); @@ -153,16 +153,14 @@ bool ProposalRel(const Array& types, int num_inputs, const Attrs& attrs, auto batch = cls_prob->shape[0]; - std::vector oshape( - {batch * proposal_attrs->rpn_post_nms_top_n, 5}); + std::vector oshape({batch * proposal_attrs->rpn_post_nms_top_n, 5}); reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype)); return true; } Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array scales, Array ratios, int feature_stride, double threshold, - int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, - bool iou_loss) { + int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) { auto attrs = make_object(); attrs->scales = scales; attrs->ratios = ratios; @@ -176,8 +174,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal") -.set_body_typed(MakeProposal); +TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal").set_body_typed(MakeProposal); RELAY_REGISTER_OP("vision.proposal") .describe(R"code(Generate region proposals via RPN. @@ -187,12 +184,12 @@ RELAY_REGISTER_OP("vision.proposal") - **im_info**: 2-D with shape [batch, 3]. - **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5]. )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") -.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") -.add_argument("im_info", "Tensor", "Image size and scale") -.set_support_level(5) -.add_type_rel("Proposal", ProposalRel); + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") + .add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") + .add_argument("im_info", "Tensor", "Image size and scale") + .set_support_level(5) + .add_type_rel("Proposal", ProposalRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 58596778de1d0..e54473f68ef75 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -21,10 +21,12 @@ * \file yolo.cc * \brief Yolo related operators */ -#include -#include #include +#include +#include + #include + #include "../op_common.h" #include "../type_relations.h" @@ -34,15 +36,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(YoloReorgAttrs); /*! -* \brief YoloReorgRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool YoloReorgRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief YoloReorgRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool YoloReorgRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -60,34 +60,29 @@ bool YoloReorgRel(const Array& types, return true; } -Expr MakeYoloReorg(Expr data, - Integer stride) { +Expr MakeYoloReorg(Expr data, Integer stride) { auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg") -.set_body_typed(MakeYoloReorg); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg").set_body_typed(MakeYoloReorg); RELAY_REGISTER_OP("vision.yolo_reorg") -.describe(R"doc("Yolo reorg operation. This layer reorganize the output. + .describe(R"doc("Yolo reorg operation. This layer reorganize the output. Its function is mostly shape transform.")doc" TVM_ADD_FILELINE) -.add_argument("data", "Tensor", "The input tensor.") -.set_num_inputs(1) -.set_support_level(5) -.set_attrs_type() -.add_type_rel("YoloReorg", YoloReorgRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* params = attrs.as(); - CHECK(params != nullptr); - return Array{ topi::vision::reorg(inputs[0], params->stride) }; -}); + .add_argument("data", "Tensor", "The input tensor.") + .set_num_inputs(1) + .set_support_level(5) + .set_attrs_type() + .add_type_rel("YoloReorg", YoloReorgRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* params = attrs.as(); + CHECK(params != nullptr); + return Array{topi::vision::reorg(inputs[0], params->stride)}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index d8752d8030d75..b0dc3e4af5c4a 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -44,7 +45,6 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); - // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first add the tensors, subtract the zero point, and requantize at the end. @@ -65,18 +65,14 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Q_c = Q_a' + Q_b' - zp_c // The add op is done in int32 precision. - - // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' + Q_b' auto output = Add(requantized_lhs, requantized_rhs); @@ -92,9 +88,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // QNN Addition operator. QNN_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); + .describe("Elementwise add with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 650dcb962d444..bda8cf8787934 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -22,13 +22,14 @@ * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. */ -#include #include #include #include +#include + #include "../../op/tensor/transform.h" -#include "../../transforms/pattern_util.h" #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -42,10 +43,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at // Check the scale and zero point types const auto* input_scales_tuple = types[1].as(); if (input_scales_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of scales as the second argument, found " - << PrettyPrint(types[1])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of scales as the second argument, found " + << PrettyPrint(types[1])); } for (const auto& input_scale : input_scales_tuple->fields) { CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] @@ -53,10 +53,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at const auto* input_zero_points_tuple = types[2].as(); if (input_zero_points_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of zero_points as the third argument, found " - << PrettyPrint(types[2])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of zero_points as the third argument, found " + << PrettyPrint(types[2])); } for (const auto& input_zero_point : input_zero_points_tuple->fields) { CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] @@ -113,9 +112,8 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("qnn.concatenate"); - return Call(op, - {data, input_scales, input_zero_points, output_scale, output_zero_point}, - Attrs(attrs), {}); + return Call(op, {data, input_scales, input_zero_points, output_scale, output_zero_point}, + Attrs(attrs), {}); } /* @@ -149,8 +147,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, // If the output qnn params do not match the input qnn params, we can call requantize on the input // expr first, followed by a concatenate on the requantized input exprs. - auto tuple_data = data.as(); - CHECK(tuple_data != nullptr); + Array tuple_exprs; + if (data->IsInstance()) { + tuple_exprs = data.as()->fields; + } else if (data->IsInstance()) { // if the data is a CallNode, use TupleGetItems + auto call = Downcast(data); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + tuple_exprs.push_back(TupleGetItem(call, i)); + } + } + CHECK(!tuple_exprs.empty()); auto tuple_input_scales = input_scales.as(); CHECK(tuple_input_scales != nullptr); @@ -160,7 +166,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, int idx = 0; Array requantized_exprs; - for (auto quantized_expr : tuple_data->fields) { + for (auto quantized_expr : tuple_exprs) { // Get the input scale for the idx quantized input tensor. auto input_scale = tuple_input_scales->fields[idx]; @@ -188,22 +194,23 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.concatenate") -.describe(R"code(Concatenate the quantized input tensors along the given axis. + .describe(R"code(Concatenate the quantized input tensors along the given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The tensor to concatenate.") -.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") -.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("QnnConcatenate", QnnConcatenateRel) -.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) -.set_attr("FInferCorrectLayout", QnnConcatenateLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate") -.set_body_typed(MakeQnnConcatenate); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The tensor to concatenate.") + .add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") + .add_argument("input_zero_points", "Tensor", + "The quantization zero_points of the input tensors.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QnnConcatenate", QnnConcatenateRel) + .set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) + .set_attr("FInferCorrectLayout", QnnConcatenateLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate").set_body_typed(MakeQnnConcatenate); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 37186283ba511..ae52a42e42b8e 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -21,15 +21,16 @@ * \file src/relay/qnn/op/convolution.cc * \brief Property def of qnn convolution operator. */ -#include +#include "../../op/nn/convolution.h" + #include #include #include #include #include #include +#include -#include "../../op/nn/convolution.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -88,9 +89,8 @@ Array> QnnConvInferCorrectLayout(const Attrs& attrs, } bool is_depthwise(const Conv2DAttrs* param) { - return param->channels.defined() && - tvm::tir::ExprDeepEqual()(param->channels, param->groups) && - param->groups != 1; + return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) && + param->groups != 1; } // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier @@ -201,8 +201,8 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D auto pad_left_value = get_const_int(param->padding[1]); auto pad_bottom_value = get_const_int(param->padding[2]); auto pad_right_value = get_const_int(param->padding[3]); - bool do_pad = pad_top_value != 0 || pad_left_value != 0 || - pad_bottom_value != 0 || pad_right_value != 0; + bool do_pad = + pad_top_value != 0 || pad_left_value != 0 || pad_bottom_value != 0 || pad_right_value != 0; if (do_pad) { Array pad_n({0, 0}); Array pad_c({0, 0}); @@ -676,13 +676,12 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.conv2d"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.conv2d") -.describe(R"code(2D quantized convolution layer. + .describe(R"code(2D quantized convolution layer. This operator convolves quantized weight with quantized data. The scale of the output quantized tensor is the product of the weight_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is @@ -694,18 +693,19 @@ operator to understand how to scale back the int32 output to (u)int8. - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "Tensor", "The quantized input data tensor.") -.add_argument("weight", "Tensor", "The quantized weight tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QnnConv2D", QnnConv2DRel) -.set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) -.set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "Tensor", "The quantized input data tensor.") + .add_argument("weight", "Tensor", "The quantized weight tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QnnConv2D", QnnConv2DRel) + .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) + .set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 7b9733c365864..464b3f9aeff34 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../op/nn/nn.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -72,9 +73,8 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern attrs->units = std::move(units); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dense"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, @@ -173,25 +173,25 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "quantized nD Tensor", "Input data.") -.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QDense", QnnDenseRel) -.set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense") -.set_body_typed(MakeQuantizedDense); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "quantized nD Tensor", "Input data.") + .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QDense", QnnDenseRel) + .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 69389a7317aa6..7c014d71a76a8 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -33,19 +34,16 @@ namespace tvm { namespace relay { namespace qnn { -bool DequantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; - CHECK(input_dtype == DataType::Int(8) || - input_dtype == DataType::UInt(8) || + CHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || input_dtype == DataType::Int(32)) - << "Input type should be one of the quantized types [unit8, int8, int32] but was " - << input_dtype; + << "Input type should be one of the quantized types [unit8, int8, int32] but was " + << input_dtype; // Check the types of scale and zero points. CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale @@ -83,20 +81,19 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dequantize") -.describe(R"code(Dequantizes the input and produces float32 output. + .describe(R"code(Dequantizes the input and produces float32 output. The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to dequantize.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.set_support_level(11) -.add_type_rel("Dequantize", DequantizeRel) -.set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to dequantize.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .set_support_level(11) + .add_type_rel("Dequantize", DequantizeRel) + .set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize") -.set_body_typed(MakeDequantize); +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index 5f9251b35080b..ec74b799407b4 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -24,6 +24,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" #include "op_common.h" @@ -85,21 +86,17 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, auto new_input_zero_point = zero_scalar; // Requantize to get Q_c - output = Requantize(output, input_type.shape, - new_input_scale, - new_input_zero_point, - args.output_scale, - args.output_zero_point, - input_type.dtype); + output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, + args.output_scale, args.output_zero_point, input_type.dtype); return output; } // QNN Multiplication operator. QNN_REGISTER_BINARY_OP("mul") -.describe("Elementwise mul with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); + .describe("Elementwise mul with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index f780f70dc7b30..50fc0cda30cf2 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -28,7 +28,9 @@ #include #include #include + #include + #include "../../op/type_relations.h" #include "../../transforms/infer_layout_util.h" #include "../util.h" @@ -87,10 +89,9 @@ struct QnnBinaryOpArguments { */ struct QnnBinaryOpTensorType { DataType dtype; - Array shape; + Array shape; - explicit QnnBinaryOpTensorType(const Array& arg_types, - const int32_t arg_idx) { + explicit QnnBinaryOpTensorType(const Array& arg_types, const int32_t arg_idx) { CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes); auto tensor_type = arg_types[arg_idx].as(); CHECK(tensor_type != nullptr); @@ -109,8 +110,7 @@ struct QnnBinaryOpTensorType { * \return New expression with target dtype and possibly lower * precision. */ -inline Expr ConvertDtype(const Expr& expr, - const DataType& target_dtype) { +inline Expr ConvertDtype(const Expr& expr, const DataType& target_dtype) { auto q_min = GetQmin(target_dtype); auto q_max = GetQmax(target_dtype); auto output = Clip(expr, q_min, q_max); @@ -134,18 +134,15 @@ inline Expr ConvertDtype(const Expr& expr, * it simply casts the given expression to Int32 as no requantization is * needed in this case. */ -inline Expr RequantizeOrUpcast(const Expr& expr, - const Expr& expr_scale, - const Expr& expr_zero_point, - const Expr& target_scale, - const Expr& target_zero_point, - const Array& expr_shape, +inline Expr RequantizeOrUpcast(const Expr& expr, const Expr& expr_scale, + const Expr& expr_zero_point, const Expr& target_scale, + const Expr& target_zero_point, const Array& expr_shape, const DataType& target_dtype = DataType::Int(32)) { auto result = expr; if (!IsEqualScalar(expr_scale, target_scale) || !IsEqualScalar(expr_zero_point, target_zero_point)) { - result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, - target_scale, target_zero_point, target_dtype); + result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, target_scale, + target_zero_point, target_dtype); } else { result = Cast(result, target_dtype); } @@ -153,27 +150,23 @@ inline Expr RequantizeOrUpcast(const Expr& expr, } /*! \brief Infer layout for QNN binary broadcast operators */ -inline Array > QnnBinaryBroadcastLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { +inline Array > QnnBinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // Use Relay Binary Broadcast Infer correct layout. auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these // tensors can be treated as C. Layout channel_layout = Layout("C"); - Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, channel_layout, channel_layout, channel_layout, channel_layout}; Array output_layouts = layouts[1]; return {input_layouts, output_layouts}; } - -static inline bool QnnBroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +static inline bool QnnBroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes); @@ -201,28 +194,28 @@ static inline bool QnnBroadcastRel(const Array& types, * * \param OpName the name of registry. */ -#define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ - static const Op& op = Op::Get("qnn." OpName); \ - return Call(op, {lhs, rhs, \ - lhs_scale, lhs_zero_point, \ - rhs_scale, rhs_zero_point, \ - output_scale, output_zero_point}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP("qnn." OpName) \ - .set_num_inputs(kNumQnnBinaryOpInputs) \ - .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ - .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ - .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ - .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ - .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ - .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ - .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ - .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ - .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) +#define QNN_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + static const Op& op = Op::Get("qnn." OpName); \ + return Call(op, \ + {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ + output_zero_point}, \ + Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("qnn." OpName) \ + .set_num_inputs(kNumQnnBinaryOpInputs) \ + .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ + .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ + .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ + .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ + .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ + .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ + .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ + .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ + .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 43ba4b6b1ba4f..28f0b8994a014 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -35,24 +36,21 @@ namespace qnn { TVM_REGISTER_NODE_TYPE(QuantizeAttrs); -bool QuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; CHECK(input_dtype == DataType::Float(32)) - << "Input type should be one of float32 but was " << input_dtype; + << "Input type should be one of float32 but was " << input_dtype; const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << quantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << quantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale @@ -130,7 +128,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.quantize") -.describe(R"code(Quantizes the input and produces quantized output. + .describe(R"code(Quantizes the input and produces quantized output. The input can be either float or quantized(int8, unit8). If the input is float, this op takes scale and zero point and quantize the float value to quantized output, in int8 or uint8 format. If the input is quantized value, @@ -140,17 +138,17 @@ scale and zero point. - **data**: Tensor of any shape to quantize. The input data can be of floating point or quantized. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to quantize.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Quantize", QuantizeRel) -.set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize") -.set_body_typed(MakeQuantize); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to quantize.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Quantize", QuantizeRel) + .set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a2a46497e1977..79cb08d3f9482 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -25,8 +25,9 @@ #include #include #include -#include "../../transforms/pattern_util.h" + #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -68,7 +69,7 @@ Array> RequantizeInferCorrectLayout(const Attrs& attrs, for (auto iter_var : new_in_layouts[0]->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); - if (old_dim == layout_dim) { + if (old_dim == layout_dim) { new_axis = tvm::Integer(axis_index); } // Collect only the primal axis. @@ -249,18 +250,16 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); CHECK(data != nullptr); const auto in_dtype = data->dtype; - CHECK(in_dtype == DataType::Int(8) || - in_dtype == DataType::UInt(8) || + CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || in_dtype == DataType::Int(32)) << "Input type should be one of [int8, uint8, int32] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << requantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << requantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale @@ -272,8 +271,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const Array oshape = data->shape; // assign output type auto out_dtype = requantize_attrs->out_dtype; - CHECK(out_dtype == DataType::Int(8) || - out_dtype == DataType::UInt(8) || + CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) << "Output type should be one of [int8, uint8, int32] but was " << out_dtype; reporter->Assign(types[5], TensorType(oshape, out_dtype)); @@ -290,11 +288,11 @@ Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr out attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.requantize"); return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, - Attrs(attrs), {}); + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.requantize") -.describe(R"code(Requantize operator. + .describe(R"code(Requantize operator. The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this @@ -302,20 +300,20 @@ point. The computation looks like this Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The quantized input tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Requantize", RequantizeRel) -.set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) -.set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize") -.set_body_typed(MakeRequantize); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The quantized input tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Requantize", RequantizeRel) + .set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) + .set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc index c6ce3e33f48f0..b928bd5e465c4 100644 --- a/src/relay/qnn/op/subtract.cc +++ b/src/relay/qnn/op/subtract.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -36,8 +37,7 @@ namespace qnn { * \param arg_types The types of input and output. * \return The sequence of Relay ops for add op. */ -Expr QnnSubtractCanonicalize(const Attrs& attrs, - const Array& new_args, +Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { // Get the args. QnnBinaryOpArguments args(new_args); @@ -66,17 +66,13 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // The subtract op is done in int32 precision. // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' - Q_b' auto output = Subtract(requantized_lhs, requantized_rhs); @@ -93,10 +89,9 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // QNN Subtraction operator. QNN_REGISTER_BINARY_OP("subtract") -.describe("Elementwise subtract with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); - + .describe("Elementwise subtract with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 91fe3ca2a9488..7171ded765b9c 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -23,6 +23,7 @@ */ #include "util.h" + #include "../transforms/pattern_util.h" namespace tvm { @@ -48,8 +49,7 @@ namespace qnn { * * Credit to TFLite reference implementation. */ -std::pair GetFixedPointMultiplierShift( - double double_multiplier) { +std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { significand = 0; @@ -84,8 +84,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // 1) Calculating the integer multiplier and integer shift int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = - GetFixedPointMultiplierShift(multiplier); + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; @@ -119,8 +118,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = - Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } @@ -128,8 +126,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& tensor = Add(tensor, round_scalar); // 5) Simply right shift the result to get the final output. - tensor = - RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. return Cast(tensor, DataType::Int(32)); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index d4046ae906071..736b7361a300a 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -25,14 +25,15 @@ #ifndef TVM_RELAY_QNN_UTIL_H_ #define TVM_RELAY_QNN_UTIL_H_ -#include -#include #include #include +#include +#include + #include #include -#include #include +#include namespace tvm { namespace relay { @@ -46,8 +47,7 @@ static inline Array get_shape(const Type& type) { } static inline int32_t GetQmin(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* min_value = tir::as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); @@ -59,8 +59,7 @@ static inline int32_t GetQmin(const DataType& dtype) { } static inline int32_t GetQmax(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* max_value = tir::as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); @@ -171,8 +170,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons const TypeReporter& reporter) { // Scale/Zero_points can be either const scalar or a vector with C axis num elems. const auto* tensor_type = expr_type.as(); - CHECK(tensor_type) << "Can assign type to Tensor type only. But got " - << AsText(expr_type, false); + CHECK(tensor_type) << "Can assign type to Tensor type only. But got " << AsText(expr_type, false); const auto tensor_dtype = tensor_type->dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; if (tensor_type->shape.size() != 0) { diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 4492ed5bebca4..8ae7df9e29412 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -24,8 +24,9 @@ * \brief Annotating the graph with simulated quantize operators. */ -#include #include +#include + #include "./quantize.h" namespace tvm { @@ -63,10 +64,7 @@ class QAnnotateExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); }; - -Expr QAnnotateExprNode::Realize() const { - return expr; -} +Expr QAnnotateExprNode::Realize() const { return expr; } QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { auto rnode = make_object(); @@ -75,12 +73,10 @@ QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr") -.set_body_typed([](Expr expr, int kind) { +TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) { return QAnnotateExpr(expr, static_cast(kind)); }); - Pass QuantizeAnnotate() { // TODO(tvm-teams): since partition has added cast_hint in different // branches, try to remove this in the future. @@ -88,8 +84,7 @@ Pass QuantizeAnnotate() { if (e->IsInstance()) { const auto* n = e.as(); CHECK(n); - const PackedFunc* f = - runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); Expr ret = (*f)(n->expr, static_cast(kQInput)); return static_cast(QAnnotateExpr(ret, kQInput)); } @@ -97,23 +92,18 @@ Pass QuantizeAnnotate() { }; runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); - auto new_params = func->params; - for (const auto& x : FreeVars(func)) { - new_params.push_back(x); - } - return Function(new_params, - func->body, - func->ret_type, - func->type_params, - func->attrs); - }; + [=](Function f, IRModule m, PassContext pc) { + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); + auto new_params = func->params; + for (const auto& x : FreeVars(func)) { + new_params.push_back(x); + } + return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate") -.set_body_typed(QuantizeAnnotate); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate); TVM_REGISTER_NODE_TYPE(QAnnotateExprNode); diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 7b1e909501b5f..ea42a198bf849 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -26,7 +26,9 @@ #include #include #include + #include + #include "./quantize.h" namespace tvm { @@ -65,8 +67,8 @@ static std::vector SmoothDistribution(const std::vector& p, } static float ComputeEntropy(float* p, float* q, size_t size) { - float p_sum = std::accumulate(p, p+size, 0.f); - float q_sum = std::accumulate(q, q+size, 0.f); + float p_sum = std::accumulate(p, p + size, 0.f); + float q_sum = std::accumulate(q, q + size, 0.f); float ret = 0; for (size_t i = 0; i < size; i++) { CHECK(p[i] > 0 && q[i] > 0); @@ -77,9 +79,8 @@ static float ComputeEntropy(float* p, float* q, size_t size) { return ret; } -float MinimizeKL(const std::vector& hist, - const std::vector& hist_edges, - int num_bins, int num_quantized_bins) { +float MinimizeKL(const std::vector& hist, const std::vector& hist_edges, int num_bins, + int num_quantized_bins) { const int zero_bin_idx = num_bins / 2; const int num_half_quantized_bins = num_quantized_bins / 2; std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); @@ -137,9 +138,9 @@ float MinimizeKL(const std::vector& hist, divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size()); } } - auto min_divergence_idx = std::distance(divergence.begin(), - std::min_element(divergence.begin(), divergence.end())); - return thresholds[min_divergence_idx];; + auto min_divergence_idx = + std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end())); + return thresholds[min_divergence_idx]; } class StatsCollector : private ExprMutator { @@ -152,7 +153,7 @@ class StatsCollector : private ExprMutator { CHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + func->attrs); } private: @@ -167,7 +168,7 @@ class StatsCollector : private ExprMutator { auto attrs = new_call->attrs.as(); // rewrite the annotation auto new_attrs = make_object(); - const Expr& quantize_input = new_call->args[0]; // expression being quantized + const Expr& quantize_input = new_call->args[0]; // expression being quantized auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; new_attrs->kind = QAnnotateKind::kQIdentity; @@ -198,24 +199,20 @@ class StatsCollector : private ExprMutator { * \param expr The simulation graph after annotation. * \return The profile graph. */ -Expr CreateStatsCollector(const Expr& expr) { - return StatsCollector().Collect(expr); -} - -TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector") -.set_body_typed(CreateStatsCollector); +Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } +TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector").set_body_typed(CreateStatsCollector); TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int* hist_ptr = static_cast(static_cast(args[0])); - float* hist_edges_ptr = static_cast(static_cast(args[1])); - int num_bins = args[2]; - int num_quantized_bins = args[3]; - std::vector hist(hist_ptr, hist_ptr + num_bins); - std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); - ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int* hist_ptr = static_cast(static_cast(args[0])); + float* hist_edges_ptr = static_cast(static_cast(args[1])); + int num_bins = args[2]; + int num_quantized_bins = args[3]; + std::vector hist(hist_ptr, hist_ptr + num_bins); + std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); + ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); + }); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index 39de0bc49d4ca..14b420d6034c0 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -25,6 +25,7 @@ */ #include + #include "../transforms/pattern_util.h" #include "./quantize.h" @@ -34,16 +35,13 @@ namespace quantize { using namespace relay::transform; - class QPartitionExpr; class QPartitionExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr expr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("expr", &expr); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } Expr Realize() const final; @@ -62,7 +60,6 @@ class QPartitionExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); }; - Expr QPartitionExprNode::Realize() const { // insert cast hint and stop fusion const QConfig& cfg = QConfig::Current(); @@ -76,23 +73,20 @@ QPartitionExpr::QPartitionExpr(Expr expr) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") -.set_body_typed([](Expr expr) { +TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr").set_body_typed([](Expr expr) { return QPartitionExpr(expr); }); Pass QuantizePartition() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto ret = Downcast( - ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); - return ret; - }; + [=](Function f, IRModule m, PassContext pc) { + auto ret = Downcast(ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); + return ret; + }; return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition") -.set_body_typed(QuantizePartition); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition").set_body_typed(QuantizePartition); TVM_REGISTER_NODE_TYPE(QPartitionExprNode); diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 631d8c0fdf582..d197458154fbe 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -23,12 +23,13 @@ * \brief transform a graph to a low-bit graph * for compression and acceleration. */ +#include "./quantize.h" + #include #include #include -#include -#include "./quantize.h" +#include namespace tvm { namespace relay { @@ -36,9 +37,7 @@ namespace quantize { TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); -bool SimulatedQuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SimulatedQuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 5); const auto param = attrs.as(); @@ -48,36 +47,34 @@ bool SimulatedQuantizeRel(const Array& types, CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale - reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min - reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max - reporter->Assign(types[4], types[0]); // output + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale + reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max + reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) -.set_num_inputs(4) -.add_argument("data", "Tensor", "The input data.") -.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") -.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") -.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") -.set_attrs_type() -.set_support_level(11) -.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); + .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input data.") + .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") + .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") + .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") + .set_attrs_type() + .set_support_level(11) + .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") -.set_body_typed( - [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, - int kind, bool sign, std::string rounding) { - auto attrs = make_object(); - attrs->kind = kind; - attrs->sign = sign; - attrs->rounding = rounding; - static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); - return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); - }); - + .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, + std::string rounding) { + auto attrs = make_object(); + attrs->kind = kind; + attrs->sign = sign; + attrs->rounding = rounding; + static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); + return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); + }); /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { @@ -87,26 +84,24 @@ struct TVMQConfigThreadLocalEntry { /*! \brief The current build config context */ std::stack context_stack; - TVMQConfigThreadLocalEntry() : - default_config(make_object()) { - } + TVMQConfigThreadLocalEntry() : default_config(make_object()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } @@ -117,31 +112,31 @@ QConfig& QConfig::Current() { TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* op = static_cast(ref.get()); - p->stream << "qconfig("; - p->stream << "nbit_input=" << op->nbit_input << ", "; - p->stream << "nbit_weight=" << op->nbit_weight << ", "; - p->stream << "nbit_activation=" << op->nbit_activation << ", "; - p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; - p->stream << "global_scale=" << op->global_scale << ", "; - p->stream << "weight_scale=" << op->weight_scale << ", "; - p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; - p->stream << "do_simulation==" << op->do_simulation << ", "; - p->stream << "round_for_shift==" << op->round_for_shift << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", "; - p->stream << "rounding==" << op->rounding; - p->stream << ")"; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* op = static_cast(ref.get()); + p->stream << "qconfig("; + p->stream << "nbit_input=" << op->nbit_input << ", "; + p->stream << "nbit_weight=" << op->nbit_weight << ", "; + p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; + p->stream << "global_scale=" << op->global_scale << ", "; + p->stream << "weight_scale=" << op->weight_scale << ", "; + p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "do_simulation==" << op->do_simulation << ", "; + p->stream << "round_for_shift==" << op->round_for_shift << ", "; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; + p->stream << "rounding==" << op->rounding; + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig { + return QConfig::Current(); }); -TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") -.set_body_typed(QConfig::Current); - TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") -.set_body_typed(QConfig::EnterQConfigScope); + .set_body_typed(QConfig::EnterQConfigScope); -TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope") -.set_body_typed(QConfig::ExitQConfigScope); +TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 563f47f569339..a883cb1922d2b 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -24,9 +24,11 @@ #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_ #define TVM_RELAY_QUANTIZE_QUANTIZE_H_ -#include #include +#include + #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -34,12 +36,7 @@ namespace relay { namespace quantize { /*! \brief Kind of annotate field */ -enum QAnnotateKind : int { - kQIdentity = 0, - kQInput = 1, - kQWeight = 2, - kQActivation = 3 -}; +enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 }; /*! \brief Attribute for simulated quantize operator */ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { @@ -48,20 +45,17 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { std::string rounding; TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round").describe( + "rounding mode. Can be 'floor', 'ceil', 'round'"); } }; - class QConfig; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfigNode : public Object { public: int nbit_input = 8; @@ -103,20 +97,16 @@ class QConfigNode : public Object { }; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfig : public ObjectRef { public: QConfig() {} explicit QConfig(ObjectPtr n) : ObjectRef(n) {} - const QConfigNode* operator->() const { - return static_cast(get()); - } + const QConfigNode* operator->() const { return static_cast(get()); } - QConfigNode* operator->() { - return static_cast(get_mutable()); - } + QConfigNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Push a new BuildConfig context onto the thread local stack. @@ -150,14 +140,10 @@ struct QConfigContext { * context. When the BuildConfigContext is destructed, the previous context is restored. * \param build_config The BuildConfig to set as the new current context. */ - explicit QConfigContext(const QConfig& qconfig) { - QConfig::EnterQConfigScope(qconfig); - } + explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } /*! \brief Destructor. Pops the context off the thread local stack. */ - ~QConfigContext() { - QConfig::ExitQConfigScope(); - } + ~QConfigContext() { QConfig::ExitQConfigScope(); } }; } // namespace quantize diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 6d56e19d229cc..49d1e522f7d74 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -25,12 +25,13 @@ * graph. */ -#include #include #include -#include "./quantize.h" -#include "../transforms/pattern_util.h" +#include + #include "../qnn/util.h" +#include "../transforms/pattern_util.h" +#include "./quantize.h" namespace tvm { namespace relay { @@ -53,7 +54,6 @@ class QRealizeExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode); }; - class QRealizeIntExprNode : public QRealizeExprNode { public: Expr dom_scale; @@ -67,7 +67,7 @@ class QRealizeIntExprNode : public QRealizeExprNode { Expr Realize() const final; - static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; + static constexpr const char* _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; @@ -78,7 +78,6 @@ class QRealizeIntExpr : public QRealizeExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); }; - Expr QRealizeIntExprNode::Realize() const { Expr data = this->data; // dequantize @@ -95,15 +94,13 @@ QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) { data_ = std::move(n); } - inline Expr ForwardOp(const Call& ref_call, const Array& args) { return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args); } - /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, - const Array &data_shape) { + const Array& data_shape) { const QConfig& cfg = QConfig::Current(); // here we assume the dtype of data is dtype activation if (s1 == s2) return data; @@ -112,8 +109,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(dtype, - static_cast(shift_factor))); + return LeftShift(data, MakeConstantScalar(dtype, static_cast(shift_factor))); } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { @@ -122,9 +118,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } } -Expr QuantizeRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -158,22 +152,20 @@ Expr QuantizeRealize(const Call& ref_call, // use right shift if (cfg->round_for_shift) { float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(round_bias))); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = RightShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } else { - data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = LeftShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, - cfg->rounding); + ref_call->type_as()->shape, cfg->rounding); data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } @@ -195,12 +187,9 @@ Expr FoldConstantOpt(const Expr& expr) { } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.set_attr("FQRealizeRewrite", QuantizeRealize); - + .set_attr("FQRealizeRewrite", QuantizeRealize); -Expr Conv2dRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() && !new_args[1]->IsInstance()) { @@ -223,20 +212,15 @@ Expr Conv2dRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FQRealizeRewrite", Conv2dRealize); - +RELAY_REGISTER_OP("nn.conv2d").set_attr("FQRealizeRewrite", Conv2dRealize); -Expr DenseRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr DenseRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { @@ -257,20 +241,15 @@ Expr DenseRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.dense") -.set_attr("FQRealizeRewrite", DenseRealize); +RELAY_REGISTER_OP("nn.dense").set_attr("FQRealizeRewrite", DenseRealize); - -Expr MulRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { @@ -297,9 +276,7 @@ Expr MulRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("multiply") -.set_attr("FQRealizeRewrite", MulRealize); - +RELAY_REGISTER_OP("multiply").set_attr("FQRealizeRewrite", MulRealize); float ChooseDomScale(const std::vector& nptrs) { if (nptrs.size() == 2) { @@ -316,7 +293,6 @@ float ChooseDomScale(const std::vector& nptrs) { } } - /* \brief Unify the dom scale of arguments */ Array UnifyDTypeScale(const Array& ref_args, const Array& args, DataType* dtype_ptr, Expr* scale_ptr) { @@ -366,9 +342,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args return ret; } -Expr AddRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { DataType dtype; @@ -382,12 +356,9 @@ Expr AddRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("add") -.set_attr("FQRealizeRewrite", AddRealize); +RELAY_REGISTER_OP("add").set_attr("FQRealizeRewrite", AddRealize); -Expr ClipRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ClipRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { const auto ref_attrs = ref_call->attrs.as(); @@ -396,21 +367,16 @@ Expr ClipRealize(const Call& ref_call, attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; - Expr ret = Call(ref_call->op, - {n->data}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); return QRealizeIntExpr(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); } -RELAY_REGISTER_OP("clip") -.set_attr("FQRealizeRewrite", ClipRealize); - +RELAY_REGISTER_OP("clip").set_attr("FQRealizeRewrite", ClipRealize); -Expr ConcatenateRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); @@ -435,14 +401,10 @@ Expr ConcatenateRealize(const Call& ref_call, } } -RELAY_REGISTER_OP("concatenate") -.set_attr("FQRealizeRewrite", ConcatenateRealize); - +RELAY_REGISTER_OP("concatenate").set_attr("FQRealizeRewrite", ConcatenateRealize); /* \brief forward the original operator */ -Expr IdentityRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr IdentityRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); @@ -452,18 +414,15 @@ Expr IdentityRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", IdentityRealize); -RELAY_REGISTER_OP("strided_slice") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("annotation.stop_fusion") -.set_attr("FQRealizeRewrite", IdentityRealize); + .set_attr("FQRealizeRewrite", IdentityRealize); /* \brief for unary operators which requantize its input to dtype_nbit */ -Expr CastDtypeInputRealize(const Call& ref_call, - const Array& new_args, +Expr CastDtypeInputRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); @@ -477,12 +436,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, } RELAY_REGISTER_OP("nn.max_pool2d") -.set_attr("FQRealizeRewrite", CastDtypeInputRealize); - + .set_attr("FQRealizeRewrite", CastDtypeInputRealize); -Expr AvgPoolRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AvgPoolRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -497,15 +453,12 @@ Expr AvgPoolRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); +RELAY_REGISTER_OP("nn.avg_pool2d").set_attr("FQRealizeRewrite", AvgPoolRealize); RELAY_REGISTER_OP("nn.global_avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); + .set_attr("FQRealizeRewrite", AvgPoolRealize); -Expr CastHintRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr CastHintRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const auto param = ref_call->attrs.as(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -517,19 +470,17 @@ Expr CastHintRealize(const Call& ref_call, } RELAY_REGISTER_OP("annotation.cast_hint") -.set_attr("FQRealizeRewrite", CastHintRealize); + .set_attr("FQRealizeRewrite", CastHintRealize); Pass QuantizeRealizePass() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); + }; return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize") -.set_body_typed(QuantizeRealizePass); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass); } // namespace quantize } // namespace relay diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index aab0b3a30a7cf..7b91e8cca1d3d 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -85,8 +85,8 @@ class AlterTransformMemorizer : public TransformMemorizer { } // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes. // Probably we need to disable the AlterOpLayout when compiling dynamic models. - Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos, - ref_call->checked_type()); + Expr altered_value = + falter_layout[op](ref_call->attrs, new_args, tinfos, ref_call->checked_type()); if (altered_value.defined()) { new_e = altered_value; modified = true; @@ -122,14 +122,13 @@ namespace transform { Pass AlterOpLayout() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::alter_op_layout::AlterOpLayout(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") -.set_body_typed(AlterOpLayout); +TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout").set_body_typed(AlterOpLayout); } // namespace transform diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 4caac042f0161..36359473e9c78 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -299,8 +299,7 @@ Pass AnnotateTarget(const Array& targets) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); }; - auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {"InferType"}); + auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index ebcbd578b5f0b..f47810739143d 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -22,9 +22,10 @@ * \brief Canonicalize cast expressions to make operator fusion more efficient. */ #include -#include #include +#include #include + #include "pass_util.h" #include "pattern_util.h" @@ -112,8 +113,7 @@ class CastCanonicalizer : public ExprMutator { const CallNode* new_call = new_expr.as(); CHECK(new_call); CHECK(new_call->op == cast_op_); - return Call(new_call->op, new_call->args, new_call->attrs, - new_call->type_args); + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); } } } @@ -122,22 +122,19 @@ class CastCanonicalizer : public ExprMutator { } }; -Expr CanonicalizeCast(const Expr& e) { - return CastCanonicalizer().Mutate(e); -} +Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); } namespace transform { Pass CanonicalizeCast() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeCast(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") -.set_body_typed(CanonicalizeCast); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast); } // namespace transform diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 1d3111b29d7d5..fec757ee68d54 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -23,10 +23,11 @@ This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ #include +#include #include #include -#include #include + #include "pattern_util.h" namespace tvm { @@ -71,14 +72,13 @@ namespace transform { Pass CanonicalizeOps() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeOps(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") -.set_body_typed(CanonicalizeOps); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps").set_body_typed(CanonicalizeOps); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index af6b1353f5acf..1990414c3aa45 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -33,15 +33,17 @@ */ #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -50,13 +52,10 @@ namespace relay { class ParallelConv2DCombiner : public ParallelOpCombiner { public: explicit ParallelConv2DCombiner(uint64_t min_num_branches) - : ParallelOpCombiner("nn.conv2d", min_num_branches) { - } + : ParallelOpCombiner("nn.conv2d", min_num_branches) {} protected: - bool IsSupportedOp(const CallNode* n) { - return n->attrs.as()->groups == 1; - } + bool IsSupportedOp(const CallNode* n) { return n->attrs.as()->groups == 1; } bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { StructuralEqual eq; @@ -67,10 +66,10 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as(); const auto* tweight_b = b->args[1]->type_as(); - const auto shape_a = tir::BijectiveLayout( - Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); - const auto shape_b = tir::BijectiveLayout( - Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); + const auto shape_a = + tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); + const auto shape_b = + tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && @@ -118,8 +117,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto toutput_a = a->type_as(); auto toutput_b = b->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; // Position of the 'C' dimension in the argument size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); @@ -132,15 +130,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { for (size_t i = 0; i < ta->shape.size(); i++) { if (i == arg_channel_pos) continue; - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -166,9 +161,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return Call(call->op, new_args, call->attrs, {}); } - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int64_t index = 0; for (const auto& branch : branches) { @@ -217,14 +210,13 @@ namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelConv2D(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") -.set_body_typed(CombineParallelConv2D); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D").set_body_typed(CombineParallelConv2D); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 1278020ac7353..8613dbe1466e8 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -32,16 +32,18 @@ */ #include -#include #include #include +#include #include #include + #include #include + +#include "./combine_parallel_op_batch.h" #include "./expr_subst.h" #include "pattern_util.h" -#include "./combine_parallel_op_batch.h" namespace tvm { namespace relay { @@ -49,8 +51,7 @@ namespace relay { class ParallelDenseCombiner : public ParallelOpBatchCombiner { public: explicit ParallelDenseCombiner(uint64_t min_num_branches) - : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) { - } + : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} protected: virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { @@ -63,8 +64,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { const auto* weight_b = b->args[1]->type_as(); return eq(attrs_a->out_dtype, attrs_b->out_dtype) && - eq(weight_a->shape[0], weight_b->shape[0]) && - eq(weight_a->shape[1], weight_b->shape[1]); + eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]); } }; @@ -77,14 +77,13 @@ namespace transform { Pass CombineParallelDense(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelDense(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") -.set_body_typed(CombineParallelDense); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc index a7f7af2b79e52..854a1aee5a5e3 100644 --- a/src/relay/transforms/combine_parallel_op.cc +++ b/src/relay/transforms/combine_parallel_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,33 +23,33 @@ * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ +#include "combine_parallel_op.h" + #include #include -#include #include #include +#include #include #include #include + #include -#include #include #include +#include + #include "expr_subst.h" #include "pattern_util.h" -#include "combine_parallel_op.h" - namespace tvm { namespace relay { -BranchGroupFinder::BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, +BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops) - : cached_op_(op), - fis_supported_op_(fis_supported_op), - fare_compatible_ops_(fare_compatible_ops) { -} + : cached_op_(op), + fis_supported_op_(fis_supported_op), + fare_compatible_ops_(fare_compatible_ops) {} std::vector BranchGroupFinder::Find(const Expr& expr) { this->VisitExpr(expr); @@ -111,18 +111,13 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { } ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) - : cached_op_(Op::Get(op_name)), - min_num_branches_(min_num_branches) { -} + : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {} Expr ParallelOpCombiner::Combine(const Expr& expr) { - auto groups = BranchGroupFinder(cached_op_, - [&](const CallNode* n) { - return IsSupportedOp(n); - }, - [&](const CallNode* a, const CallNode* b) { - return CanOpsBeCombined(a, b); - }).Find(expr); + auto groups = BranchGroupFinder( + cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, + [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); }) + .Find(expr); for (const Group& group : groups) { if (group.size() < min_num_branches_) { continue; @@ -135,10 +130,9 @@ Expr ParallelOpCombiner::Combine(const Expr& expr) { void ParallelOpCombiner::CombineBranches(const Group& branches) { Call combined = MakeCombinedOp(branches); auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); + [](const Branch& branch_a, const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); size_t depth = it->size(); size_t i; // starting from 1 to skip the op @@ -155,32 +149,30 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { - const CallNode* call = branches[0][depth]; - tvm::StructuralEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } + const CallNode* call = branches[0][depth]; + tvm::StructuralEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false; - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; - if (!IsArgCompatible(call, branch[depth], i) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; } } - return true; } + return true; +} } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/combine_parallel_op.h b/src/relay/transforms/combine_parallel_op.h index 0097e29b13ea4..23fe347544302 100644 --- a/src/relay/transforms/combine_parallel_op.h +++ b/src/relay/transforms/combine_parallel_op.h @@ -26,26 +26,27 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ #include -#include #include #include +#include #include #include + +#include #include #include #include -#include + #include "./expr_subst.h" #include "pattern_util.h" - namespace tvm { namespace relay { using Branch = std::vector; using Group = std::vector; -using FIsSupportedOp = std::function; -using FAreCompatibleOps = std::function; +using FIsSupportedOp = std::function; +using FAreCompatibleOps = std::function; using ExprSubstMap = std::unordered_map; /* @@ -74,8 +75,7 @@ class BranchGroupFinder : private ExprVisitor { * \param fare_compatible_ops function that returns true if * two ops are compatible for combining */ - BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, + BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); /* @@ -188,10 +188,8 @@ class ParallelOpCombiner { * all combined ops * \return new combined call */ - virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, - size_t parent_index) = 0; + virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, + size_t depth, size_t parent_index) = 0; /* * \brief Updates map of expr to substitute with combined expr. This usually involves @@ -201,9 +199,7 @@ class ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - virtual void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) = 0; private: diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 361565ef11d76..5cd287c927f40 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -44,17 +44,20 @@ * */ +#include "./combine_parallel_op_batch.h" + #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" -#include "./combine_parallel_op_batch.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -63,13 +66,9 @@ namespace relay { ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) - : ParallelOpCombiner(op_name, min_num_branches), - batch_op_name_(batch_op_name) { -} + : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {} -bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { - return true; -} +bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; } bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { if (a->args.size() != b->args.size()) { @@ -116,19 +115,16 @@ bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; for (size_t i = 0; i < ta->shape.size(); i++) { - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -160,10 +156,8 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, return Call(call->op, new_args, call->attrs, {}); } -void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, - ExprSubstMap* subst_map) { +void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, + size_t depth, ExprSubstMap* subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { @@ -174,30 +168,25 @@ void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, } /*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ -Expr CombineParallelOpBatch(const Expr& expr, - const std::string& op_name, - const std::string& batch_op_name, - uint64_t min_num_branches) { +Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name, + const std::string& batch_op_name, uint64_t min_num_branches) { return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); } namespace transform { -Pass CombineParallelOpBatch(const std::string& op_name, - const std::string& batch_op_name, +Pass CombineParallelOpBatch(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelOpBatch(f, - op_name, - batch_op_name, - min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast( + CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") -.set_body_typed(CombineParallelOpBatch); + .set_body_typed(CombineParallelOpBatch); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h index 6876604339466..9f87d9d2184f2 100644 --- a/src/relay/transforms/combine_parallel_op_batch.h +++ b/src/relay/transforms/combine_parallel_op_batch.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,16 +25,18 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ #include -#include #include #include +#include #include #include + +#include #include #include -#include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -68,8 +70,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param min_num_branches min number of parallel branches beginning with op * to start combining */ - ParallelOpBatchCombiner(const std::string& op_name, - const std::string& batch_op_name, + ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches); protected: @@ -116,9 +117,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * all combined ops * \return new combined call as batch op by stacking args */ - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) final; /* @@ -129,15 +128,13 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) final; private: /* \brief name of op to replace combined ops with. for example, * for combining parallel dense, this will will be set to - * nn.batch_matmul + * nn.batch_matmul */ std::string batch_op_name_; }; diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index dbb2c38e3f274..f43c8f63b4df9 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -132,8 +132,7 @@ Pass ConvertLayout(const std::string& desired_layout) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; - return CreateFunctionPass( - pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); + return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 1b83e7188df27..36aaa478eab60 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -133,20 +133,12 @@ Pass DenseToSparse(const Array& weight_name, // Remove FreeVar warnings auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, - f0->body, - f0->ret_type, - f0->type_params, - f0->attrs); + auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, - f1->body, - f1->ret_type, - f1->type_params, - f1->attrs); + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); }; return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 48b8666856a68..1c250f102e77a 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -23,17 +23,15 @@ * \brief Use a fresh Id for every Var to make the result well-formed. */ #include -#include #include +#include #include namespace tvm { namespace relay { Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, - public ExprMutator, - public PatternMutator { + class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { TypeVar ret = TypeVar(tv->name_hint, tv->kind); @@ -65,9 +63,7 @@ Expr DeDup(const Expr& e) { return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } - Type VisitType(const Type& t) final { - return t.defined() ? TypeMutator::VisitType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } Expr VisitExpr_(const FunctionNode* op) final { tvm::Array type_params; @@ -78,29 +74,19 @@ Expr DeDup(const Expr& e) { for (const Var& param : op->params) { params.push_back(Fresh(param)); } - return Function(params, - VisitExpr(op->body), - VisitType(op->ret_type), - type_params, - op->attrs); + return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); } - Pattern VisitPattern(const Pattern& p) final { - return PatternFunctor::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(Fresh(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); } Type VisitType_(const TypeVarNode* op) final { TypeVar v = GetRef(op); return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; } - Var VisitVar(const Var& v) final { - return Fresh(v); - } + Var VisitVar(const Var& v) final { return Fresh(v); } private: std::unordered_map rename_; @@ -113,8 +99,7 @@ Expr DeDup(const Expr& e) { return ret; } -TVM_REGISTER_GLOBAL("relay._transform.dedup") -.set_body_typed(DeDup); +TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index a0d093f197d6d..9aa0f49baa514 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -30,12 +30,13 @@ #include #include #include + #include "let_list.h" namespace tvm { namespace relay { -template +template using VarMap = std::unordered_map; using VarSet = std::unordered_set; @@ -59,20 +60,18 @@ class Eliminator : private ExprMutator { VarMap expr_map_; VarMap use_map_; bool inline_once_; - explicit Eliminator(const VarMap& expr_map, - const VarMap& use_map, - bool inline_once) : - expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { } + explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, bool inline_once) + : expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) {} friend CalcDep; bool HasLet(const Var& v) { switch (use_map_[v]) { - case 0: - return false; - case 1: - return !inline_once_; - default: - return true; + case 0: + return false; + case 1: + return !inline_once_; + default: + return true; } } @@ -104,8 +103,7 @@ class CalcDep : protected MixedModeVisitor { } private: - explicit CalcDep(const VarMap& expr_map) - : MixedModeVisitor(2), expr_map_(expr_map) {} + explicit CalcDep(const VarMap& expr_map) : MixedModeVisitor(2), expr_map_(expr_map) {} VarMap expr_map_; VarMap use_map_; @@ -123,9 +121,7 @@ class CalcDep : protected MixedModeVisitor { } } - void VisitExpr_(const LetNode* l) final { - VisitExpr(l->body); - } + void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); } void VisitExpr_(const VarNode* v) final { Var var = GetRef(v); @@ -144,14 +140,13 @@ namespace transform { Pass DeadCodeElimination(bool inline_once) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(DeadCodeElimination(f, inline_once)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DeadCodeElimination(f, inline_once)); + }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination") -.set_body_typed(DeadCodeElimination); +TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index d5e1d2efd8e6d..39cf563f730a6 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -28,12 +28,12 @@ * 3. Collect the device allocation of each expression. */ -#include -#include #include +#include #include #include #include +#include #include #include @@ -103,8 +103,7 @@ class ValidateAnnotation : private ExprVisitor { * \return The device type. */ int GetDeviceId(const CallNode* call_node) { - CHECK(IsOnDeviceNode(call_node)) - << "The input call node must be on_device node."; + CHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node."; const OnDeviceAttrs* on_device_attr = call_node->attrs.as(); return on_device_attr->device_type; } @@ -160,8 +159,7 @@ class RewriteAnnotation : public ExprMutator { Expr VisitExpr_(const TupleGetItemNode* op) final { Expr tuple = op->tuple; if (NeedDeviceCopy(tuple.operator->(), op)) { - Expr new_expr = - TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); + Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); UpdateAnnotationMap(op, new_expr.operator->()); return this->VisitExpr(new_expr); } else { @@ -201,8 +199,7 @@ class RewriteAnnotation : public ExprMutator { } if (annotated) { - Call new_call = Call(call_node->op, new_args, call_node->attrs, - call_node->type_args); + Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); UpdateAnnotationMap(call_node, new_call.operator->()); return this->VisitExpr(new_call); @@ -235,8 +232,7 @@ class RewriteAnnotation : public ExprMutator { return CreateDeviceCopy(src, fallback_device_, dit->second); } else { const auto dit = annotation_map_.find(dst); - int dst_dev_type = - dit == annotation_map_.end() ? fallback_device_ : dit->second; + int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second; return CreateDeviceCopy(src, sit->second, dst_dev_type); } } @@ -301,6 +297,7 @@ class AnnotatationVisitor : private ExprVisitor { visitor(expr); return visitor.annotations_; } + private: void VisitExpr_(const CallNode* call_node) { if (IsOnDeviceNode(call_node)) { @@ -414,9 +411,7 @@ class DeviceInfo { // TODO(zhiics) Skip annotation of tuple node for now. } - void VisitExpr_(const TupleGetItemNode* op) final { - ExprVisitor::VisitExpr_(op); - } + void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); @@ -432,7 +427,6 @@ class DeviceInfo { post_dfs_order_.push_back(std::make_pair(in, has_copy_)); } - int num_device_copy_ops_{0}; bool has_copy_ = false; std::vector> post_dfs_order_; @@ -479,25 +473,23 @@ class DeviceInfo { const auto* attrs = last_copy_node->attrs.as(); cur_dev_type = attrs->src_dev_type; if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; - if (it->second) device_map_.Set(GetRef(it->first), - attrs->dst_dev_type); + if (it->second) device_map_.Set(GetRef(it->first), attrs->dst_dev_type); } else if (last_copy_node) { Expr expr = GetRef(it->first); CHECK_EQ(device_map_.count(expr), 0U); if (it->second) device_map_.Set(expr, cur_dev_type); } } - return out_dev_type; + return out_dev_type; } void FillPropagation(int out_dev_type) { for (const auto& it : post_visitor_.post_dfs_order_) { - Expr expr = GetRef(it.first); - if (!it.second) device_map_.Set(expr, out_dev_type); + Expr expr = GetRef(it.first); + if (!it.second) device_map_.Set(expr, out_dev_type); } } - PostDfsOrderVisitor post_visitor_; Map device_map_; }; @@ -521,14 +513,12 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } CHECK_GT(new_body.size(), 0U); if (new_body.size() == 1) { - return Function(params, new_body[0], Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); } else if (tuple->fields.size() == new_body.size()) { - return new_expr; + return new_expr; } else { Tuple tuple_body = Tuple(new_body); - return Function(params, tuple_body, Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); } } else { return new_expr; @@ -544,40 +534,35 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { if (tuple->fields.size() == new_fields.size()) { return new_fields.size() == 1 ? new_fields[0] : new_expr; } else { - return new_fields.size() == 1 ? new_fields[0] - : Tuple(new_fields); + return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields); } } else { return new_expr; } } -Map CollectDeviceInfo(const Expr& expr) { - return DeviceInfo::GetDeviceMap(expr); -} +Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo") -.set_body_typed(CollectDeviceInfo); +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo); TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") -.set_body_typed(CollectDeviceAnnotationOps); + .set_body_typed(CollectDeviceAnnotationOps); namespace transform { Pass RewriteAnnotatedOps(int fallback_device) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); + }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") -.set_body_typed(RewriteAnnotatedOps); +TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps); } // namespace transform diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 68c59f5ea2ef1..2861f32d52a36 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -29,7 +29,9 @@ #include #include #include + #include + #include "pattern_util.h" namespace tvm { @@ -37,7 +39,7 @@ namespace relay { class CommonSubexprEliminator : public ExprMutator { public: - explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip): fskip_(fskip) {} + explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} Expr VisitExpr_(const CallNode* call) final { static auto op_stateful = Op::GetAttr("TOpIsStateful"); @@ -88,14 +90,14 @@ namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(EliminateCommonSubexpr(f, fskip)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") -.set_body_typed(EliminateCommonSubexpr); + .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index c720bdfa14ee7..5b43d07ee1107 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -24,9 +24,9 @@ * */ #include +#include #include #include -#include namespace tvm { namespace relay { @@ -62,16 +62,14 @@ class EtaExpander : public ExprMutator { type_var_replacer_(TypeVarReplacer()), expand_constructor_(expand_constructor), expand_global_var_(expand_global_var) { - CHECK(expand_constructor || expand_global_var) - << "must expand at least one language feature"; + CHECK(expand_constructor || expand_global_var) << "must expand at least one language feature"; } IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); if (auto* n = base_func.as()) { - const Function new_func = Downcast( - VisitExpr(GetRef(n))); + const Function new_func = Downcast(VisitExpr(GetRef(n))); mod_->Update(global_var, new_func); } } @@ -111,11 +109,8 @@ class EtaExpander : public ExprMutator { Expr body = Call(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); - return Function( - Downcast>(params), - body, - ret_type, - Downcast>(type_params)); + return Function(Downcast>(params), body, ret_type, + Downcast>(type_params)); } Expr VisitExpr_(const GlobalVarNode* gvar_node) final { @@ -124,7 +119,7 @@ class EtaExpander : public ExprMutator { return std::move(gvar); } const auto base_func = mod_->Lookup(gvar); - if (auto *ptr = base_func.as()) { + if (auto* ptr = base_func.as()) { // handle relay function, skip external functions. auto func = GetRef(ptr); tvm::Array params; @@ -135,11 +130,7 @@ class EtaExpander : public ExprMutator { args.push_back(var); } - return Function( - args, - Call(gvar, params), - func->ret_type, - func->type_params); + return Function(args, Call(gvar, params), func->ret_type, func->type_params); } else { return std::move(gvar); } @@ -161,15 +152,14 @@ class EtaExpander : public ExprMutator { namespace transform { Pass EtaExpand(bool expand_constructor, bool expand_global_var) { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); }; return CreateModulePass(pass_func, 1, "EtaExpand", {}); } -TVM_REGISTER_GLOBAL("relay._transform.EtaExpand") -.set_body_typed(EtaExpand); +TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand); } // namespace transform diff --git a/src/relay/transforms/expr_subst.cc b/src/relay/transforms/expr_subst.cc index d3e6aa8dbfe60..54731ed69eae8 100644 --- a/src/relay/transforms/expr_subst.cc +++ b/src/relay/transforms/expr_subst.cc @@ -22,9 +22,10 @@ * \brief Utility functions for substituting expressions. */ -#include #include "./expr_subst.h" +#include + namespace tvm { namespace relay { diff --git a/src/relay/transforms/expr_subst.h b/src/relay/transforms/expr_subst.h index 849ffc2db9e2a..e82e3e6ca62b7 100644 --- a/src/relay/transforms/expr_subst.h +++ b/src/relay/transforms/expr_subst.h @@ -24,13 +24,13 @@ #ifndef TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #define TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #include + #include namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, - std::unordered_map subst_map); +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 8234dea5e075f..3c8d8db637c81 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -22,10 +22,11 @@ * \brief Replaces non linear activation functions with their fast but approximate counterparts. */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { @@ -33,10 +34,7 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() - : exp_op_(Op::Get("exp")), - erf_op_(Op::Get("erf")), - tanh_op_(Op::Get("tanh")) {} + FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -67,14 +65,11 @@ namespace transform { Pass FastMath() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FastMath(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FastMath") -.set_body_typed(FastMath); +TVM_REGISTER_GLOBAL("relay._transform.FastMath").set_body_typed(FastMath); } // namespace transform diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index a52f42054c3e3..70df0ed8c2b4c 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -21,15 +21,16 @@ * \file constant_folding.cc */ #include +#include #include +#include #include #include -#include -#include #include -#include -#include #include +#include +#include + #include "pattern_util.h" namespace tvm { @@ -48,8 +49,7 @@ class ConstantChecker : private ExprVisitor { return true; } const auto it = memo_.find(expr); - if (it != memo_.end()) - return it->second; + if (it != memo_.end()) return it->second; VisitExpr(expr); return memo_[expr]; // return memoized result or the default value false } @@ -69,12 +69,9 @@ class ConstantChecker : private ExprVisitor { } }; -bool ConstantCheck(const Expr& e) { - return ConstantChecker().Check(e); -} +bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_GLOBAL("relay.analysis.check_constant") -.set_body_typed(ConstantCheck); +TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. @@ -98,9 +95,7 @@ class ConstantFolder : public ExprMutator { } else { Var var = Downcast(this->Mutate(op->var)); Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -123,7 +118,7 @@ class ConstantFolder : public ExprMutator { const OpNode* op = call->op.as(); if (op == nullptr) return res; if (skip_list.count(op->name)) { - return res; + return res; } // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; @@ -133,9 +128,7 @@ class ConstantFolder : public ExprMutator { } // We should think about potentially constant evaluation over these ops too. - if (call->op == invoke_tvm_op_ || - call->op == shape_func_op_ || - call->op == alloc_tensor_op_ || + if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || call->op == alloc_storage_op_) { return GetRef(call); } @@ -184,8 +177,7 @@ class ConstantFolder : public ExprMutator { if (value->IsInstance()) { auto nd_array = Downcast(value); for (auto dim : nd_array.Shape()) { - CHECK_GT(dim, 0) - << "invalid dimension after constant eval"; + CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } return Constant(nd_array); } else if (const auto* val = value.as()) { @@ -202,7 +194,7 @@ class ConstantFolder : public ExprMutator { } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { - std::vector passes = {transform::FuseOps(0), + std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { @@ -211,10 +203,7 @@ class ConstantFolder : public ExprMutator { // TODO(@jroesch): fix this func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); } - auto mod = IRModule( - {}, - module_->type_definitions, - module_->Imports()); + auto mod = IRModule({}, module_->type_definitions, module_->Imports()); auto global = GlobalVar("main"); mod->Add(global, func); auto seq = transform::Sequential(passes); @@ -250,7 +239,7 @@ class ConstantFolder : public ExprMutator { value = runtime::NDArray::Empty({}, cdtype, ctx); } else { CHECK_NE(ishape.size(), 0); - std::vector cshape = { static_cast(ishape.size()) }; + std::vector cshape = {static_cast(ishape.size())}; value = runtime::NDArray::Empty(cshape, cdtype, ctx); int32_t* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; @@ -273,12 +262,11 @@ class ConstantFolder : public ExprMutator { // Cast the constant into correct dtype auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; - Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {}); + Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } }; - Expr FoldConstant(const Expr& expr, const IRModule& mod) { DLContext ctx; ctx.device_type = kDLCPU; @@ -295,14 +283,13 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FoldConstant(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldConstant(f, m)); + }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } -TVM_REGISTER_GLOBAL("relay._transform.FoldConstant") -.set_body_typed(FoldConstant); +TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index cfe74bfd8ef16..57e3d6925b20b 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -23,14 +23,14 @@ * \brief Fold axis scaling into weights of * conv/dense operators. */ -#include #include #include #include #include -#include "pattern_util.h" -#include "pass_util.h" +#include +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -43,7 +43,6 @@ namespace fold_scale_axis { using runtime::TypedPackedFunc; - // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of @@ -109,7 +108,7 @@ class Message : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); }; -Message::Message(const AxesSet& axes, bool require_positive) { +Message::Message(const AxesSet& axes, bool require_positive) { auto n = make_object(); n->axes = axes; n->require_positive = require_positive; @@ -139,7 +138,8 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { ++j; } else { ret.push_back(lhs[i]); - ++i; ++j; + ++i; + ++j; } } return ret; @@ -166,8 +166,8 @@ Message Intersect(const Message& lhs, const Message& rhs) { * positive scale is required. * \return The message containing the result scaling on axes of the input. */ -using FForwardPrep = runtime::TypedPackedFunc< - Array (const Call& call, const Message& out_message)>; +using FForwardPrep = + runtime::TypedPackedFunc(const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { @@ -180,8 +180,7 @@ class ScaledExprNode : public TempExprNode { Expr scale = NullValue(); Expr Realize() const final { - CHECK(!axes.defined()) - << "outstanding scale"; + CHECK(!axes.defined()) << "outstanding scale"; return value; } @@ -195,18 +194,15 @@ class ScaledExprNode : public TempExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; -using FForwardRewrite = TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const Message& message)>; +using FForwardRewrite = TypedPackedFunc& new_args, + const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order @@ -222,7 +218,7 @@ class ForwardPrep : private ExprVisitor { private: // The invoke list - std::vector > flist_; + std::vector> flist_; // The message on each node. std::unordered_map message_; // Update the message stored at node. @@ -245,15 +241,11 @@ class ForwardPrep : private ExprVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* call) { - LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; - } + void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; } void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); - auto flazy = [this, op] { - this->Update(op->body, NullValue()); - }; + auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } @@ -261,8 +253,7 @@ class ForwardPrep : private ExprVisitor { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { - static const auto& fprep = - Op::GetAttr("FScaleAxisForwardPrep"); + static const auto& fprep = Op::GetAttr("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; @@ -326,31 +317,26 @@ Array ReluForwardPrep(const Call& call, const Message& out_message) { return {out_message}; } -Expr ReluForwardRewrite(const Call& ref_call, - const Array& new_args, - const Message& message) { +Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_object(); - rnode->value = Call( - ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; return Expr(rnode); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardRewrite", + ReluForwardRewrite); -RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.leaky_relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array AddSubForwardPrep(const Call& call, const Message& out_message) { @@ -367,8 +353,7 @@ Array AddSubForwardPrep(const Call& call, const Message& out_message) { return {none, none}; } -Expr AddSubForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); @@ -380,43 +365,36 @@ Expr AddSubForwardRewrite(const Call& ref_call, if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - slhs->scale, tlhs->shape.size(), slhs->axes); + Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes); Expr rhs = Divide(new_args[1], scale); - rnode->value = Call(ref_call->op, {slhs->value, rhs}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - srhs->scale, trhs->shape.size(), srhs->axes); + Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes); Expr lhs = Divide(new_args[0], scale); - rnode->value = Call(ref_call->op, {lhs, srhs->value}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; } return Expr(rnode); } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardRewrite", + AddSubForwardRewrite); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { if (!message.defined()) return Expr(); const auto& expected_out_axes = message->axes; @@ -451,7 +429,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. @@ -476,8 +454,7 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && + if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; } @@ -488,8 +465,7 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { } // Conv2D consumes the scale axis during transformation. -Expr Conv2DForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); @@ -505,8 +481,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(sdata->axes.size() == 1 && - c_big_axis == sdata->axes[0]->value); + CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); @@ -518,29 +493,24 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // match the ic_axis if (is_depthwise_conv2d) { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_oc_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, scale); } else { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_ic_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis}); weight = Multiply(weight, scale); } // return transformed conv2d - return Call( - ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); + return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); - + .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef{ + auto fcontext = [&](const Call& call) -> ObjectRef { auto it = message.find(call.get()); if (it != message.end()) { return it->second; @@ -548,8 +518,7 @@ Expr ForwardFoldScaleAxis(const Expr& data) { return ObjectRef(nullptr); } }; - return ForwardRewrite( - data, "FScaleAxisForwardRewrite", fcontext); + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } //---------------------------------------- @@ -564,14 +533,11 @@ class BackwardTransformer; * positive scale is required. * \return Message containing the result scaling on axes of the input. */ -using FBackwardPrep = TypedPackedFunc< - Message(const Call& call, const Array& in_messages)>; +using FBackwardPrep = TypedPackedFunc& in_messages)>; -using FBackwardTransform = TypedPackedFunc< - Expr(const Call& call, - const Message& message, - const Expr& scale, - const BackwardTransformer& transformer)>; +using FBackwardTransform = + TypedPackedFunc; //---------------------------------------------- // Generic Visitors for FScaleAxisBackward @@ -580,8 +546,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); return std::move(message_); @@ -595,8 +560,7 @@ class BackwardPrep : private ExprVisitor { // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); - static const auto& fprep = - Op::GetAttr("FScaleAxisBackwardPrep"); + static const auto& fprep = Op::GetAttr("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); @@ -620,9 +584,7 @@ class BackwardPrep : private ExprVisitor { } }; -class BackwardTransformerNode : - public Object, - private ExprMutator { +class BackwardTransformerNode : public Object, private ExprMutator { public: // Run forward transform. Expr Fold(Expr expr) { @@ -692,19 +654,15 @@ class BackwardTransformerNode : class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} - explicit BackwardTransformer( - ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { - } + explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); } using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform( - const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = - Op::GetAttr("FScaleAxisBackwardTransform"); +Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { + static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { const Call call = GetRef(call_node); @@ -712,10 +670,7 @@ Expr BackwardTransformerNode::Transform( if (it != memo_.end()) { return it->second; } - Expr new_expr = f(GetRef(call_node), - message, - scale, - GetRef(this)); + Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { @@ -724,7 +679,6 @@ Expr BackwardTransformerNode::Transform( } } - //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -737,45 +691,38 @@ Message ReluBackwardPrep(const Call& call, const Array& in_messages) { return in_messages[0]; } -Expr ReluBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } - Expr input = transformer->Transform( - call->args[0], message, scale); + Expr input = transformer->Transform(call->args[0], message, scale); return Call(call->op, {input}, call->attrs, call->type_args); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardTransform", + ReluBackwardTransform); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); StructuralEqual equal; - if (in_messages[0].defined() && - MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { + if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; } else if (in_messages[1].defined() && MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { return in_messages[1]; - } else if (in_messages[0].defined() && - in_messages[1].defined() && - equal(in_messages[0]->axes, in_messages[1]->axes) && - equal(tlhs->shape, trhs->shape)) { + } else if (in_messages[0].defined() && in_messages[1].defined() && + equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. return in_messages[0]; } else { @@ -784,9 +731,7 @@ Message AddSubBackwardPrep(const Call& call, const Array& in_messages) } } -Expr AddSubBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); @@ -806,19 +751,15 @@ Expr AddSubBackwardTransform(const Call& call, } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); - Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); - Expr rhs_scale = ExpandBiasToMatchAxis( - scale, tlhs->shape.size(), message->axes); + Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); - Expr lhs = transformer->Transform( - call->args[0], NullValue(), NullValue()); + Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); - Expr lhs_scale = ExpandBiasToMatchAxis( - scale, trhs->shape.size(), message->axes); + Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -827,23 +768,19 @@ Expr AddSubBackwardTransform(const Call& call, } } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardTransform", + AddSubBackwardTransform); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); @@ -871,7 +808,7 @@ Expr MultiplyBackwardTransform(const Call& call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); // Consumer operators // Conv2D send out requirement of axis folding. @@ -893,8 +830,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 && - kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && + kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { return Message({c_big_axis}, false); } else { @@ -903,9 +839,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) } // Conv2D consumes the scale axis during transformation. -Expr Conv2DBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); @@ -920,31 +854,26 @@ Expr Conv2DBackwardTransform(const Call& call, // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1); CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(message->axes.size() == 1 && - c_big_axis == message->axes[0]->value); + CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value); int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); - Expr data = transformer->Transform( - call->args[0], NullValue(), NullValue()); - Expr weight = transformer->Transform( - call->args[1], NullValue(), NullValue()); + Expr data = transformer->Transform(call->args[0], NullValue(), NullValue()); + Expr weight = transformer->Transform(call->args[1], NullValue(), NullValue()); // scale on input for deptwise. - Expr wscale = ExpandBiasToMatchAxis( - scale, kernel_layout.ndim(), {big_oc_axis}); + Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, wscale); - return Call( - call->op, {data, weight}, call->attrs, call->type_args); + return Call(call->op, {data, weight}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { return make_object()->Fold(data); @@ -956,39 +885,33 @@ namespace transform { Pass ForwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::ForwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") -.set_body_typed(ForwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::BackwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") -.set_body_typed(BackwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis); Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. - Pass pass = Sequential( - {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, - "FoldScaleAxis"); + Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); return pass; } -TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis") -.set_body_typed(FoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis); } // namespace transform diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index f01c4faeff3eb..226b3384eba11 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -26,6 +26,7 @@ #include #include #include + #include "pass_util.h" namespace tvm { @@ -36,9 +37,7 @@ namespace relay { // so that calling realize repeatively won't hurt perf. class TempRealizer : private MixedModeMutator { public: - Expr Realize(Expr expr) { - return Mutate(expr); - } + Expr Realize(Expr expr) { return Mutate(expr); } private: Expr DispatchVisitExpr(const Expr& expr) final { @@ -57,17 +56,12 @@ class ForwardRewriter : private MixedModeMutator { ForwardRewriter(const OpMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_map_(rewrite_map), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} + : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_func_(rewrite_func), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} - + : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} // Transform expression. Expr Rewrite(const Expr& expr) { @@ -91,7 +85,7 @@ class ForwardRewriter : private MixedModeMutator { TempRealizer realizer_; // Visit and allow non-realized version. - Expr GetTempExpr(const Expr& expr, const Expr& post) { + Expr GetTempExpr(const Expr& expr, const Expr& post) { if (fmulti_ref_trigger_ != nullptr) { Expr ret = post; auto it = ref_counter_.find(expr.get()); @@ -160,9 +154,8 @@ class ForwardRewriter : private MixedModeMutator { } // try to rewrite. if (frewrite != nullptr) { - Expr res = frewrite( - ref_call, call_args, - fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); + Expr res = frewrite(ref_call, call_args, + fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { @@ -175,21 +168,18 @@ class ForwardRewriter : private MixedModeMutator { } } if (unchanged) return ref_call; - return Call( - new_op, call_args, call_node->attrs, call_node->type_args); + return Call(new_op, call_args, call_node->attrs, call_node->type_args); } }; -Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_name, +Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } -Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 0b6a04e943e76..054244dc35169 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -24,14 +24,15 @@ * \brief This is a backend-aware optimization pass. * Fuse necessary ops into a single one. */ -#include #include #include #include #include -#include "pattern_util.h" +#include + #include "../../support/arena.h" #include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -54,8 +55,9 @@ namespace relay { However, at the point of conv2d we do not necessarily know that all the future paths will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. - The immediate post-dominator of a node defined by the closest node where all the future path goes into. - In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows: + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: - Construct a DAG of dataflow graph for dominator analysis - Construct a post-dominator tree which gives immediate post dominator of each node. @@ -74,8 +76,8 @@ namespace relay { - CommitFuse: mark all the nodes between source and post-dominator as the same group. - We use an Union-Find data structure to manage the groups. */ -using support::LinkNode; using support::LinkedList; +using support::LinkNode; constexpr uint32_t kMaxFusedOps = 256; @@ -124,9 +126,7 @@ class IndexedForwardGraph { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " - << GetRef(node->ref) - << " outputs=["; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } @@ -148,8 +148,7 @@ class IndexedForwardGraph { // Creator of post dominator tree of the dataflow class IndexedForwardGraph::Creator : private ExprVisitor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -165,9 +164,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // attribute equal comparator StructuralEqual attr_equal_; // Update the message stored at the node. - void Update(const Expr& node, - IndexedForwardGraph::Node* parent, - OpPatternKind pattern) { + void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { const tvm::Object* key = node.get(); IndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); @@ -189,8 +186,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void AddNode(const tvm::Object* key) { auto it = graph_.node_map.find(key); - CHECK(it != graph_.node_map.end()) - << "Cannot find node " << GetRef(key); + CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef(key); IndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; @@ -215,12 +211,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. - bool is_simple_const = ( - dtype == DataType::Int(32) || - dtype == DataType::Int(64) || - dtype == DataType::Float(32) || - dtype == DataType::Float(64) || - dtype == DataType::Bool()); + bool is_simple_const = + (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || + dtype == DataType::Float(64) || dtype == DataType::Bool()); if (op->is_scalar() && is_simple_const) { node->pattern = kElemWise; } else { @@ -233,8 +226,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttr("TOpPattern"); // Now we set the pattern of this call. // // If we see a call mentioning an operator we should mark it with its @@ -262,13 +254,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { const auto* rtype = call->checked_type().as(); // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { - const auto* arg_type = - call->args[i]->checked_type().as(); + const auto* arg_type = call->args[i]->checked_type().as(); // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; - if (edge_pattern == kBroadcast && - arg_type != nullptr && - rtype != nullptr && + if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && attr_equal_(rtype->shape, arg_type->shape)) { edge_pattern = kElemWise; } @@ -320,9 +309,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) final { - this->AddNode(op); - } + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } void VisitExpr_(const LetNode* op) final { // do not fuse through let. @@ -371,8 +358,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create( - support::Arena* arena, const Expr& body) { +IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { return Creator(arena).Prepare(body); } @@ -405,13 +391,11 @@ class DominatorTree { * \note This algorithm makes use of the fact that graph is DAG, * and runs a single pass algorithm via LCA (Least Common Ancestor) */ - static DominatorTree PostDom(support::Arena* arena, - const IndexedForwardGraph& graph); + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); private: // Combine pattern together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > rhs) return lhs; return rhs; } @@ -423,26 +407,19 @@ class DominatorTree { * The combined edge pattern across all the parents. * \return The least common ancestor of the two. */ - static Node* LeastCommonAncestor( - Node* lhs, - Node* rhs, - OpPatternKind* edge_pattern) { + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { while (lhs != rhs) { if (lhs == nullptr) return nullptr; if (rhs == nullptr) return nullptr; if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); rhs = rhs->parent; } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); lhs = lhs->parent; } else { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); lhs = lhs->parent; rhs = rhs->parent; } @@ -503,9 +480,7 @@ class DominatorTree { } }; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, - const IndexedForwardGraph& graph) { +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { DominatorTree tree; tree.nodes.resize(graph.post_dfs_order.size(), nullptr); // reverse topo order @@ -579,13 +554,11 @@ class GraphPartitioner { /*! \brief internal field used for deduplication */ std::unordered_set visited_; // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { if (visited_.count(src)) return true; visited_.insert(src); - Group* gnode = groups_[src->index]; + Group* gnode = groups_[src->index]; CHECK(gnode != nullptr); gnode = gnode->FindRoot(); if (!fcond(gnode->pattern, src == sink)) return false; @@ -607,10 +580,8 @@ class GraphPartitioner { * \tparam F the condition function, with signature * \note sink must be a post-dominator of src. */ - template - bool CheckPath(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { CHECK(!src->extern_ref); visited_.clear(); CHECK(src != sink); @@ -620,8 +591,7 @@ class GraphPartitioner { return true; } // Combine two patterns together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { LOG(FATAL) << "Cannot merge two complex group together"; } @@ -644,14 +614,11 @@ class GraphPartitioner { if (child->master_ref != nullptr) { CHECK(parent->master_ref == nullptr); parent->master_ref = child->master_ref; - parent->pattern = CombinePattern( - child->pattern, parent->pattern); + parent->pattern = CombinePattern(child->pattern, parent->pattern); } } // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - Group* target) { + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { if (src == sink) return; if (visited_.count(src)) return; visited_.insert(src); @@ -669,8 +636,7 @@ class GraphPartitioner { * \param sink The termination node. * \note sink must be a post-dominator of src. */ - void CommitFuse(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink) { + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { Group* target = groups_[sink->index]; visited_.clear(); CHECK(src != sink); @@ -694,9 +660,7 @@ class GraphPartitioner { } // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, - const DominatorTree& post_dom_tree, - int phase) { + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; @@ -711,8 +675,7 @@ class GraphPartitioner { size_t dom_parent_gindex = dom_node->parent->gnode->index; // refuse the fusion if too many ops are going to be fused together - if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) - continue; + if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue; if (phase == 2) { // Fuse injective ops into intermediate tuples, if any @@ -723,9 +686,7 @@ class GraphPartitioner { if (dom_root_group->pattern == kTuple) continue; if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; // dom_root_group can also be tuple, as in inception layers // CheckPath is needed to avoid fusing two intermediate tuples if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { @@ -750,9 +711,7 @@ class GraphPartitioner { if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { CHECK(dom_node->parent->gnode != nullptr); // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kBroadcast; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -760,8 +719,7 @@ class GraphPartitioner { } else if (group_node->pattern <= kBroadcast) { // Pre-condition: can only be fused to parent which is injective or reduction. if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || - dom_node->pattern == kCommReduce)) { + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { // Check if all the intermediate ops are still broadcast. // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { @@ -770,9 +728,7 @@ class GraphPartitioner { // are allowed be fused to the elemwise/broadcast master. return kind <= kInjective; } else { - return (kind <= kBroadcast || - kind == kCommReduce || - kind == kInjective || + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || kind == kOutEWiseFusable); } }; @@ -785,9 +741,7 @@ class GraphPartitioner { // so conv2d always finishes fusing. if (phase != 1) continue; // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -799,8 +753,8 @@ class GraphPartitioner { } }; -std::vector -GraphPartitioner::Partition(const IndexedForwardGraph& graph) { +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { this->InitGroups(graph); if (opt_level_ == 0) return std::move(groups_); // get post dominator tree @@ -818,8 +772,7 @@ class FuseMutator : private ExprMutator { Expr Transform(const Expr& body, int fuse_opt_level) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( - graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -829,7 +782,6 @@ class FuseMutator : private ExprMutator { return this->Mutate(body); } - private: /*! \brief Temporary information from each group. */ struct GroupInfo { @@ -872,8 +824,7 @@ class FuseMutator : private ExprMutator { // Transform calls. Expr VisitExpr_(const CallNode* call) { if (call->op.as()) { - static auto fnoncomputational = - Op::GetAttr("TNonComputational"); + static auto fnoncomputational = Op::GetAttr("TNonComputational"); if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); @@ -888,8 +839,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); - auto new_call = Call( - call->op, new_args, call->attrs, call->type_args); + auto new_call = Call(call->op, new_args, call->attrs, call->type_args); if (ret_group->root_ref == call) { // This is the root of the group @@ -936,9 +886,7 @@ class FuseMutator : private ExprMutator { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { bool has_call = false; - void VisitExpr_(const CallNode* op) final { - has_call = true; - } + void VisitExpr_(const CallNode* op) final { has_call = true; } } visitor; visitor(body); const GroupInfo& ginfo = ginfo_[group]; @@ -967,13 +915,13 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { - auto it = gmap_.find(expr.get()); - if (it == gmap_.end()) return ""; - std::ostringstream os; - auto *group = it->second->FindRoot(); - os << " /* group=" << group << " */"; - return os.str(); - }); + auto it = gmap_.find(expr.get()); + if (it == gmap_.end()) return ""; + std::ostringstream os; + auto* group = it->second->FindRoot(); + os << " /* group=" << group << " */"; + return os.str(); + }); LOG(INFO) << "Dump of group info:\n" << text; } }; @@ -986,15 +934,14 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; - return Downcast(FuseOps(f, opt_level, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + return Downcast(FuseOps(f, opt_level, m)); + }; return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FuseOps") -.set_body_typed(FuseOps); +TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps); } // namespace transform diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index d0ff169445fb2..67c62f3ac6656 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -22,13 +22,14 @@ * \brief API for Automatic Differentiation for the Relay IR. */ #include -#include -#include #include +#include #include -#include "pattern_util.h" -#include "pass_util.h" +#include + #include "let_list.h" +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -42,12 +43,14 @@ using namespace tvm::runtime; * Formally speaking, such requirement mean that the input function is a closed expression - * that is, it only refer to local variable that is it's parameter, or defined inside it. * Every top level definition satisfy this criteria. - * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]). - * In relay we currently only support compile-time AD, but it should be enough for a lot of use case. + * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> + * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough + * for a lot of use case. * - * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant. - * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD. - * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD! + * In deep learning, the most common way to train a deep neural network is by gradient descent or + * some of it's variant. Such optimization method require us to input the gradient of neural + * network, which can be obtained easily using AD. In fact, back propagation is essentially + * reverse-mode automatic differentiation, a kind of AD! */ /*! In relay, automatic differentiation(AD) is a macro, @@ -55,10 +58,10 @@ using namespace tvm::runtime; * (x0, x1, x2, ...) -> Float[] to * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)), * When x0, x1, x2... are Float of different shape. - * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input. - * WithGradientType will take the type of input, and produce the type of output. - * There are multiple implementation of AD in relay, with different characteristic. - * However, they all transform the input expr according to WithGradientType. + * the return value is a pair, with left hand side as the original value, and right hand side as + * gradient of the input. WithGradientType will take the type of input, and produce the type of + * output. There are multiple implementation of AD in relay, with different characteristic. However, + * they all transform the input expr according to WithGradientType. */ Type WithGradientType(const Type&); @@ -71,10 +74,7 @@ Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; - return FuncType(ty->arg_types, - TupleType({ - ty->ret_type, - TupleType(ty->arg_types)}), {}, {}); + return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); } //! \brief if the expression is a GlobalVar, transform to it's expression. @@ -95,7 +95,7 @@ Expr DeGlobal(const IRModule& mod, const Expr& e) { * pass. */ struct ADValueNode { - virtual ~ADValueNode() { } + virtual ~ADValueNode() {} template T& get() { auto ret = dynamic_cast(this); @@ -110,8 +110,8 @@ using ADValue = std::shared_ptr; struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) : - forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { + ADTensor(LetList* ll, const Expr& forward) + : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { this->forward->checked_type_ = forward->checked_type(); } }; @@ -121,51 +121,46 @@ struct ADTensor : ADValueNode { * can compute away this function to obtain a reverse mode program. */ struct ADFunction : ADValueNode { - std::function&, - const Attrs&, - const tvm::Array&)> func; - explicit ADFunction(const std::function&, - const Attrs&, - const tvm::Array&)>& func) : - func(func) { } + std::function&, const Attrs&, + const tvm::Array&)> + func; + explicit ADFunction(const std::function&, + const Attrs&, const tvm::Array&)>& func) + : func(func) {} }; -struct FirstOrderReverseAD : ExprFunctor { +struct FirstOrderReverseAD : ExprFunctor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping std::unordered_map env; LetList* ll; - FirstOrderReverseAD(LetList* ll) : ll(ll) { } + FirstOrderReverseAD(LetList* ll) : ll(ll) {} ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); - CHECK(rev_map.count(op_ref)) - << op->name << " does not have reverse mode defined"; - return std::make_shared([this, op_ref](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().forward); - } - auto orig = Call(op_ref, call_args, attrs, type_args); - orig->checked_type_ = orig_type; - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - args[i]->get().reverse = - ll->Push(Add(args[i]->get().reverse, rev[i])); - } - }); - return ret; - }); + CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; + return std::make_shared( + [this, op_ref](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + std::vector call_args; + for (const ADValue& adval : args) { + call_args.push_back(adval->get().forward); + } + auto orig = Call(op_ref, call_args, attrs, type_args); + orig->checked_type_ = orig_type; + auto ret = std::make_shared(ll, orig); + backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, ret->reverse); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + args[i]->get().reverse = + ll->Push(Add(args[i]->get().reverse, rev[i])); + } + }); + return ret; + }); } ADValue VisitExpr_(const ConstantNode* op) final { @@ -185,16 +180,15 @@ struct FirstOrderReverseAD : ExprFunctor { ADValue VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); // todo: assert no closure - return std::make_shared([this, f](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - CHECK_EQ(f->params.size(), args.size()); - for (size_t i = 0; i < f->params.size(); ++i) { - env[f->params[i]] = args[i]; - } - return VisitExpr(f->body); - }); + return std::make_shared( + [this, f](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + CHECK_EQ(f->params.size(), args.size()); + for (size_t i = 0; i < f->params.size(); ++i) { + env[f->params[i]] = args[i]; + } + return VisitExpr(f->body); + }); } ADValue VisitExpr_(const VarNode* op) final { @@ -240,8 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { res.reverse = OnesLike(res.forward); - for (auto it = reverse_ad.backprop_actions.rbegin(); - it != reverse_ad.backprop_actions.rend(); + for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); } @@ -257,8 +250,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") -.set_body_typed(FirstOrderGradient); +TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { @@ -267,17 +259,13 @@ struct ReverseADType : TypeMutator { } }; -Type ReverseType(const Type& t) { - return ReverseADType()(t); -} +Type ReverseType(const Type& t) { return ReverseADType()(t); } /*! \brief Lift a function that transform Tensor to a function that also transform more type * by doing a structure preserving map. */ Expr LiftTensor(const std::function& f, - const std::function& tf, - const Type& forward_type, - const Expr& e, + const std::function& tf, const Type& forward_type, const Expr& e, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as()) { @@ -288,11 +276,7 @@ Expr LiftTensor(const std::function& f, tvm::Array fields; tvm::Array types; for (size_t i = 0; i < tt->fields.size(); ++i) { - auto field = LiftTensor(f, - tf, - tt->fields[i], - ll->Push(GetField(e, i)), - ll); + auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll); fields.push_back(field); types.push_back(field->checked_type_); } @@ -308,10 +292,7 @@ Expr LiftTensor(const std::function& f, /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, * by stitching the references in the AD values. */ -void TransferGrads(const Type& forward_type, - const Expr& from, - const Expr& to, - LetList* ll) { +void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) { CHECK(IsAtomic(from)) << from; CHECK(IsAtomic(to)) << to; if (forward_type.as()) { @@ -320,9 +301,7 @@ void TransferGrads(const Type& forward_type, ll->Push(RefWrite(to_ref, RefRead(from_ref))); } else if (auto* tt = forward_type.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - TransferGrads(tt->fields[i], - ll->Push(TupleGetItem(from, i)), - ll->Push(TupleGetItem(to, i)), + TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)), ll); } } else { @@ -333,48 +312,31 @@ void TransferGrads(const Type& forward_type, /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { - auto rev = [&](const Expr& e) { - return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); - }; - auto rev_type = [&](const Type& forward_type) { - return ReverseType(forward_type); - }; + auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; + auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the original value. */ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { - auto val = [&](const Expr& e) { - return GetField(e, 0); - }; - auto val_type = [&](const Type& forward_type) { - return forward_type; - }; + auto val = [&](const Expr& e) { return GetField(e, 0); }; + auto val_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(val, val_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { - auto grad = [&](const Expr& e) { - return ll->Push(RefRead(GetField(e, 1))); - }; - auto grad_type = [&](const Type& forward_type) { - return forward_type; - }; + auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); }; + auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWrite(GetField(arg, 1), - Add(ll->Push(RefRead(GetField(arg, 1))), - grad))); + ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - UpdateGrad(tt->fields[i], - ll->Push(GetField(arg, i)), - ll->Push(GetField(grad, i)), - ll); + UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); } } else { LOG(FATAL) << "unsupported arg type of operator: " << t; @@ -394,15 +356,14 @@ struct ReverseAD : ExprMutator { std::shared_ptr ad_vars; const OpMap rev_map = Op::GetAttr("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) - : bp(bp), ad_vars(ad_vars) { } + explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; throw; } - Expr VisitCheckpoint(const CallNode *call) { + Expr VisitCheckpoint(const CallNode* call) { const OpNode* op_node = call->op.as(); CHECK(op_node) << "expected op in call"; Op op_ref = GetRef(op_node); @@ -412,20 +373,17 @@ struct ReverseAD : ExprMutator { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - // we need a new ReverseAD visitor to avoid clobbering the bp local var - auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); - - TransferGrads(call->checked_type(), ret, dup_ad, ll); - ll->Push(Call(RefRead(dup_bp), {})); - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + // we need a new ReverseAD visitor to avoid clobbering the bp local var + auto dup_bp = ll->Push(BPEmpty()); + ReverseAD dup_diff(dup_bp, ad_vars); + auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); + + TransferGrads(call->checked_type(), ret, dup_ad, ll); + ll->Push(Call(RefRead(dup_bp), {})); + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -439,8 +397,7 @@ struct ReverseAD : ExprMutator { return VisitCheckpoint(call); } - CHECK(rev_map.count(op_ref)) - << op_node->name << " does not have reverse mode defined"; + CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; for (const auto& arg : call->args) { @@ -456,18 +413,16 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + tvm::Array rev = + rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -481,9 +436,8 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const IfNode* op) final { - return If(TupleGetItem(VisitExpr(op->cond), 0), - VisitExpr(op->true_branch), - VisitExpr(op->false_branch)); + return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), + VisitExpr(op->false_branch)); } Expr VisitExpr_(const VarNode* var) final { @@ -497,9 +451,7 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } - Type VisitType(const Type& t) final { - return t.defined() ? ReverseType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; bool MissingGrad(const Expr& e) { @@ -585,8 +537,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.gradient") -.set_body_typed(Gradient); +TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/infer_layout_util.h b/src/relay/transforms/infer_layout_util.h index ca730034327a8..e4df647e65acc 100644 --- a/src/relay/transforms/infer_layout_util.h +++ b/src/relay/transforms/infer_layout_util.h @@ -27,11 +27,13 @@ #ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ #define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ -#include #include #include +#include + #include #include + #include "pattern_util.h" namespace tvm { @@ -94,17 +96,15 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ -using FInferCorrectLayout = runtime::TypedPackedFunc< - Array>(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types)>; +using FInferCorrectLayout = runtime::TypedPackedFunc>( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types)>; /*! \brief take arbitrary input layout and copy to output */ -inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> ElemwiseArbitraryLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Layout ret; if (new_in_layouts.defined()) { @@ -119,14 +119,14 @@ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, } } - return Array >{Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } /*! \brief Infer layout for binary broadcast operators */ -inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> BinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array layouts; Array> old_in_shapes; for (auto old_in_t : old_in_types) { @@ -142,28 +142,27 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, if (!layouts[0].defined() && !layouts[1].defined()) { // both undefined, infer fails - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } else if (!layouts[0].defined() || !layouts[1].defined()) { // only one is defined, use shape information to help infer int defined_idx = layouts[0].defined() ? 0 : 1; int undef_idx = 1 - defined_idx; if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { - layouts.Set(undef_idx, - layouts[defined_idx].SubLayout( - old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), - old_in_shapes[undef_idx].size())); - return Array >{layouts, {layouts[defined_idx]}}; + layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() - + old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); + return Array>{layouts, {layouts[defined_idx]}}; } else { // only know the tensor with smaller dimensions, // so we cannot infer the final broadcasted output. // fails in this case. - return Array >{{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } else if (layouts[0].defined() && layouts[1].defined() && - (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { + (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { int scalar = layouts[0].ndim() == 0 ? 0 : 1; - return Array >{layouts, {layouts[1-scalar]}}; + return Array>{layouts, {layouts[1 - scalar]}}; } else { // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims // while transforming layout. @@ -217,8 +216,7 @@ static inline std::tuple, Array, bool> InferCorrectLayouts Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array> inferred_layouts; - inferred_layouts = - finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index ba0f5688ea9d8..c9a0de44e2d46 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -35,8 +35,9 @@ #include #include -#include #include +#include + #include #include @@ -83,11 +84,8 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + func->attrs); } private: @@ -115,20 +113,13 @@ class Inliner : ExprMutator { } // Make a new Relay expression to replace the callee. - Expr MakeNewExpr(const GlobalVar& global, - const Array& args, - const Expr& callee) { - CHECK(callee->IsInstance() || - callee->IsInstance()); + Expr MakeNewExpr(const GlobalVar& global, const Array& args, const Expr& callee) { + CHECK(callee->IsInstance() || callee->IsInstance()); auto base_func = call_graph_->GetGlobalFunction(global); const auto* fn = base_func.as(); CHECK(fn) << "Expected to work on a Relay function."; - auto func = Function(fn->params, - fn->body, - fn->ret_type, - fn->type_params, - fn->attrs); + auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. if (!func->GetAttr(attr::kCompiler).defined()) { @@ -144,14 +135,13 @@ class Inliner : ExprMutator { // Cannot replace TensorType/TensorTupleType with FuncType. Therefore, // we simply inline the function as a closure instead of directly using // its body when the global var returns FuncType. - return ret_type->IsInstance() ? std::move(func) - : func->body; + return ret_type->IsInstance() ? std::move(func) : func->body; } else { CHECK(callee->IsInstance()); return Bind(func->body, bind_map); } } else if (const auto* call_node = callee.as()) { - return Call(func, args, call_node->attrs, call_node->type_args); + return Call(func, args, call_node->attrs, call_node->type_args); } else { return std::move(func); } @@ -214,14 +204,11 @@ namespace transform { Pass Inline() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::Inline(m); - }; + [=](IRModule m, PassContext pc) { return relay::Inline(m); }; return CreateModulePass(pass_func, 1, "InlineGlobals", {}); } -TVM_REGISTER_GLOBAL("relay._transform.Inline") -.set_body_typed(Inline); +TVM_REGISTER_GLOBAL("relay._transform.Inline").set_body_typed(Inline); } // namespace transform diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index e6248f11a00ea..3cd29d66abfd5 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -24,21 +24,21 @@ * \brief Lazily instantiate 0-filled or 1-filled tensors. * This pass should be used after reverse-mode ad so that gradient tensors * are not instantiated until after the forward pass. - * - * This pass delays or removes memory allocation by converting tensors into + * + * This pass delays or removes memory allocation by converting tensors into * GradCell, an algebraic data type defined in gradient.rly. - * + * * This will delay or decrease memory usage. All calls to * ones, ones_like, zeros, zeros_like will call the One or Zero constructor * of GradCell, which will not instantiate in memory until needed. All other cases result * in using the Raw constructor which means the tensor is instantiated in memory. - * + * * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. - * + * * Note: this pass can only be used with functions where the input/output types are * a combination of TupleTypes and TensorTypes - * + * * This pass optimizes 6 ops: * - add * - multiply @@ -46,39 +46,40 @@ * - ones_like * - zeros * - zeros_like - * + * * This pass makes use of three visitor. The most important one visits the entire function, * one is used for wrap inputs and one to unwrap outputs. - * + * * For example: * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] - * + * * After this pass * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] - * + * * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ +#include #include #include #include -#include #include + #include "let_list.h" namespace tvm { namespace relay { /*! -* \brief Visitor appropriately wraps tensors with Raw constructor -* -* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now) -* and either call the GradCell constructor if TensorType -* or unfold and recursively visit if TupleType -*/ -class InputVisitor: public ExprFunctor { + * \brief Visitor appropriately wraps tensors with Raw constructor + * + * Recursively looks at the type of the expression (TensorType or TupleType are only supported for + * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if + * TupleType + */ +class InputVisitor : public ExprFunctor { public: - explicit InputVisitor(IRModule module): module_(module) {} + explicit InputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const VarNode* op, const Type& t) final { std::cout << op->type_annotation << std::endl; @@ -88,13 +89,13 @@ class InputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return WrapExpr(GetRef(op), t); } + private: IRModule module_; Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), - {expr}, Attrs(), {type}); + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -110,15 +111,15 @@ class InputVisitor: public ExprFunctor { }; /*! -* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors -* -* Recursively looks at the type of the expression -* and either use the FromGradCell function if TypeCall to GradCell -* or unfold and recursively visit if TupleType -*/ -class OutputVisitor: public ExprFunctor { + * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors + * + * Recursively looks at the type of the expression + * and either use the FromGradCell function if TypeCall to GradCell + * or unfold and recursively visit if TupleType + */ +class OutputVisitor : public ExprFunctor { public: - explicit OutputVisitor(IRModule module): module_(module) {} + explicit OutputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const CallNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); @@ -127,6 +128,7 @@ class OutputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); } + private: IRModule module_; @@ -150,19 +152,18 @@ class OutputVisitor: public ExprFunctor { } }; -class LazyGradientInitializer: public ExprMutator, public TypeMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: - explicit LazyGradientInitializer(IRModule module): - module_(module) { - module_->ImportFromStd("gradient.rly"); - } + explicit LazyGradientInitializer(IRModule module) : module_(module) { + module_->ImportFromStd("gradient.rly"); + } /*! - * \brief apply LazyGradientInit transformation and wrap function - * so that function type stays the same - * - * input/output types should only be a combination of TupleTypes and TensorTypes - */ + * \brief apply LazyGradientInit transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes + */ Expr Transform(const Expr& e) { auto* f = (e).as(); auto* transformed = this->Mutate(e).as(); @@ -185,8 +186,8 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return Call(module_->GetConstructor("GradCell", "Raw"), - {GetRef(op)}, Attrs(), {op->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), + {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -203,12 +204,12 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { // fn() -> T, function returns result of the operation - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); + Expr func = + Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", constructor_name), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(), + {call_node->checked_type()}); } if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { @@ -218,23 +219,21 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), + {call_node->checked_type()}); } // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } // not an op return ExprMutator::VisitExpr_(call_node); } - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } Type VisitType_(const TensorTypeNode* op) { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); @@ -248,23 +247,22 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { IRModule module_; /*! - * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type - */ + * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type + */ Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { // can only use overloaded functions if 2 arguments of same type if (call_node->args.size() != 2 || !tvm::StructuralEqual()(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { Expr result = CallPrimitiveOp(call_node); - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } tvm::Array args; // create "fallback" function for overloaded function Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; + tvm::Array params = {Var("lhs", paramType), Var("rhs", paramType)}; // use primitive op in this case Expr callOp = Call(call_node->op, {params[0], params[1]}); Expr func = Function(params, callOp, paramType, Array()); @@ -279,16 +277,15 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } /*! - * \brief Convert calls to other ops by converting args into TensorType - * \return call expr returning result of op - */ + * \brief Convert calls to other ops by converting args into TensorType + * \return call expr returning result of op + */ Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation return Call(call_node->op, args); @@ -302,14 +299,13 @@ Expr LazyGradientInit(const Expr& e, IRModule mod) { namespace transform { Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(LazyGradientInit(f, m)); - }; - return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LazyGradientInit(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit") -.set_body_typed(LazyGradientInit); +TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); } // namespace transform diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 0b5c671ab7f64..25919b4ca3de1 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -23,10 +23,10 @@ * shape, dtype or layout to another op or a sequence of ops. */ -#include #include #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index f195c3060e2f1..c0e0b3a238642 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -29,12 +29,14 @@ #ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_ #define TVM_RELAY_TRANSFORMS_LET_LIST_H_ -#include #include +#include + +#include +#include #include #include -#include -#include + #include "tvm/relay/type.h" namespace tvm { @@ -77,9 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { - return Push(Var("x", ty), expr); - } + Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } /*! * \brief insert a binding. @@ -88,9 +88,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr) { - return Push(expr, Type()); - } + Var Push(Expr expr) { return Push(expr, Type()); } /*! * \brief wrap an expr around the LetList. @@ -130,16 +128,14 @@ class LetList { * * \return the wrapped Expr. */ - template + template static Expr With(F&& f) { LetList ll; return ll.Get(f(&ll)); } static Expr LetBind(const Expr& e, const std::function& f) { - return With([&](LetList* ll) { - return f(ll->Push(e)); - }); + return With([&](LetList* ll) { return f(ll->Push(e)); }); } private: diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 75d95f0378f1c..46fdae0cf90c1 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -234,8 +234,7 @@ Pass MergeComposite(const tvm::Array& pattern_names, return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { tvm::Array pattern_names = args[0]; tvm::Array patterns = args[1]; std::vector checks; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index cd1f40c28767e..a27cb79da8cf3 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -91,12 +91,13 @@ */ #include #include -#include #include -#include #include -#include "pass_util.h" +#include +#include + #include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -109,9 +110,7 @@ using namespace runtime; * Use VarHash to hash Var by id. */ struct VarHash { - size_t operator()(const Var& v) const { - return ObjectHash()(v->vid); - } + size_t operator()(const Var& v) const { return ObjectHash()(v->vid); } }; /*! \brief Compare Var by it's id. @@ -119,9 +118,7 @@ struct VarHash { * Use VarEqual to compare Var by id. */ struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } + bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); } }; Expr PostProcess(const Expr&); @@ -137,9 +134,7 @@ class Static : public ObjectRef { public: Static() {} explicit Static(ObjectPtr n) : ObjectRef(n) {} - const StaticNode* operator->() const { - return static_cast(get()); - } + const StaticNode* operator->() const { return static_cast(get()); } using ContainerType = StaticNode; }; @@ -156,9 +151,9 @@ struct PStaticNode : Object { Static pstatic; // may be null Expr dynamic; Time created_time; - PStaticNode(const Static& pstatic, const Expr& dynamic) : - pstatic(pstatic), dynamic(dynamic), created_time(time()) { } - explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } + PStaticNode(const Static& pstatic, const Expr& dynamic) + : pstatic(pstatic), dynamic(dynamic), created_time(time()) {} + explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {} static constexpr const char* _type_key = "relay.PStatic"; TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); }; @@ -170,7 +165,7 @@ class PStatic : public ObjectRef { struct STupleNode : StaticNode { std::vector fields; - explicit STupleNode(const std::vector& fields) : fields(fields) { } + explicit STupleNode(const std::vector& fields) : fields(fields) {} static constexpr const char* _type_key = "relay.STuple"; TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); }; @@ -186,7 +181,7 @@ Static MkSTuple(const std::vector& fields) { struct STensorNode : StaticNode { runtime::NDArray data; - explicit STensorNode(const NDArray& data) : data(data) { } + explicit STensorNode(const NDArray& data) : data(data) {} static constexpr const char* _type_key = "relay.STensor"; TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); }; @@ -196,15 +191,13 @@ class STensor : public Static { TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); }; -Static MkSTensor(const NDArray& data) { - return Static(make_object(data)); -} +Static MkSTensor(const NDArray& data) { return Static(make_object(data)); } struct SConstructorNode : StaticNode { Constructor constructor; std::vector fields; - SConstructorNode(const Constructor& constructor, const std::vector& fields) : - constructor(constructor), fields(fields) { } + SConstructorNode(const Constructor& constructor, const std::vector& fields) + : constructor(constructor), fields(fields) {} static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; @@ -229,19 +222,14 @@ class SRef : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); }; -Static MkSRef() { - return Static(make_object()); -} +Static MkSRef() { return Static(make_object()); } -using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; +using Func = std::function&, const Attrs&, + const Array&, LetList*)>; struct SFuncNode : StaticNode { Func func; - explicit SFuncNode(const Func& func) : func(func) { } + explicit SFuncNode(const Func& func) : func(func) {} static constexpr const char* _type_key = "relay.SFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode); }; @@ -251,15 +239,13 @@ class SFunc : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode); }; -Static MkSFunc(const Func& func) { - return Static(make_object(func)); -} - +Static MkSFunc(const Func& func) { return Static(make_object(func)); } class FuelNode; /*! \brief A meet-semilattice with finite descending chain. * It means that we can meet two element to get an element, - * and for every element, there is only a finite amount of meet before getting back the same element. + * and for every element, there is only a finite amount of meet before getting back the same + * element. * * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. @@ -301,9 +287,7 @@ class FuelNode : public RelayNode { TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); }; -const FuelNode* Fuel::operator->() const { - return static_cast(get()); -} +const FuelNode* Fuel::operator->() const { return static_cast(get()); } Fuel MkFSeq(const std::vector& fuels); struct FSeqNode : FuelNode { @@ -318,7 +302,7 @@ struct FSeqNode : FuelNode { } return MkFSeq(new_fuels); } - explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } + explicit FSeqNode(const std::vector& fuels) : fuels(fuels) {} static constexpr const char* _type_key = "relay.FSeq"; TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); }; @@ -328,9 +312,7 @@ class FSeq : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); }; -Fuel MkFSeq(const std::vector& fuels) { - return Fuel(make_object(fuels)); -} +Fuel MkFSeq(const std::vector& fuels) { return Fuel(make_object(fuels)); } Fuel MkFTime(Time time); struct FTimeNode : FuelNode { @@ -341,7 +323,7 @@ struct FTimeNode : FuelNode { Time new_time = std::min(time, x->time); return std::make_tuple(MkFTime(new_time), new_time < time); } - explicit FTimeNode(Time time) : time(time) { } + explicit FTimeNode(Time time) : time(time) {} static constexpr const char* _type_key = "relay.FTime"; TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); }; @@ -351,9 +333,7 @@ class FTime : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); }; -Fuel MkFTime(Time time) { - return Fuel(make_object(time)); -} +Fuel MkFTime(Time time) { return Fuel(make_object(time)); } Fuel MkFTValue(size_t tvalue); /*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ @@ -365,7 +345,7 @@ struct FTValueNode : FuelNode { size_t new_tvalue = std::min(tvalue, x->tvalue); return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); } - explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } + explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {} static constexpr const char* _type_key = "relay.FTValue"; TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); }; @@ -375,9 +355,7 @@ class FTValue : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); }; -Fuel MkFTValue(size_t tvalue) { - return Fuel(make_object(tvalue)); -} +Fuel MkFTValue(size_t tvalue) { return Fuel(make_object(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. * @@ -397,9 +375,7 @@ class FTop : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); }; -Fuel MkFTop() { - return Fuel(make_object()); -} +Fuel MkFTop() { return Fuel(make_object()); } /*! * \brief A stack frame in the Relay interpreter. @@ -414,10 +390,10 @@ struct Frame { class Environment { public: - Environment() : env_({Frame()}) { } + Environment() : env_({Frame()}) {} Environment(const Environment&) = delete; - template + template T Extend(const std::function& body) { FrameContext fc(this); return body(); @@ -447,12 +423,8 @@ class Environment { struct FrameContext { Environment* env_; - explicit FrameContext(Environment* env) : env_(env) { - env_->env_.push_back(Frame()); - } - ~FrameContext() { - env_->env_.pop_back(); - } + explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); } + ~FrameContext() { env_->env_.pop_back(); } }; }; @@ -470,16 +442,16 @@ struct StoreFrame { * It only outdate the frame above it, but not the current frame. */ bool history_valid = true; - explicit StoreFrame(const std::unordered_map& store) : store(store) { } + explicit StoreFrame(const std::unordered_map& store) : store(store) {} StoreFrame() = default; }; class Store { public: - Store() : store_({StoreFrame()}) { } + Store() : store_({StoreFrame()}) {} Store(const Store&) = delete; - template + template T Extend(const std::function& body) { StoreFrameContext sfc(this); return body(); @@ -534,13 +506,9 @@ PStatic HasStatic(const Static& stat, const Expr& dynamic) { return PStatic(make_object(stat, dynamic)); } -PStatic NoStatic(const Expr& dynamic) { - return PStatic(make_object(dynamic)); -} +PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object(dynamic)); } -enum struct MatchStatus { - Match, NoMatch, Unknown -}; +enum struct MatchStatus { Match, NoMatch, Unknown }; bool StatefulOp(const Expr& e) { static auto op_stateful = Op::GetAttr("TOpIsStateful"); @@ -582,20 +550,16 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { FuncId fid; TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { - TVM_ATTR_FIELD(fid) - .describe("The FuncId that an function is annotated with.") - .set_default(-1); + TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with.").set_default(-1); } }; TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); - RELAY_REGISTER_OP("annotation.with_funcid") -.describe(R"code(Annotate a function with a funcid.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("func", "Function", "The input data."); + .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("func", "Function", "The input data."); // Cache with_funcid op to reduce lookup overhead during traversal. static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); @@ -624,7 +588,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const IRModule& mod) : mod_(mod) { } + PartialEvaluator(const IRModule& mod) : mod_(mod) {} PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -639,9 +603,8 @@ class PartialEvaluator : public ExprFunctor return VisitExpr(c->args[0], ll, name); } } - PStatic ret = e.as() ? - VisitFunc(Downcast(e), ll, name) : - VisitExpr(e, ll); + PStatic ret = + e.as() ? VisitFunc(Downcast(e), ll, name) : VisitExpr(e, ll); CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; return ret; } @@ -670,9 +633,7 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const VarNode* op, LetList* ll) final { - return env_.Lookup(GetRef(op)); - } + PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef(op)); } PStatic VisitGlobalVar(const GlobalVar& gv) { CHECK(mod_.defined()); @@ -714,15 +675,11 @@ class PartialEvaluator : public ExprFunctor } } else { Expr t = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->true_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; }); + }); Expr f = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->false_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; }); + }); store_.Invalidate(); return NoStatic(ll->Push(If(c->dynamic, t, f))); } @@ -782,16 +739,12 @@ class PartialEvaluator : public ExprFunctor PartialEvaluator* pe_; FuncId fid_; Fuel old_fuel; - FuelFrame(PartialEvaluator* pe, - FuncId fid, - const Fuel& new_fuel) : pe_(pe), fid_(fid) { + FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) { CHECK_GT(pe_->fuel_map_.count(fid_), 0); old_fuel = pe_->fuel_map_[fid_]; pe_->fuel_map_[fid_] = new_fuel; } - ~FuelFrame() { - pe_->fuel_map_[fid_] = old_fuel; - } + ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; } }; size_t GetFTValue(const PStatic& ps) { @@ -829,82 +782,76 @@ class PartialEvaluator : public ExprFunctor free_vars.push_back(std::pair(v, env_.Lookup(v))); } } - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { return env_.Extend([&]() { - CHECK_EQ(pv.size(), func->params.size()); - CHECK_GT(func_map_.count(func), 0); - FuncId fid = func_map_.at(func); - if (fuel_map_.count(fid) == 0) { - fuel_map_.insert({fid, MkFTop()}); + CHECK_EQ(pv.size(), func->params.size()); + CHECK_GT(func_map_.count(func), 0); + FuncId fid = func_map_.at(func); + if (fuel_map_.count(fid) == 0) { + fuel_map_.insert({fid, MkFTop()}); + } + std::vector args_fuel; + for (const auto& v : pv) { + args_fuel.push_back(GetFuel(v)); + } + auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); + if (std::get<1>(meet_res)) { + FuelFrame tf(this, fid, std::get<0>(meet_res)); + Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); + Function func = AsFunc(dedup_func); + if (var.as()) { + env_.Insert(Downcast(var), self); } - std::vector args_fuel; - for (const auto& v : pv) { - args_fuel.push_back(GetFuel(v)); + for (size_t i = 0; i < pv.size(); ++i) { + env_.Insert(func->params[i], pv[i]); + } + for (const auto& p : free_vars) { + env_.Insert(p.first, p.second); + } + tvm::Map subst; + for (size_t i = 0; i < type_args.size(); ++i) { + subst.Set(func->type_params[i], type_args[i]); } - auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); - if (std::get<1>(meet_res)) { - FuelFrame tf(this, fid, std::get<0>(meet_res)); - Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); - Function func = AsFunc(dedup_func); - if (var.as()) { - env_.Insert(Downcast(var), self); - } - for (size_t i = 0; i < pv.size(); ++i) { - env_.Insert(func->params[i], pv[i]); - } - for (const auto& p : free_vars) { - env_.Insert(p.first, p.second); - } - tvm::Map subst; - for (size_t i = 0; i < type_args.size(); ++i) { - subst.Set(func->type_params[i], type_args[i]); - } - for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], IncompleteType(kType)); - } - return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); - } else { - std::vector dyn; - for (const auto& v : pv) { - dyn.push_back(v->dynamic); - } - return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { + subst.Set(func->type_params[i], IncompleteType(kType)); } - }); + return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); + } else { + std::vector dyn; + for (const auto& v : pv) { + dyn.push_back(v->dynamic); + } + return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + } + }); }; } Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, - LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), func->ret_type, func->type_params, func->attrs); + return Function(func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + }), + func->ret_type, func->type_params, func->attrs); }); } - PStatic VisitFunc(const Function& func, - LetList* ll, - const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. // restore letrec support across whole relay. - return HasStatic(MkSFunc(f), - ll->Push(name, VisitFuncDynamic(u_func, f, name))); + return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name))); } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { @@ -912,7 +859,7 @@ class PartialEvaluator : public ExprFunctor } struct ReflectError : dmlc::Error { - ReflectError() : dmlc::Error("static value not found") { } + ReflectError() : dmlc::Error("static value not found") {} }; Expr Reflect(const PStatic& st) { @@ -954,31 +901,24 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate a expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - std::vector passes = {transform::FuseOps(0), - transform::InferType()}; + std::vector passes = {transform::FuseOps(0), transform::InferType()}; auto mod = IRModule::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); auto entry_func = Downcast(mod->Lookup("main")); - auto fused_infered = - expr.as() == nullptr ? entry_func->body : entry_func; + auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); } Func ConstEvaluateFunc(const Expr& expr) { CHECK_EQ(FreeVars(expr).size(), 0); - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array ns_args; for (const PStatic& ps : pv) { ns_args.push_back(ps->dynamic); } - auto ns = [&]() { - return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); - }; + auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); }; if (StatefulOp(expr)) { return ns(); } @@ -988,8 +928,7 @@ class PartialEvaluator : public ExprFunctor args.push_back(Reflect(ps)); } return ConstEvaluate(Call(expr, args, attrs, type_args), ll); - } - catch (const ReflectError&) { + } catch (const ReflectError&) { return ns(); } }; @@ -1001,11 +940,8 @@ class PartialEvaluator : public ExprFunctor PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { Constructor c = GetRef(op); - Func f = [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + Func f = [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array dyn; for (const PStatic& ps : pv) { dyn.push_back(ps->dynamic); @@ -1020,30 +956,30 @@ class PartialEvaluator : public ExprFunctor return env_.Extend([&]() { for (const Clause& c : op->clauses) { switch (VisitPattern(c->lhs, ps)) { - case MatchStatus::Match: - return VisitExpr(c->rhs, ll); - case MatchStatus::NoMatch: - continue; - case MatchStatus::Unknown: - return [&]() { - tvm::Array clauses; - for (const Clause& c : op->clauses) { - Expr expr = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - for (const Var& v : BoundVars(c->lhs)) { - env_.Insert(v, NoStatic(v)); - } - return VisitExpr(c->rhs, ll)->dynamic; + case MatchStatus::Match: + return VisitExpr(c->rhs, ll); + case MatchStatus::NoMatch: + continue; + case MatchStatus::Unknown: + return [&]() { + tvm::Array clauses; + for (const Clause& c : op->clauses) { + Expr expr = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + for (const Var& v : BoundVars(c->lhs)) { + env_.Insert(v, NoStatic(v)); + } + return VisitExpr(c->rhs, ll)->dynamic; + }); }); - }); - clauses.push_back(Clause(c->lhs, expr)); - } - store_.Invalidate(); - return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); - }(); - default: - LOG(FATAL) << "Unknown MatchStatus"; - throw; + clauses.push_back(Clause(c->lhs, expr)); + } + store_.Invalidate(); + return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); + }(); + default: + LOG(FATAL) << "Unknown MatchStatus"; + throw; } } LOG(FATAL) << "No case Match"; @@ -1071,12 +1007,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1095,12 +1031,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1112,7 +1048,7 @@ class PartialEvaluator : public ExprFunctor void InitializeFuncId(const Expr& e) { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1121,9 +1057,7 @@ class PartialEvaluator : public ExprFunctor VisitExpr(f->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; InitializeFuncIdVisitor(this).VisitExpr(e); } @@ -1131,7 +1065,7 @@ class PartialEvaluator : public ExprFunctor Expr RegisterFuncId(const Expr& e) { struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const CallNode* op) final { if (op->op == with_funcid_op) { @@ -1154,9 +1088,7 @@ class PartialEvaluator : public ExprFunctor ExprVisitor::VisitExpr_(op); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; RegisterFuncIdVisitor(this).VisitExpr(e); return e; @@ -1165,7 +1097,7 @@ class PartialEvaluator : public ExprFunctor Expr AnnotateFuncId(const Expr& e) { struct AnnotateFuncIdMutator : ExprMutator, PatternMutator { PartialEvaluator* pe; - explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { } + explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {} Expr VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1173,13 +1105,9 @@ class PartialEvaluator : public ExprFunctor return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return AnnotateFuncIdMutator(this).VisitExpr(e); } @@ -1199,7 +1127,8 @@ class PartialEvaluator : public ExprFunctor * If no progress is made, we do not inline. * In both case, we remap the mapping to the new Fuel * when we PE inside the Function body. - * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. + * Termination is guaranteed because Fuel is finitely descending - there can only be so many + * meet. */ std::unordered_map func_map_; std::unordered_map fuel_map_; @@ -1219,9 +1148,7 @@ Expr Remap(const Expr& e) { return remap_.at(v); } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } private: std::unordered_map remap_; @@ -1240,20 +1167,14 @@ Expr StripWithFuncId(const Expr& e) { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return StripWithFuncIdMutator().VisitExpr(e); } -Expr PostProcess(const Expr& e) { - return StripWithFuncId(DeDup(Remap(e))); -} +Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval @@ -1273,14 +1194,11 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::PartialEval(m); - }; + [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } -TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate") -.set_body_typed(PartialEval); +TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); } // namespace transform diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 15ad60be3a955..634434d3a3d0b 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -54,39 +54,30 @@ namespace partitioning { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -/*! - * \brief The checker that verifies if a Relay program is annotated correctly - * for partitioning. +/*! \brief This struct maintains the required metadata for a region to generate a corresponding + * global function and function call. Global function will be passed to the target specific codegen + * and function call will be used in the transform Relay graph to invoke the function in runtime. */ -class AnnotationChecker : public ExprVisitor { - public: - bool Check() { - if (!found_start_ && !found_end_) { - LOG(WARNING) << "No compiler annotation found"; - } else if (!found_start_) { - LOG(ERROR) << "compiler_begin annotation is missing"; - return false; - } else if (!found_end_) { - LOG(ERROR) << "compiler_end annotation is missing"; - return false; - } - return true; - } +struct RegionFuncMetadata { + /*! \brief The call node of the generated global function for this region. */ + Call func_call; - void VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - return; - } else if (call->op == compiler_begin_op) { - found_start_ = true; - } else if (call->op == compiler_end_op) { - found_end_ = true; - } - } + /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used + * as a function node argument; input expression is used as a function call parameter. + */ + std::vector> args; - private: - bool found_start_{false}; - bool found_end_{false}; + /*! \brief Map from each region output expr (compiler end) node to + * the corresponding function output expr. + */ + std::unordered_map region_func_out; + + /*! \brief Map from each region input expression (compiler begin) to + * the corresponding function input variable. This cache is used to make sure + * a region function will not have duplicated inputs even if it refers to + * the same expr multiple times. + */ + std::unordered_map region_func_in; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor { * the compiler name. */ -class Partitioner : public ExprMutator { +class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; - // Creating regionset per function in the module + // Creating regionset per function in the module. auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, partitioning::compiler_end_op); regions_sets_[region_set] = f_func; } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - return ExprMutator::VisitExpr_(call); + return post; } else if (call->op == compiler_begin_op) { - // The annotation node is inserted on edge so it must have only one - // argument. + // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); // Traverse the rest graph. Expr parent = call->args[0]; - auto input_expr = VisitExpr(parent); + auto input_expr = Downcast(post)->args[0]; // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { - if (parent_call->op == compiler_begin_op || - parent_call->op == compiler_end_op) { + if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) { parent = parent_call->args[0]; } else { break; @@ -165,8 +154,8 @@ class Partitioner : public ExprMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { - return shared_output_[parent][sg]; + if (region_func_meta_[sg].region_func_in.count(parent)) { + return region_func_meta_[sg].region_func_in[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -177,11 +166,11 @@ class Partitioner : public ExprMutator { std::pair cand = std::make_pair(var, input_expr); - if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == - region_args[sg].end()) { - region_args[sg].push_back(cand); + if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) == + region_func_meta_[sg].args.end()) { + region_func_meta_[sg].args.push_back(cand); } - shared_output_[parent][sg] = var; + region_func_meta_[sg].region_func_in[parent] = var; return std::move(var); } } else { @@ -197,114 +186,21 @@ class Partitioner : public ExprMutator { BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. - auto input = VisitExpr(call->args[0]); + auto input = Downcast(post)->args[0]; CHECK(region.defined()) << "Region not defined for " << GetRef(call); // functions are created for each annotated regions, // when their first output is encountered. // If multiple outputs are there, a tuple node is inserted at the end. - // region_function_calls is map that maintains - // (each annotated regions) --> created function - if (region_function_calls.find(region) == region_function_calls.end()) { - // First time this region is encountered in the traversal. - // Creating the function. + if (!region_func_meta_[region].func_call.defined()) { + // First time this region is encountered in the traversal. Creating the function. CreateFunction(region, call); } - // Retrieve this particular output of function. - return GetFunctionOutput(region, GetRef(call)); - } - } - - Expr VisitExpr_(const TupleNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array fields; - for (auto field : op->fields) { - fields.push_back(VisitExpr(field)); - } - return Tuple(fields); - } - } - - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto region = GetRegion(GetRef(g)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(g); - } else { - auto t = VisitExpr(g->tuple); - return TupleGetItem(t, g->index); - } - } - - Expr VisitExpr_(const FunctionNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array params; - for (auto param : op->params) { - Var new_param = Downcast(VisitExpr(param)); - params.push_back(new_param); - } - auto body = VisitExpr(op->body); - return Function(params, body, op->ret_type, op->type_params, op->attrs); - } - } - - Expr VisitExpr_(const LetNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Var var = Downcast(VisitExpr(op->var)); - auto value = VisitExpr(op->value); - auto body = VisitExpr(op->body); - return Let(var, value, body); - } - } - - Expr VisitExpr_(const IfNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - auto guard = VisitExpr(op->cond); - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - return If(guard, true_b, false_b); - } - } - - Expr VisitExpr_(const RefCreateNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr value = VisitExpr(op->value); - return RefCreate(value); - } - } - - Expr VisitExpr_(const RefReadNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - return RefRead(ref); - } - } - Expr VisitExpr_(const RefWriteNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - Expr value = VisitExpr(op->value); - return RefWrite(ref, value); + // Retrieve this particular output of function. + Expr region_out_expr = Downcast(GetRef(call))->args[0]; + CHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); + return region_func_meta_[region].region_func_out[region_out_expr]; } } @@ -370,35 +266,41 @@ class Partitioner : public ExprMutator { } /*! - * \brief This function is called first time that we encounter a compiler_end - * node to create the function for the subgraph. + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. */ - void CreateFunction(AnnotatedRegion region, const CallNode* call) { - // Create fields which is a unique list of outputs. Also populate - // region_return_indices_ map which maps parent of compiler_end node to - // corresponding index in fields. + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { + // Create fields which is a unique list of outputs. Array fields; - int i = 0; - for (auto ret : region->GetOutputs()) { - auto ret_node = Downcast(ret)->args[0]; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; // Don't duplicate outputs. - if (!region_return_indices_.count(region) || - !region_return_indices_[region].count(ret_node)) { - auto ret_expr = VisitExpr(ret_node); + if (!out_expr_to_idx.count(ret_node)) { + auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_return_indices_[region][ret_node] = i; - i++; + out_expr_to_idx[ret_node] = out_idx++; } } Array params; Array param_expr; - std::unordered_map params_bind; + Map params_bind; + + auto IsConstant = [](const Expr& expr) { + if (expr->IsInstance()) return true; + if (!expr->IsInstance()) return false; + const auto* tn = expr.as(); + return std::all_of(tn->fields.begin(), tn->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); + }; - for (auto pair : region_args[region]) { + for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; + if (IsConstant(pair.second)) { + params_bind.Set(pair.first, pair.second); } else { param_expr.push_back(pair.second); } @@ -408,32 +310,29 @@ class Partitioner : public ExprMutator { if (fields.size() == 1) { // If there are only a single output; no need to add a tuple global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); } else { auto tuple = Tuple(fields); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); } - std::string target = call->attrs.as()->compiler; + std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); + WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); // Constant propagation if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); + global_region_func = Downcast(relay::Bind(global_region_func, params_bind)); } std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; + CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; // Create a global function and add it to the IRModule for the region. // This way we lift the functions that should be handled by external // codegen to the module scope and rely on the pass manager to prevent @@ -442,129 +341,81 @@ class Partitioner : public ExprMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - } + // Create a call node for the function. + auto call = Call(glob_func, param_expr); + region_func_meta_[region].func_call = call; - /*! - * \brief Get the return(output) of the function for compiler end node "end_arg". - * This will return either a Call (for a function with a single output) or a - * TupleGetItem (for a function with multiple outputs). - */ - Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { - Expr arg = Downcast(end_arg)->args[0]; - // Function has one output. - if (region_return_indices_[region].size() == 1) { - return region_function_calls[region]; - } - // Function has multiple outputs. - // Use already made TupleGetItem. - if (region_return_tuplegetitem_.count(region) && - region_return_tuplegetitem_[region].count(arg)) { - return region_return_tuplegetitem_[region][arg]; + // Create output expr(s) for the function call. + if (out_expr_to_idx.size() == 1) { + // Single output direcly uses the call node as the output expr. + region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; + } else { + // Multiple outptus need to create TupleGetItem nodes as output exprs. + for (auto pair : out_expr_to_idx) { + Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. + int idx = pair.second; // Corresponding function output tuple index. + auto tuple_get_item = TupleGetItem(call, idx); + tuple_get_item->checked_type_ = region_out_expr->checked_type_; + region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; + } } - // Create new TupleGetItem. - CHECK(region_return_indices_.count(region) && - region_return_indices_[region].count(arg)); - int index = region_return_indices_[region][arg]; - - auto func_call = region_function_calls[region]; - auto tuple_get_item_ = TupleGetItem(func_call, index); - tuple_get_item_->checked_type_ = arg->checked_type_; - region_return_tuplegetitem_[region][arg] = tuple_get_item_; - return std::move(tuple_get_item_); } - /*! - * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs - * to call - */ - std::unordered_map region_function_calls; - - /*! - * \brief This map maintains arguments (of region) visits through visitor - * patterns. Those arguement var and expression will be used to when creating - * the function. - */ - std::unordered_map>, ObjectHash, ObjectEqual> - region_args; - - /*! - * \brief This map maintains the index of an output in the subgraph function - * for a given region. If there are multiple entries for a region, then the - * function has a tuple of multiple outputs for its return. - */ - using RegionRetIndexMap = std::unordered_map; - std::unordered_map - region_return_indices_; + /*! \brief Map from each region to its metadata of the generated function. */ + std::unordered_map + region_func_meta_; - /*! - * \brief This map holds already created TupleGetItem nodes for accessing - * outputs of a function. - */ - using RegionRetTupleGetItemMap = std::unordered_map; - std::unordered_map - region_return_tuplegetitem_; - - /*! - * \brief Each region set is associated with a function in the module. + /*! \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it * belongs to */ std::unordered_map regions_sets_; - /*!\brief Cache the output that is shared by different nodes. */ - using RegionOutputMap = std::unordered_map; - std::unordered_map shared_output_; - /*!\brief The IRModule used for partitioning. */ IRModule module_; }; -class DefaultRemover : public ExprMutator { - public: - explicit DefaultRemover(const IRModule& module) : module_(module) {} +IRModule RemoveDefaultAnnotations(IRModule module) { + class DefaultRemover : public ExprRewriter { + public: + DefaultRemover() = default; - IRModule Remove() { - auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Update(pair.first, func); + Expr Rewrite_(const CallNode* call, const Expr& post) final { + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return Downcast(post)->args[0]; } + return post; } - return module_; - } + }; - Expr VisitExpr_(const CallNode* call) final { - auto attrs = call->attrs.as(); - if (attrs != nullptr && attrs->compiler == "default") { - return VisitExpr(call->args[0]); + auto glob_funcs = module->functions; + // module is mutable, hence, we make a copy of it. + module.CopyOnWrite(); + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + DefaultRemover remover; + auto removed = PostOrderRewrite(func->body, &remover); + func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + module->Update(pair.first, func); } - return ExprMutator::VisitExpr_(call); } - - private: - IRModule module_; -}; + return module; +} } // namespace partitioning namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute - // by treating them as un-annotated, but we don't have it yet. This workaround pass removes - // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::DefaultRemover(m).Remove(); - return partitioning::Partitioner(new_m).Partition(); + runtime::TypedPackedFunc part_func = [=](IRModule m, + PassContext pc) { + // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute + // by treating them as un-annotated, but we don't have it yet. This workaround pass removes + // all "default" annotations and should be deleted in the future. + auto new_m = partitioning::RemoveDefaultAnnotations(m); + return partitioning::Partitioner(new_m).Partition(); }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 507b3ccb11526..cbdd4b4a626b1 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -25,9 +25,10 @@ #ifndef TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ -#include -#include #include +#include +#include + #include #include @@ -114,41 +115,37 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } -template +template struct TreeNode { typedef std::shared_ptr> pointer; virtual ~TreeNode() {} }; -template +template struct TreeLeafNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; Expr body; - explicit TreeLeafNode(Expr body): body(body) {} + explicit TreeLeafNode(Expr body) : body(body) {} - static TreeObjectPtr Make(Expr body) { - return std::make_shared(body); - } + static TreeObjectPtr Make(Expr body) { return std::make_shared(body); } ~TreeLeafNode() {} }; -template +template struct TreeLeafFatalNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; TreeLeafFatalNode() = default; - static TreeObjectPtr Make() { - return std::make_shared(); - } + static TreeObjectPtr Make() { return std::make_shared(); } ~TreeLeafFatalNode() {} }; -template +template struct TreeBranchNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; @@ -156,15 +153,11 @@ struct TreeBranchNode : TreeNode { TreeObjectPtr then_branch; TreeObjectPtr else_branch; - TreeBranchNode(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) - : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - + TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch) + : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - static TreeObjectPtr Make(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { + static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { return std::make_shared(cond, then_branch, else_branch); } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 5f06bcf0326f2..6d213b2cb7334 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -28,19 +28,18 @@ #include #include -#include -#include -#include #include #include -#include #include +#include +#include +#include #include +#include #include -#include #include - +#include namespace tvm { namespace relay { @@ -49,42 +48,42 @@ namespace relay { * \brief Dispatch DataType to the C++ data type * during runtime. */ -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } /*! @@ -99,10 +98,8 @@ namespace relay { * \param rhs_value A squeezed version of rhs which only contains matched dimension. * \return Whether match is successful. */ -inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, - const TensorTypeNode* trhs, - const Array& lhs_axes, - Expr* rhs_value = nullptr) { +inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs, + const Array& lhs_axes, Expr* rhs_value = nullptr) { if (tlhs->shape.size() < trhs->shape.size()) return false; StructuralEqual equal; size_t base = tlhs->shape.size() - trhs->shape.size(); @@ -145,9 +142,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ -inline Expr ExpandBiasToMatchAxis(Expr bias, - int target_ndim, - const Array& axes) { +inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array& axes) { static const Op& expand_dims = Op::Get("expand_dims"); for (size_t i = axes.size(); i != 0; --i) { if (i == axes.size()) { @@ -179,14 +174,12 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, * \param param The conv2d attributes. * \return Whether it is depthwise_conv2d. */ -inline bool IsDepthwiseConv2D(const Call& call, - const Conv2DAttrs* param, +inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as()->shape); - return tir::is_const_int(wshape[0], param->groups) && - tir::is_const_int(wshape[1], 1); + return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1); } /*! @@ -195,12 +188,12 @@ inline bool IsDepthwiseConv2D(const Call& call, * \return Super-dimension size of output channels of conv2d. */ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { - auto param = call->attrs.as(); - auto tweight = call->args[1]->type_as(); - auto index = param->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - auto channels = tir::as_const_int(tweight->shape[index]); - return *channels; + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = tir::as_const_int(tweight->shape[index]); + return *channels; } /*! @@ -332,13 +325,9 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { return tvm::StructuralEqual()(a, b); } -inline Expr GetField(Expr t, size_t i) { - return TupleGetItem(t, i); -} +inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } -inline Expr Pair(Expr l, Expr r) { - return Tuple({l, r}); -} +inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } inline Expr Exp(Expr e) { static const Op& op = Op::Get("exp"); @@ -390,25 +379,21 @@ inline Expr Negative(Expr x) { return Call(op, {x}, Attrs(), {}); } - inline Expr Sqrt(Expr x) { static const Op& op = Op::Get("sqrt"); return Call(op, {x}, Attrs(), {}); } - inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); return Call(op, {x}, Attrs(), {}); } - inline Expr Round(Expr x) { static const Op& op = Op::Get("round"); return Call(op, {x}, Attrs(), {}); } - inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); auto attrs = make_object(); @@ -417,25 +402,21 @@ inline Expr Clip(Expr x, double a_min, double a_max) { return Call(op, {x}, Attrs(attrs), {}); } - inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Subtract(Expr lhs, Expr rhs) { static const Op& op = Op::Get("subtract"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); return Call(op, {lhs, rhs}, Attrs(), {}); @@ -474,31 +455,26 @@ inline Expr Power(Expr lhs, Expr rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr RightShift(Expr x, Expr nbit) { static const Op& op = Op::Get("right_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr LeftShift(Expr x, Expr nbit) { static const Op& op = Op::Get("left_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Copy(Expr data) { static const Op& op = Op::Get("copy"); return Call(op, {data}, Attrs(), {}); } - inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); @@ -517,7 +493,6 @@ inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, b return Call(op, {data, mean}, Attrs(attrs), {}); } - static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); return Call(op, {condition, x, y}); @@ -528,9 +503,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } -static inline Expr Full(Expr fill_value, - Array shape, - DataType dtype) { +static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); @@ -557,10 +530,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, return Call(op, {data, weight}, Attrs(attrs), {}); } -static inline Expr Dense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index 6cd77f424d18d..99ded0ba591d0 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -128,20 +128,12 @@ Pass SimplifyFCTranspose(const Array& target_weights) { // Remove FreeVar warning auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); Array wt_params = FreeVars(f0); - auto f1 = Function(wt_params, - f0->body, - f0->ret_type, - f0->type_params, - f0->attrs); + auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); for (const auto& var : wt_params) { params.push_back(var); } - return Function(params, - f1->body, - f1->ret_type, - f1->type_params, - f1->attrs); + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); }; return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index a9ceec26ce06c..7c33947e39621 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -21,22 +21,18 @@ * \file simplify_inference.cc */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { namespace relay { -Expr BatchNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Expr moving_mean, - Expr moving_var, - Type tdata) { +Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, + Expr moving_var, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -64,12 +60,7 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } - -Expr GroupNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -88,20 +79,20 @@ Expr GroupNormToInferUnpack(const Attrs attrs, // new shape = N, num_groups, C/num_groups, H, W // reduce_axes = axis of (C/num_groups, H, W) for (int i = 0; i < ndim; ++i) { - auto val = ttype->shape[i].as()->value; - - // Save the old shape to reshape later - old_shape.push_back(val); - if (i == axis) { - new_shape.push_back(num_groups); - new_shape.push_back(channel / num_groups); - reduced_axes.push_back(i + 1); - continue; - } - if (i >= axis) { - reduced_axes.push_back(i + 1); - } - new_shape.push_back(val); + auto val = ttype->shape[i].as()->value; + + // Save the old shape to reshape later + old_shape.push_back(val); + if (i == axis) { + new_shape.push_back(num_groups); + new_shape.push_back(channel / num_groups); + reduced_axes.push_back(i + 1); + continue; + } + if (i >= axis) { + reduced_axes.push_back(i + 1); + } + new_shape.push_back(val); } data = Reshape(data, new_shape); @@ -124,11 +115,7 @@ Expr GroupNormToInferUnpack(const Attrs attrs, return out; } -Expr LayerNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -151,11 +138,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, return out; } -Expr InstanceNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -165,8 +148,7 @@ Expr InstanceNormToInferUnpack(const Attrs attrs, int axis = (param->axis < 0) ? param->axis + ndim : param->axis; Array reduced_axes; for (int i = 1; i < ndim; ++i) { - if (i != axis) - reduced_axes.push_back(i); + if (i != axis) reduced_axes.push_back(i); } Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast(param->epsilon)); @@ -259,22 +241,19 @@ class InferenceSimplifier : public ExprMutator { std::unordered_map ty_map_; }; -Expr SimplifyInference(const Expr& e) { - return InferenceSimplifier().Mutate(e); -} +Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } namespace transform { Pass SimplifyInference() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(SimplifyInference(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") -.set_body_typed(SimplifyInference); +TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference); } // namespace transform diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 21c516201dd73..c0c92860e3459 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -26,12 +26,12 @@ #include #include #include -#include #include -#include "let_list.h" -#include "pass_util.h" + #include "../../support/arena.h" #include "../analysis/dependency_graph.h" +#include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -47,13 +47,11 @@ struct ScopeNode { size_t level; Scope parent; std::shared_ptr ll = std::make_shared(); - explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } - ScopeNode() : level(0) { } + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} + ScopeNode() : level(0) {} }; -Scope ChildScope(const Scope& s) { - return std::make_shared(s); -} +Scope ChildScope(const Scope& s) { return std::make_shared(s); } Scope LCA(Scope lhs, Scope rhs) { while (lhs != rhs) { @@ -100,8 +98,7 @@ std::unordered_map CalcScope(const DependencyGrap */ class Fill : ExprFunctor { public: - static Expr ToANormalForm(const Expr& e, - const DependencyGraph& dg, + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, std::unordered_map* node_scope) { Fill fi(dg, node_scope); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); @@ -112,14 +109,10 @@ class Fill : ExprFunctor { std::unordered_map* node_scope_; std::unordered_map memo; - Fill(const DependencyGraph& dg, - std::unordered_map* node_scope) : - dg_(dg), - node_scope_(node_scope) { } + Fill(const DependencyGraph& dg, std::unordered_map* node_scope) + : dg_(dg), node_scope_(node_scope) {} - Scope GetScope(const Expr& e) { - return node_scope_->at(dg_.expr_node.at(e)); - } + Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } Scope GetSubScope(const Expr& e, size_t i) { DependencyGraph::Node* n = dg_.expr_node.at(e); @@ -144,18 +137,12 @@ class Fill : ExprFunctor { return ret; } - Expr VisitExpr(const Expr& e) { - return this->VisitExpr(e, Var()); - } + Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& e, const Var& v) { - return v.defined() ? GetScope(e)->ll->Push(v, e) : e; - } + Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? - v : - Var(std::string("x"), Type()); + Var var = v.defined() ? v : Var(std::string("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -199,9 +186,8 @@ class Fill : ExprFunctor { Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), - GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); return Compound(e, ret, v); } @@ -211,11 +197,8 @@ class Fill : ExprFunctor { if (f->HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { - ret = Function(f->params, - GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), - f->ret_type, - f->type_params, - f->attrs); + ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); } return Compound(e, ret, v); } @@ -257,9 +240,8 @@ class Fill : ExprFunctor { Expr data = VisitExpr(m->data); std::vector clauses; for (const Clause& c : m->clauses) { - clauses.push_back(Clause( - c->lhs, - GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); + clauses.push_back( + Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } return Compound(e, Match(data, clauses, m->complete), v); } @@ -301,14 +283,9 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = - TransformF([&](const Expr& e) { - return ToANormalFormAux(e); - }, it.second); + Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) - << AsText(ret) - << "should not has free vars: " - << FreeVars(ret); + << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); } @@ -325,14 +302,11 @@ namespace transform { Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::ToANormalForm(m); - }; + [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm") -.set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c83928b098a..81545b685068f 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -51,9 +51,10 @@ * wheter directly invoking it, or indirectly by recursion. */ #include -#include #include #include +#include + #include "let_list.h" #include "pass_util.h" @@ -62,9 +63,7 @@ namespace relay { // we assume the data type has no closure - no idea how to look into datatype right now. -Type Arrow(const Type& l, const Type& r) { - return FuncType({l}, r, {}, {}); -} +Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); } Type CPSType(const Type& t, const TypeVar& answer); @@ -79,7 +78,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { Type CPSType(const Type& t, const TypeVar& answer) { struct CPSTypeMutator : TypeMutator { - explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { } + explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {} TypeVar answer; Type VisitType_(const FuncTypeNode* t) final { return CPSFuncType(GetRef(t), answer); @@ -113,22 +112,15 @@ using MCont = std::function; Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm); -Function ToCPS(const Function& f, - const IRModule& m, - CPSMap* cm, - VarMap* vm, +Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { - std::function remap = [&](const Var& v) { - return vm->count(v) == 0 ? v : vm->at(v); - }; + std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; auto function_type = Downcast(f->checked_type()); // Each MCont can be used at most once. struct CPSFunctor : ExprFunctor, PatternMutator { - CPSFunctor(const std::function& remap, - const TypeVar& answer, - const IRModule& m, - VarMap* vm, - CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } + CPSFunctor(const std::function& remap, const TypeVar& answer, const IRModule& m, + VarMap* vm, CPSMap* cm) + : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {} const std::function& remap; TypeVar answer; IRModule m; @@ -136,9 +128,8 @@ Function ToCPS(const Function& f, CPSMap* cm; Expr VisitExpr_(const LetNode* op, const MCont& k) final { - return VisitExpr(op->value, [&](const Expr& v) { - return Let(remap(op->var), v, VisitExpr(op->body, k)); - }); + return VisitExpr( + op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); }); } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { @@ -150,13 +141,9 @@ Function ToCPS(const Function& f, return k(GetRef(op)); } - Expr VisitExpr_(const VarNode* op, const MCont& k) final { - return k(remap(GetRef(op))); - } + Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef(op))); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(remap(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); } Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); @@ -186,16 +173,14 @@ Function ToCPS(const Function& f, } Expr reify(const MCont& k, const std::function& cont) { - return LetList::LetBind(reify(k), - [&](const Var& f) { + return LetList::LetBind(reify(k), [&](const Var& f) { return cont([&](const Expr& e) { return Call(f, {e}); }); }); } Expr VisitExpr_(const IfNode* op, const MCont& k) final { return reify(k, [&](const MCont& kf) { - return VisitExpr(op->cond, - [&](const Expr& v) { + return VisitExpr(op->cond, [&](const Expr& v) { return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); }); }); @@ -214,19 +199,13 @@ Function ToCPS(const Function& f, } Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { - return LetList::LetBind(RefRead(r), k); - }); + return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); }); } Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { + return VisitExpr(op->ref, [&](const Expr& r) { return VisitExpr(op->value, - [&](const Expr& v) { - return LetList::LetBind(RefWrite(r, v), k); - }); + [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); }); }); } @@ -234,20 +213,18 @@ Function ToCPS(const Function& f, tvm::Array fields; std::function next; next = [&]() { - return (fields.size() == op->fields.size()) ? - k(Tuple(fields)) : - VisitExpr(op->fields[fields.size()], [&](const Expr& v) { - fields.push_back(v); - return next(); - }); + return (fields.size() == op->fields.size()) + ? k(Tuple(fields)) + : VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + fields.push_back(v); + return next(); + }); }; return next(); } Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { - return VisitExpr(op->tuple, [&](const Expr& v) { - return k(TupleGetItem(v, op->index)); - }); + return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); }); } Expr VisitExpr_(const CallNode* op, const MCont& k) final { @@ -259,9 +236,9 @@ Function ToCPS(const Function& f, return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k); } else { return VisitExpr(op->args[args.size()], [&](const Expr& v) { - args.push_back(v); - return next(); - }); + args.push_back(v); + return next(); + }); } }; return next(); @@ -279,7 +256,7 @@ Function ToCPS(const Function& f, return next(); }); } - }; + }; return VisitExpr(op->op, [&](const Expr& v) { f = v; return next(); @@ -293,19 +270,15 @@ Function ToCPS(const Function& f, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, - mut.VisitExpr(f->body, - [&](const Expr& e) { return Call(k, {e}); }), - answer, - f->type_params, - f->attrs); + return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), + answer, f->type_params, f->attrs); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { TypeVar answer = TypeVar("answer", kType); VarMap var; struct Remapper : ExprVisitor, PatternVisitor { - Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } + Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {} TypeVar answer; VarMap* vm; void VisitExpr_(const VarNode* vn) final { @@ -316,13 +289,9 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - VisitExpr(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); } } remap(answer, &var); remap.VisitExpr(f); Function ret = ToCPS(f, m, cm, &var, answer); @@ -366,43 +335,32 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, - Call(f, args, {}, type_args), - new_ret_type, - new_type_params, - f->attrs); + return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, + f->attrs); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") -.set_body_typed(static_cast(ToCPS)); + .set_body_typed(static_cast(ToCPS)); -TVM_REGISTER_GLOBAL("relay._transform.un_cps") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.un_cps").set_body_typed(UnCPS); namespace transform { Pass ToCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(ToCPS(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); }; return CreateFunctionPass(pass_func, 1, "ToCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToCPS") -.set_body_typed(ToCPS); - +TVM_REGISTER_GLOBAL("relay._transform.ToCPS").set_body_typed(ToCPS); Pass UnCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(UnCPS(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); }; return CreateFunctionPass(pass_func, 1, "UnCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.UnCPS") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.UnCPS").set_body_typed(UnCPS); } // namespace transform diff --git a/src/relay/transforms/to_graph_normal_form.cc b/src/relay/transforms/to_graph_normal_form.cc index 8bf41a4610c04..2e6c5456249f2 100644 --- a/src/relay/transforms/to_graph_normal_form.cc +++ b/src/relay/transforms/to_graph_normal_form.cc @@ -26,6 +26,7 @@ #include #include #include + #include "let_list.h" namespace tvm { @@ -33,7 +34,7 @@ namespace relay { class UseVarVisitor : public ExprVisitor { public: - explicit UseVarVisitor(const Var& v) : v(v) { } + explicit UseVarVisitor(const Var& v) : v(v) {} static bool UseVar(const Var& v, const Expr& e) { UseVarVisitor uv(v); @@ -45,9 +46,7 @@ class UseVarVisitor : public ExprVisitor { bool use_var = false; Var v; - void VisitExpr_(const VarNode* vn) override { - use_var = use_var || (v == GetRef(vn)); - } + void VisitExpr_(const VarNode* vn) override { use_var = use_var || (v == GetRef(vn)); } }; class GNF : public ExprMutator { @@ -58,9 +57,7 @@ class GNF : public ExprMutator { return var_map_.count(v) == 0 ? v : var_map_.at(v); } - static bool UseVar(const Var& v, const Expr& e) { - return UseVarVisitor::UseVar(v, e); - } + static bool UseVar(const Var& v, const Expr& e) { return UseVarVisitor::UseVar(v, e); } static Expr WrapRec(const Var& var, const Expr& val) { return UseVar(var, val) ? Let(var, val, var) : val; @@ -72,22 +69,19 @@ class GNF : public ExprMutator { } }; -Expr ToGraphNormalForm(const Expr& e) { - return GNF()(e); -} +Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); } namespace transform { Pass ToGraphNormalForm() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToGraphNormalForm(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToGraphNormalForm(f)); + }; return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm") -.set_body_typed(ToGraphNormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm").set_body_typed(ToGraphNormalForm); } // namespace transform diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index b6e75ae4f5855..19632defc8262 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -26,14 +26,16 @@ #ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ #define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ -#include #include +#include + #include -#include #include +#include #include -#include "pattern_util.h" + #include "infer_layout_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -49,8 +51,8 @@ class TransformMemorizerNode : public Object { struct key_hash : public std::function { std::size_t operator()(const TransformKey& k) const { return dmlc::HashCombine( - dmlc::HashCombine( - std::hash()(std::get<0>(k)), std::get<1>(k)), + dmlc::HashCombine(std::hash()(std::get<0>(k)), + std::get<1>(k)), (std::get<2>(k))); } }; @@ -300,8 +302,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { success = false; - std::tie(new_in2, new_out, success) = - InferCorrectLayouts(new_call, new_in, old_in, types); + std::tie(new_in2, new_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types); if (!success) { return Expr(nullptr); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 3a16d8ff793b7..0782484835879 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -37,14 +37,15 @@ * If we can not infer a type or there are conflicting typing * constraints we will trigger an error. */ -#include #include +#include +#include #include #include -#include #include -#include "pass_util.h" + #include "../analysis/type_solver.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -53,21 +54,16 @@ namespace relay { struct TupleGetItemAttrs : public tvm::AttrsNode { int index; - TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { - TVM_ATTR_FIELD(index); - } + TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { TVM_ATTR_FIELD(index); } }; -bool TupleGetItemRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TupleGetItemRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); if (types[0].as()) return false; const auto* data = types[0].as(); - CHECK(data != nullptr) - << "TupleGetItem expect input type to be TupleType " - << " get " << types[0] << " instead"; + CHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType " + << " get " << types[0] << " instead"; const auto* param = attrs.as(); CHECK(param != nullptr); CHECK_GE(param->index, 0); @@ -77,9 +73,7 @@ bool TupleGetItemRel(const Array& types, } TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); -TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") -.set_body_typed( - TupleGetItemRel); +TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem").set_body_typed(TupleGetItemRel); struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) @@ -105,8 +99,10 @@ class TypeInferencer : private ExprFunctor, // constructors explicit TypeInferencer(IRModule mod, GlobalVar current_func) - : mod_(mod), current_func_(current_func), - err_reporter(), solver_(current_func, mod, &this->err_reporter) { + : mod_(mod), + current_func_(current_func), + err_reporter(), + solver_(current_func, mod, &this->err_reporter) { CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; } @@ -140,22 +136,16 @@ class TypeInferencer : private ExprFunctor, Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { try { return solver_.Unify(t1, t2, expr); - } catch (const dmlc::Error &e) { + } catch (const dmlc::Error& e) { this->ReportFatalError( - expr, - ErrorBuilder() - << "Error unifying `" - << t1 - << "` and `" - << t2 - << "`: " << e.what()); + expr, ErrorBuilder() << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); return Type(); } } // Lazily get type for expr // expression, we will populate it now, and return the result. - Type GetType(const Expr &expr) { + Type GetType(const Expr& expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; @@ -186,19 +176,15 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); if (!mod_.defined()) { - this->ReportFatalError( - GetRef(op), - ErrorBuilder() << - "Cannot do type inference on global variables " \ - "without a module"); + this->ReportFatalError(GetRef(op), + ErrorBuilder() << "Cannot do type inference on global variables " + "without a module"); } Expr e = mod_->Lookup(var); return e->checked_type(); } - Type VisitExpr_(const ConstantNode* op) final { - return op->tensor_type(); - } + Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } Type VisitExpr_(const TupleNode* op) final { Array types; @@ -209,23 +195,22 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const TupleGetItemNode* op) final { - if (!tuple_getitem_rel_.defined()) { - tuple_getitem_rel_ = Downcast( - EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); + if (!tuple_getitem_rel_.defined()) { + tuple_getitem_rel_ = + Downcast(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); } Type tuple_type = GetType(op->tuple); Type rtype = IncompleteType(Kind::kType); auto attrs = make_object(); attrs->index = op->index; - solver_.AddConstraint(TypeRelation( - tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); + solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), + GetRef(op)); return rtype; } void VisitPattern_(const PatternConstructorNode* con, const Type& t) { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << con->constructor->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" + << con->constructor->name_hint; TypeData td = mod_->type_definitions.at(con->constructor->belong_to); auto pc = GetRef(con); @@ -242,15 +227,14 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified); } if (td->header != tc->func) { - this->ReportFatalError(pc, - ErrorBuilder() << "ADT headers must match, but we have " - << td->header << " and " << tc->func); + this->ReportFatalError(pc, ErrorBuilder() << "ADT headers must match, but we have " + << td->header << " and " << tc->func); } if (td->type_vars.size() != tc->args.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "The number of type args must match" - << "the number of type vars in the type data: " - << td->type_vars.size() << " != " << tc->args.size()); + this->ReportFatalError( + pc, ErrorBuilder() << "The number of type args must match" + << "the number of type vars in the type data: " << td->type_vars.size() + << " != " << tc->args.size()); } std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { @@ -258,10 +242,9 @@ class TypeInferencer : private ExprFunctor, } CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; if (con->constructor->inputs.size() != con->patterns.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "Not enough inputs for the constructor; " - << "expected " << con->constructor->inputs.size() - << ", got " << con->patterns.size()); + this->ReportFatalError(pc, ErrorBuilder() << "Not enough inputs for the constructor; " + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size()); } for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); @@ -294,7 +277,7 @@ class TypeInferencer : private ExprFunctor, Unify(vt, t, pv->span); } - void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {} Type VisitExpr_(const MatchNode* op) final { Type dtype = GetType(op->data); @@ -303,9 +286,7 @@ class TypeInferencer : private ExprFunctor, } Type rtype = IncompleteType(Kind::kType); for (const auto& c : op->clauses) { - rtype = this->Unify(rtype, - GetType(c->rhs), - op->span); + rtype = this->Unify(rtype, GetType(c->rhs), op->span); } if (op->complete) { @@ -319,18 +300,14 @@ class TypeInferencer : private ExprFunctor, for (auto cs : unmatched_cases) { ss << "case " << i++ << ": \n" << PrettyPrint(cs); } - this->ReportFatalError( - match, - ss); + this->ReportFatalError(match, ss); } } return rtype; } - Type VisitExpr_(const OpNode* op) final { - return op->op_type; - } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion @@ -342,7 +319,6 @@ class TypeInferencer : private ExprFunctor, type_map_[let->var].checked_type = let_type; } - if (let->var->type_annotation.defined()) { let_type = Unify(let_type, let->var->type_annotation, GetRef(let)); } @@ -360,9 +336,7 @@ class TypeInferencer : private ExprFunctor, // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. Type cond_type = this->GetType(ite->cond); - this->Unify(cond_type, - TensorType::Scalar(tvm::DataType::Bool()), - ite->cond); + this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond); Type checked_true = this->GetType(ite->true_branch); Type checked_false = this->GetType(ite->false_branch); return this->Unify(checked_true, checked_false, GetRef(ite)); @@ -372,9 +346,7 @@ class TypeInferencer : private ExprFunctor, // which are registered in the style defined in src/relay/op/*. // // The result will be the return type of the operator. - Type PrimitiveCall(const FuncTypeNode* op, - Array arg_types, - const Attrs& attrs, + Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, const ObjectRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); @@ -387,8 +359,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteType(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here - solver_.AddConstraint(TypeRelation( - rel->func, arg_types, arg_types.size() - 1, attrs), loc); + solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -417,9 +388,7 @@ class TypeInferencer : private ExprFunctor, ret_type = IncompleteType(Kind::kType); } - Type inst_ty = FuncType(fn_ty->arg_types, - ret_type, {}, - fn_ty->type_constraints); + Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); inst_ty = Bind(inst_ty, subst_map); return Downcast(inst_ty); } @@ -437,7 +406,6 @@ class TypeInferencer : private ExprFunctor, return InstantiateFuncType(fn_ty, type_args); } - void AddTypeArgs(const Expr& expr, Array type_args) { auto type_info = type_map_.find(expr); if (type_info == type_map_.end()) { @@ -456,10 +424,8 @@ class TypeInferencer : private ExprFunctor, if (fn_ty_node == nullptr && inc_ty_node == nullptr) { this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "only expressions with function types can be called, found " - << ftype); + GetRef(call), + ErrorBuilder() << "only expressions with function types can be called, found " << ftype); } // incomplete type => it must be a function taking the arg types @@ -474,12 +440,10 @@ class TypeInferencer : private ExprFunctor, Array type_args = call->type_args; if (type_args.size() > fn_ty_node->type_params.size()) { this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "Incorrect number of type args in " - << call->span << ": " - << "Expected " - << fn_ty_node->type_params.size() - << "but got " << type_args.size()); + ErrorBuilder() + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() << "but got " + << type_args.size()); } FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); @@ -491,17 +455,15 @@ class TypeInferencer : private ExprFunctor, if (type_arity != number_of_args) { if (type_arity < number_of_args) { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); } else { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } } @@ -511,9 +473,8 @@ class TypeInferencer : private ExprFunctor, for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { - solver_.AddConstraint( - TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - GetRef(call)); + solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), + GetRef(call)); } else { solver_.AddConstraint(cs, GetRef(call)); } @@ -529,9 +490,7 @@ class TypeInferencer : private ExprFunctor, } if (const OpNode* opnode = call->op.as()) { - Type rtype = PrimitiveCall(opnode->op_type.as(), - arg_types, - call->attrs, + Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, call->attrs, GetRef(call)); if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); @@ -560,9 +519,7 @@ class TypeInferencer : private ExprFunctor, return solver_.Resolve(ret); } - Type VisitExpr_(const RefCreateNode* op) final { - return RelayRefType(GetType(op->value)); - } + Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); } Type VisitExpr_(const RefReadNode* op) final { Type it = IncompleteType(Kind::kType); @@ -578,16 +535,13 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const ConstructorNode* c) final { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << c->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; TypeData td = mod_->LookupTypeDef(c->belong_to); std::vector types; - for (const auto & t : td->type_vars) { + for (const auto& t : td->type_vars) { types.push_back(t); } - return FuncType(c->inputs, TypeCall(c->belong_to, types), - td->type_vars, {}); + return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } void Solve() { @@ -603,72 +557,39 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { - } + : tmap_(tmap), solver_(solver) {} - Expr VisitExpr_(const VarNode* op) final { - return VisitVar(GetRef(op)); - } + Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } - Expr VisitExpr_(const ConstantNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const GlobalVarNode* op) final { - return GetRef(op); - } + Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef(op); } - Expr VisitExpr_(const OpNode* op) final { - return ExprMutator::VisitExpr_(op); - } + Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const TupleNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const FunctionNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const CallNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const LetNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const IfNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefCreateNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefReadNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefWriteNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const ConstructorNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const MatchNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Var VisitVar(const Var& v) final { if (vmap_.count(v) == 0) { @@ -678,7 +599,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } // attach checked type to the mutated node. - template + template Expr AttachCheckedType(const T* op) { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); @@ -687,42 +608,34 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // TODO(@jroesch): it would be nice if we would report resolution // errors directly on the program. CHECK(checked_type.as() == nullptr) - << "Cannot resolve type of " << GetRef(op) - << " at " << op->span; + << "Cannot resolve type of " << GetRef(op) << " at " << op->span; Expr new_e = ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. // Compiler optimization will likely fold these away for other nodes. - CallNode* new_call =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - VarNode* new_var =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - FunctionNode* new_fn =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); + CallNode* new_call = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + VarNode* new_var = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + FunctionNode* new_fn = + (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); // check if we need update the new_e bool need_update_type = !checked_type.same_as(new_e->checked_type_); - bool need_update_call = ( - std::is_base_of::value && - it->second.type_args.defined() && - !it->second.type_args.same_as(new_call->type_args)); - bool need_update_var = ( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_var->type_annotation.defined()); - - bool need_update_fn =( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_fn->ret_type.defined()); - - if (!need_update_type && - !need_update_var && - !need_update_call && - !need_update_fn) { + bool need_update_call = + (std::is_base_of::value && it->second.type_args.defined() && + !it->second.type_args.same_as(new_call->type_args)); + bool need_update_var = (std::is_base_of::value && update_missing_type_annotation_ && + !new_var->type_annotation.defined()); + + bool need_update_fn = (std::is_base_of::value && + update_missing_type_annotation_ && !new_fn->ret_type.defined()); + + if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) { return new_e; } @@ -732,15 +645,11 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // we make a copy mutating an existing reference. ObjectPtr ptr = make_object(*new_e.as()); new_e = Expr(ptr); - new_call = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_var = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_fn = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); + new_call = + (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_var = (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_fn = (std::is_base_of::value ? static_cast(ptr.get()) + : nullptr); } // attach the information. @@ -765,9 +674,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { return new_e; } - Type VisitType(const Type &t) final { - return solver_->Resolve(t); - } + Type VisitType(const Type& t) final { return solver_->Resolve(t); } private: std::unordered_map vmap_; @@ -793,17 +700,21 @@ Expr TypeInferencer::Infer(Expr expr) { struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { - if (e.as()) { return; } - if (e.as()) { return; } - if (e.as()) { return; } + if (e.as()) { + return; + } + if (e.as()) { + return; + } + if (e.as()) { + return; + } CHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } }; -void EnsureCheckedType(const Expr& e) { - AllCheckTypePopulated().VisitExpr(e); -} +void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } Expr InferType(const Expr& expr, const IRModule& mod) { auto main = mod->GetGlobalVar("main"); @@ -811,15 +722,12 @@ Expr InferType(const Expr& expr, const IRModule& mod) { auto e = inferencer.Infer(expr); CHECK(WellFormed(e)); auto free_tvars = FreeTypeVars(e, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in " << e << ": " << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); return e; } -Function InferType(const Function& func, - const IRModule& mod, - const GlobalVar& var) { +Function InferType(const Function& func, const IRModule& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_object(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); @@ -828,11 +736,9 @@ Function InferType(const Function& func, mod->Remove(var); CHECK(WellFormed(func_ret)); auto free_tvars = FreeTypeVars(func_ret, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in: " - << std::endl - << AsText(func, true) - << std::endl << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in: " << std::endl + << AsText(func, true) << std::endl + << free_tvars; return Downcast(func_ret); } @@ -840,16 +746,11 @@ namespace transform { Pass InferType() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(InferType(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(InferType(f, m)); }; return CreateFunctionPass(pass_func, 0, "InferType", {}); } -TVM_REGISTER_GLOBAL("relay._transform.InferType") -.set_body_typed([]() { - return InferType(); -}); +TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); } // namespace transform diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 60dc55d8c24a3..d229491a4c7b5 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,7 +20,7 @@ /*! * \file builtin_fp16.cc * \brief Functions for conversion between fp32 and fp16 -*/ + */ #include #include diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index fb1f74da2103c..0164b1bc4d399 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -22,20 +22,22 @@ * \brief Device specific implementations */ #include -#include #include -#include +#include +#include #include +#include #include -#include -#include -#include + #include -#include -#include +#include #include -#include "runtime_base.h" +#include +#include +#include + #include "object_internal.h" +#include "runtime_base.h" namespace tvm { namespace runtime { @@ -90,9 +92,7 @@ class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; // Get API - static DeviceAPI* Get(const TVMContext& ctx) { - return Get(ctx.device_type); - } + static DeviceAPI* Get(const TVMContext& ctx) { return Get(ctx.device_type); } static DeviceAPI* Get(int dev_type, bool allow_missing = false) { return Global()->GetAPI(dev_type, allow_missing); } @@ -102,9 +102,7 @@ class DeviceAPIManager { DeviceAPI* rpc_api_{nullptr}; std::mutex mutex_; // constructor - DeviceAPIManager() { - std::fill(api_.begin(), api_.end(), nullptr); - } + DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } // Global static variable. static DeviceAPIManager* Global() { static DeviceAPIManager inst; @@ -130,8 +128,7 @@ class DeviceAPIManager { std::string factory = "device_api." + name; auto* f = Registry::Get(factory); if (f == nullptr) { - CHECK(allow_missing) - << "Device API " << name << " is not enabled."; + CHECK(allow_missing) << "Device API " << name << " is not enabled."; return nullptr; } void* ptr = (*f)(); @@ -140,19 +137,14 @@ class DeviceAPIManager { }; DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { - return DeviceAPIManager::Get( - static_cast(ctx.device_type), allow_missing); + return DeviceAPIManager::Get(static_cast(ctx.device_type), allow_missing); } -void* DeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); } -void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { - FreeDataSpace(ctx, ptr); -} +void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); } TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) { LOG(FATAL) << "Device does not support stream api."; @@ -163,8 +155,7 @@ void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) { LOG(FATAL) << "Device does not support stream api."; } -void DeviceAPI::SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, +void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { LOG(FATAL) << "Device does not support stream api."; } @@ -256,7 +247,8 @@ std::string NormalizeError(std::string err_msg) { // Parse error type. { size_t start_pos = 0, end_pos; - for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { + } for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { char ch = line[end_pos]; if (ch == ':') { @@ -268,8 +260,9 @@ std::string NormalizeError(std::string err_msg) { } if (error_type.length() != 0) { // if we successfully detected error_type: trim the following space. - for (start_pos = end_pos + 1; - start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; + ++start_pos) { + } line = line.substr(start_pos); } else { // did not detect error_type, use default value. @@ -345,22 +338,16 @@ struct TVMRuntimeEntry { typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; -const char *TVMGetLastError() { - return TVMAPIRuntimeStore::Get()->last_error.c_str(); -} +const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } -int TVMAPIHandleException(const std::runtime_error &e) { +int TVMAPIHandleException(const std::runtime_error& e) { TVMAPISetLastError(NormalizeError(e.what()).c_str()); return -1; } -void TVMAPISetLastError(const char* msg) { - TVMAPIRuntimeStore::Get()->last_error = msg; -} +void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; } -int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out) { +int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { API_BEGIN(); TVMRetValue ret; ret = Module::LoadFromFile(file_name, format); @@ -371,21 +358,16 @@ int TVMModLoadFromFile(const char* file_name, API_END(); } -int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep) { +int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { API_BEGIN(); - ObjectInternal::GetModuleNode(mod)->Import( - GetRef(ObjectInternal::GetModuleNode(dep))); + ObjectInternal::GetModuleNode(mod)->Import(GetRef(ObjectInternal::GetModuleNode(dep))); API_END(); } -int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *func) { +int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* func) { API_BEGIN(); - PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction( - func_name, query_imports != 0); + PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { *func = new PackedFunc(pf); } else { @@ -394,23 +376,15 @@ int TVMModGetFunction(TVMModuleHandle mod, API_END(); } -int TVMModFree(TVMModuleHandle mod) { - return TVMObjectFree(mod); -} +int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } -int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *func) { +int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { API_BEGIN(); - *func = (TVMFunctionHandle)( - static_cast(mod_node)->GetFuncFromEnv(func_name)); + *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name)); API_END(); } -void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t size, - int dtype_code_hint, +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { TVMContext ctx; ctx.device_type = static_cast(device_type); @@ -421,14 +395,10 @@ void* TVMBackendAllocWorkspace(int device_type, type_hint.bits = static_cast(dtype_bits_hint); type_hint.lanes = 1; - return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, - static_cast(size), - type_hint); + return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast(size), type_hint); } -int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr) { +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { TVMContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; @@ -436,10 +406,7 @@ int TVMBackendFreeWorkspace(int device_type, return 0; } -int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void* cdata, - int nbytes) { +int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { if (*handle == nullptr) { *handle = reinterpret_cast(1); return (*f)(cdata); @@ -453,20 +420,14 @@ int TVMFuncFree(TVMFunctionHandle func) { API_END(); } -int TVMFuncCall(TVMFunctionHandle func, - TVMValue* args, - int* arg_type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code) { +int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); + TVMRetValue rv; - (*static_cast(func)).CallPacked( - TVMArgs(args, arg_type_codes, num_args), &rv); + (*static_cast(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. - if (rv.type_code() == kTVMStr || - rv.type_code() == kTVMDataType || - rv.type_code() == kTVMBytes) { + if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); if (rv.type_code() != kTVMDataType) { e->ret_str = *rv.ptr(); @@ -488,10 +449,7 @@ int TVMFuncCall(TVMFunctionHandle func, API_END(); } -int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret) { +int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { API_BEGIN(); CHECK_EQ(num_ret, 1); TVMRetValue* rv = static_cast(ret); @@ -499,32 +457,28 @@ int TVMCFuncSetReturn(TVMRetValueHandle ret, API_END(); } -int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out) { +int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, + TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - *out = new PackedFunc( - [func, resource_handle](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, resource_handle); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, resource_handle); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - *out = new PackedFunc( - [func, rpack](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, rpack.get()); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, rpack.get()); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } API_END(); } @@ -565,9 +519,7 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { API_END(); } -int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst) { API_BEGIN(); TVMContext ctx; @@ -585,35 +537,55 @@ int TVMCbArgToReturn(TVMValue* value, int* code) { API_END(); } +int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint, + void** out_data) { + API_BEGIN(); + out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); + API_END(); +} + +int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) { + API_BEGIN(); + DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr); + API_END(); +} + +int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { + API_BEGIN(); + TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to; + DeviceAPIManager::Get(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, num_bytes, ctx_from, + ctx_to, type_hint, stream); + API_END(); +} + // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - DeviceAPIManager::Get(ctx)->SetDevice(ctx); - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + DeviceAPIManager::Get(ctx)->SetDevice(ctx); + }); // set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - - DeviceAttrKind kind = static_cast(args[2].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, ret); - } else { - *ret = 0; - } +TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + + DeviceAttrKind kind = static_cast(args[2].operator int()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, ret); } else { - DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + *ret = 0; } - }); - + } else { + DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + } +}); -TVM_REGISTER_GLOBAL("runtime.TVMSetStream") -.set_body_typed(TVMSetStream); +TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 6145926105842..62220a8852082 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -24,31 +24,27 @@ #include #include #include -#include #include +#include namespace tvm { namespace runtime { using namespace vm; -TVM_REGISTER_GLOBAL("runtime.GetADTTag") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("runtime.GetADTSize") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTSize").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetADTFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTFields").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; const auto& adt = Downcast(obj); @@ -56,8 +52,7 @@ TVM_REGISTER_GLOBAL("runtime.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("runtime.Tuple") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.Tuple").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); @@ -65,8 +60,7 @@ TVM_REGISTER_GLOBAL("runtime.Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("runtime.ADT") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); std::vector fields; @@ -76,13 +70,11 @@ TVM_REGISTER_GLOBAL("runtime.ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_GLOBAL("runtime.String") -.set_body_typed([](std::string str) { +TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { return String(std::move(str)); }); -TVM_REGISTER_GLOBAL("runtime.GetFFIString") -.set_body_typed([](String str) { +TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { return std::string(str); }); diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index d4959be64cf17..0cf4c69cdf1e9 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -21,8 +21,9 @@ * \file Use external cblas library call. */ #include -#include #include +#include + #include "gemm_common.h" extern "C" { @@ -50,8 +51,8 @@ struct CblasSgemmOp { void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc) { #if USE_DNNL == 1 - dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, - ldb, A, lda, beta, C, ldc); + dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A, + lda, beta, C, ldc); #else cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -159,8 +160,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -170,8 +170,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -182,14 +181,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index b73ababbbaded..96d6322cc5921 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -23,15 +23,16 @@ */ #pragma once -#include #include +#include + #include namespace tvm { namespace contrib { using namespace runtime; -inline int ColumnStride(DLTensor *tensor) { +inline int ColumnStride(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,7 +43,7 @@ inline int ColumnStride(DLTensor *tensor) { } } -inline int ElementStride(DLTensor *tensor) { +inline int ElementStride(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,25 +52,21 @@ inline int ElementStride(DLTensor *tensor) { } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor *tensor) { +inline bool IsInPlaceTransposed(DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 0]; -} +inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } -inline int ColumnCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 0 : 1]; -} +inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. template -inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -92,20 +89,17 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), - ColumnCount(A, transa), static_cast(alpha), - reinterpret_cast( - static_cast(B->data) + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), + static_cast(alpha), + reinterpret_cast(static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast( - static_cast(A->data) + A->byte_offset), + reinterpret_cast(static_cast(A->data) + A->byte_offset), ColumnStride(A), static_cast(beta), - reinterpret_cast( - static_cast(C->data) + C->byte_offset), + reinterpret_cast(static_cast(C->data) + C->byte_offset), ColumnStride(C)); } -inline int ColumnStride3D(DLTensor *tensor) { +inline int ColumnStride3D(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -115,7 +109,7 @@ inline int ColumnStride3D(DLTensor *tensor) { return tensor->shape[2]; } } -inline int ElementStride3D(DLTensor *tensor) { +inline int ElementStride3D(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[1], tensor->strides[2]); } else { @@ -123,22 +117,18 @@ inline int ElementStride3D(DLTensor *tensor) { } } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed3D(DLTensor *tensor) { +inline bool IsInPlaceTransposed3D(DLTensor* tensor) { return tensor->strides && (tensor->strides[2] > tensor->strides[1]); } -inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } -inline int RowCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 2 : 1]; -} -inline int ColumnCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 2]; -} +inline int BatchCount3D(DLTensor* tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 2 : 1]; } +inline int ColumnCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 2]; } template -inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { +inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { using DType = typename TBatchGemmOp::TDatatype; - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(DType) * 8; @@ -163,16 +153,15 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { const int A_size = A->shape[1] * A->shape[2]; const int B_size = B->shape[1] * B->shape[2]; const int C_size = C->shape[1] * C->shape[2]; - DType *A_data = reinterpret_cast( - static_cast(A->data) + A->byte_offset); - DType *B_data = reinterpret_cast( - static_cast(B->data) + B->byte_offset); - DType *C_data = reinterpret_cast( - static_cast(C->data) + C->byte_offset); - op(batch_size, transb, transa, ColumnCount3D(B, transb), - RowCount3D(A, transa), ColumnCount3D(A, transa), - static_cast(alpha), - B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + DType* A_data = reinterpret_cast(static_cast(A->data) + + A->byte_offset); + DType* B_data = reinterpret_cast(static_cast(B->data) + + B->byte_offset); + DType* C_data = reinterpret_cast(static_cast(C->data) + + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), + ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, + ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), static_cast(beta), C_data, C_size, ColumnStride3D(C)); } diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index fada8008f18a6..404afa2f547e1 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -25,16 +25,16 @@ #ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ -#import #import +#import #include #include #include -#include -#include #include +#include +#include namespace tvm { namespace runtime { @@ -53,15 +53,12 @@ class CoreMLRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const { - return "CoreMLRuntime"; - } + const char* type_key() const { return "CoreMLRuntime"; } /*! * \brief Invoke the coreml prediction. @@ -74,9 +71,8 @@ class CoreMLRuntime : public ModuleNode { * \param ctx The context where the coreml model will be executed on. * \param output_names The output names of the model. */ - void Init(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names); + void Init(const std::string& model_path, TVMContext ctx, + const std::vector& output_names); /*! * \brief set input to the model. @@ -99,13 +95,13 @@ class CoreMLRuntime : public ModuleNode { int GetNumOutputs() const; // CoreML model - MLModel *model_; + MLModel* model_; // CoreML model input dictionary - NSMutableDictionary *input_dict_; + NSMutableDictionary* input_dict_; // CoreML model output id output_; // List of output names - std::vector output_names_; + std::vector output_names_; // TVM context TVMContext ctx_; }; diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 614842b457816..1ce84a00efd87 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -27,28 +27,27 @@ namespace tvm { namespace runtime { -MLModel *load_coreml_model(const std::string& model_path) { +MLModel* load_coreml_model(const std::string& model_path) { NSBundle* bundle = [NSBundle mainBundle]; NSString* base = [bundle privateFrameworksPath]; NSString* fname = [NSString stringWithUTF8String:("tvm/" + model_path).c_str()]; - NSString* assetPath = [base stringByAppendingPathComponent: fname]; + NSString* assetPath = [base stringByAppendingPathComponent:fname]; if (![[NSFileManager defaultManager] fileExistsAtPath:assetPath]) { - assetPath = [NSString stringWithCString: model_path.c_str() encoding:NSUTF8StringEncoding]; + assetPath = [NSString stringWithCString:model_path.c_str() encoding:NSUTF8StringEncoding]; } - NSURL *url = [NSURL fileURLWithPath:assetPath]; + NSURL* url = [NSURL fileURLWithPath:assetPath]; - MLModel *model = [MLModel modelWithContentsOfURL:url error:nil]; + MLModel* model = [MLModel modelWithContentsOfURL:url error:nil]; if (model == nil) { NSLog(@"modelc %@ not found", url); } return model; } -void CoreMLRuntime::Init(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names) { +void CoreMLRuntime::Init(const std::string& model_path, TVMContext ctx, + const std::vector& output_names) { model_ = load_coreml_model(model_path); ctx_ = ctx; input_dict_ = [NSMutableDictionary dictionary]; @@ -56,13 +55,14 @@ } void CoreMLRuntime::Invoke() { - id input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ error:nil]; + id input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ + error:nil]; output_ = [model_ predictionFromFeatures:input error:nil]; } void CoreMLRuntime::SetInput(const std::string& key, DLTensor* data_in) { int64_t size = 1; - NSMutableArray *shape = [[NSMutableArray alloc] init]; + NSMutableArray* shape = [[NSMutableArray alloc] init]; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; [shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]]; @@ -81,21 +81,20 @@ return; } - MLMultiArray *dest = [[MLMultiArray alloc] initWithShape:shape - dataType:dataType error:nil]; + MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; CHECK(data_in->strides == NULL); memcpy(dest.dataPointer, data_in->data, size); - NSString *nsKey = [NSString stringWithUTF8String:key.c_str()]; + NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; [input_dict_ setObject:dest forKey:nsKey]; } NDArray CoreMLRuntime::GetOutput(int index) const { - NSString *name = output_names_[index]; - MLModelDescription *model_desc = model_.modelDescription; - MLFeatureDescription *output_desc = model_desc.outputDescriptionsByName[name]; - MLMultiArrayConstraint *data_desc = output_desc.multiArrayConstraint; + NSString* name = output_names_[index]; + MLModelDescription* model_desc = model_.modelDescription; + MLFeatureDescription* output_desc = model_desc.outputDescriptionsByName[name]; + MLMultiArrayConstraint* data_desc = output_desc.multiArrayConstraint; std::vector shape; int64_t size = 1; for (int64_t i = 0; i < data_desc.shape.count; ++i) { @@ -114,59 +113,50 @@ } else { LOG(FATAL) << "unexpected data type " << data_desc.dataType; } - MLMultiArray *src = [output_ featureValueForName:name].multiArrayValue; + MLMultiArray* src = [output_ featureValueForName:name].multiArrayValue; NDArray ret = NDArray::Empty(shape, dtype, ctx_); ret.CopyFromBytes(src.dataPointer, size); return ret; } -int CoreMLRuntime::GetNumOutputs() const { - return output_names_.size(); -} +int CoreMLRuntime::GetNumOutputs() const { return output_names_.size(); } -PackedFunc CoreMLRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc CoreMLRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Invoke(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - const auto& input_name = args[0].operator std::string(); - this->SetInput(input_name, args[1]); - }); + const auto& input_name = args[0].operator std::string(); + this->SetInput(input_name, args[1]); + }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutput(args[0]); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "get_num_outputs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetNumOutputs(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetNumOutputs(); }); } else { return PackedFunc(); } } -Module CoreMLRuntimeCreate(const std::string& model_path, - TVMContext ctx, - const std::vector& output_names) { +Module CoreMLRuntimeCreate(const std::string& model_path, TVMContext ctx, + const std::vector& output_names) { auto exec = make_object(); exec->Init(model_path, ctx, output_names); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - std::vector output_names; - for (size_t i = 2; i < args.size(); i++) { - const std::string& name = args[i]; - output_names.push_back([NSString stringWithUTF8String:name.c_str()]); - } - *rv = CoreMLRuntimeCreate(args[0], args[1], output_names); - }); +TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector output_names; + for (size_t i = 2; i < args.size(); i++) { + const std::string& name = args[i]; + output_names.push_back([NSString stringWithUTF8String:name.c_str()]); + } + *rv = CoreMLRuntimeCreate(args[0], args[1], output_names); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 5424f4cdcddf8..ff204457d1c47 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,152 +20,98 @@ /*! * \file Use external cblas library call. */ -#include -#include #include +#include +#include + #include "../cblas/gemm_common.h" #include "cublas_utils.h" - namespace tvm { namespace contrib { using namespace runtime; -inline cublasOperation_t BooleanToTranspose(bool item) { - return item ? CUBLAS_OP_T : CUBLAS_OP_N; -} +inline cublasOperation_t BooleanToTranspose(bool item) { return item ? CUBLAS_OP_T : CUBLAS_OP_N; } inline void TryEnableTensorCore(cublasHandle_t hdl) { // TensorCores are only supported in cublas 9.0 or higher int version; CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version)); - if (version >= 9000) - CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); + if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); } struct CublasHgemmOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - half alpha, half* A, int lda, - half* B, int ldb, - half beta, half* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasHgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, half alpha, half* A, int lda, half* B, + int ldb, half beta, half* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasHgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasSgemmOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasSgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasSgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasDgemmOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasDgemmOp(cublasHandle_t hdl) : handle(hdl) {} + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasDgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasHgemmBatchOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasHgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A, int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasSgemmBatchOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasSgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasDgemmBatchOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasDgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; @@ -174,22 +120,19 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || - TypeMatch(in_dtype, kDLFloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); } else { return false; } } -int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} +int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // Reversed strides indicates an in-place transpose operation. @@ -230,53 +173,37 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; cublasLtMatmulDesc_t operationDesc = nullptr; CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = BooleanToTranspose(transa); cublasOperation_t opTransB = BooleanToTranspose(transb); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransB, sizeof(opTransB))); // Create descriptors for the original matrices - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , - opTransA == CUBLAS_OP_N ? k : m, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , - opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k, + opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n, + opTransB == CUBLAS_OP_N ? n : k, ldb)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, - operationDesc, - &alpha, - B_data, - Adesc, - A_data, - Bdesc, - &beta, - C_data, - Cdesc, - C_data, - Cdesc, - NULL, - NULL, - 0, - 0)); + Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + C_data, Cdesc, C_data, Cdesc, NULL, NULL, 0, 0)); } #endif -inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 2); @@ -297,10 +224,10 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -320,28 +247,21 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride(B), - A_data, cuda_in_type, ColumnStride(A), - beta_ptr, - C_data, cuda_out_type, ColumnStride(C), - cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), + ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride(B), A_data, cuda_in_type, ColumnStride(A), beta_ptr, + C_data, cuda_out_type, ColumnStride(C), cuda_out_type, algo)); } -inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 3); @@ -364,10 +284,10 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -391,88 +311,76 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount3D(B, transb), - RowCount3D(A, transa), - ColumnCount3D(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride3D(B), B_size, - A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, - C_data, cuda_out_type, ColumnStride3D(C), C_size, - batch_size, cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( + hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, beta_ptr, C_data, + cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); - else - CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); - } else { - CallGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); + else + CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } }); #if CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; - cublasLtHandle_t ltHandle; - CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - CallLtIgemm(args, ret, ltHandle); - CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); + CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); #endif // CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + TryEnableTensorCore(entry_ptr->handle); + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); - else - CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); - } else { - CallBatchGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); + else + CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } }); } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 9953cda323797..d4ec087707232 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -21,18 +21,16 @@ * \file Use external cudnn utils function */ #include "cublas_utils.h" + #include #include + #include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { - -CuBlasThreadEntry::CuBlasThreadEntry() { - CHECK_CUBLAS_ERROR(cublasCreate(&handle)); -} - +CuBlasThreadEntry::CuBlasThreadEntry() { CHECK_CUBLAS_ERROR(cublasCreate(&handle)); } CuBlasThreadEntry::~CuBlasThreadEntry() { if (handle) { @@ -41,10 +39,8 @@ CuBlasThreadEntry::~CuBlasThreadEntry() { } } - typedef dmlc::ThreadLocalStore CuBlasThreadStore; - CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); @@ -52,6 +48,5 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { return retval; } - } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 2e553e28493b0..5189c4f483a86 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -24,11 +24,12 @@ #ifndef TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ -#include -#include #include #include #include +#include +#include + #include #if CUDART_VERSION >= 10010 #include @@ -39,27 +40,35 @@ namespace contrib { inline const char* GetCublasErrorString(int error) { switch (error) { - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; } return "Unrecognized error"; } #ifndef CHECK_CUBLAS_ERROR -#define CHECK_CUBLAS_ERROR(fn) \ - do { \ - int error = static_cast(fn); \ +#define CHECK_CUBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ CHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \ } while (0) // ; intentionally left off. -#endif // CHECK_CUBLAS_ERROR - +#endif // CHECK_CUBLAS_ERROR struct CuBlasThreadEntry { CuBlasThreadEntry(); @@ -71,19 +80,26 @@ struct CuBlasThreadEntry { inline cudaDataType_t GetCudaDataType(DLDataType type) { if (type.code == kDLInt) { switch (type.bits) { - case 8: return CUDA_R_8I; - case 32: return CUDA_R_32I; + case 8: + return CUDA_R_8I; + case 32: + return CUDA_R_32I; } } else if (type.code == kDLUInt) { switch (type.bits) { - case 8: return CUDA_R_8U; - case 32: return CUDA_R_32U; + case 8: + return CUDA_R_8U; + case 32: + return CUDA_R_32U; } } else if (type.code == kDLFloat) { switch (type.bits) { - case 16: return CUDA_R_16F; - case 32: return CUDA_R_32F; - case 64: return CUDA_R_64F; + case 16: + return CUDA_R_16F; + case 32: + return CUDA_R_32F; + case 64: + return CUDA_R_64F; } } LOG(FATAL) << "Unsupported cuda type"; diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index c4c05d8079f9c..223a5b4fe435a 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external cudnn utils function */ -#include #include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -30,19 +31,9 @@ namespace contrib { using namespace runtime; -void ConvolutionForward( - int mode, - int format, - int algo, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - DLTensor* x, - DLTensor* w, - DLTensor* y, - const std::string& conv_dtype) { +void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], + const int stride[], const int dilation[], DLTensor* x, DLTensor* w, + DLTensor* y, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); @@ -67,17 +58,11 @@ void ConvolutionForward( CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); if (dims == 2) { // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad[0], - pad[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { ni = 0; ci = 3; hi = 1; @@ -90,67 +75,46 @@ void ConvolutionForward( } // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), - static_cast(w->shape[ci]), - static_cast(w->shape[hi]), - static_cast(w->shape[wi]))); + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w->shape[ni]), static_cast(w->shape[ci]), + static_cast(w->shape[hi]), static_cast(w->shape[wi]))); // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(x->shape[ni]), - static_cast(x->shape[ci]), - static_cast(x->shape[hi]), - static_cast(x->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(x->shape[ni]), static_cast(x->shape[ci]), + static_cast(x->shape[hi]), static_cast(x->shape[wi]))); // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(y->shape[ni]), - static_cast(y->shape[ci]), - static_cast(y->shape[hi]), - static_cast(y->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(y->shape[ni]), static_cast(y->shape[ci]), + static_cast(y->shape[hi]), static_cast(y->shape[wi]))); } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - entry_ptr->conv_entry.mode, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); // Set Filter for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(w->shape[i]); } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, dim.data())); // Set Input for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(x->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); // Set Output for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(y->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); } if (cudnnGetVersion() > 7000) { @@ -159,42 +123,23 @@ void ConvolutionForward( // Set workspace size_t workspace_size = 0; - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.fwd_algo, - &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.fwd_algo, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); - CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle, - CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - entry_ptr->conv_entry.workspace, - workspace_size, - CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.output_desc, - y->data)); + CUDNN_CALL(cudnnConvolutionForward( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.output_desc, y->data)); } - -void OutputShape( - int format, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - void *out_shape, - const std::string& data_dtype, - const std::string& conv_dtype) { +void OutputShape(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], void* out_shape, + const std::string& data_dtype, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -207,79 +152,46 @@ void OutputShape( // conv desc CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - x_dim[0], - x_dim[3], - x_dim[1], - x_dim[2])); + entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], + x_dim[3], x_dim[1], x_dim[2])); // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - w_dim[0], - w_dim[3], - w_dim[1], - w_dim[2])); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 3, - static_cast(out_shape) + 1, - static_cast(out_shape) + 2)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3], + w_dim[1], w_dim[2])); + + CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 3, static_cast(out_shape) + 1, + static_cast(out_shape) + 2)); } else { // Set Input std::vector tensor_stride(full_dims); GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); - - CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - full_dims, - static_cast(out_shape))); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); + + CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, full_dims, static_cast(out_shape))); } } - -void FindAlgo( - int format, - int dims, - int groups, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - const int y_dim[], - const std::string& data_dtype, - const std::string& conv_dtype, - TVMRetValue *ret) { +void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], + const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -292,65 +204,46 @@ void FindAlgo( // conv desc CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); std::vector tensor_stride(full_dims); // input desc GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); // output desc GetCudnnStride(full_dims, y_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - y_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + y_dim, tensor_stride.data())); if (cudnnGetVersion() > 7000) { CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) } int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - CUDNN_CONVOLUTION_FWD_ALGO_COUNT, - &returned_algo_count, - perf_results)); - - const std::vector fwd_algo_names{ - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" - }; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"}; auto best_algo = perf_results[0].algo; - LOG(INFO) << "\tCUDNN Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] << " - time: " << perf_results[i].time << " ms" @@ -360,87 +253,83 @@ void FindAlgo( ret[0] = best_algo; } - TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[5 + i]; - dilation_v[i] = args[7 + i]; - } - DLTensor* x = args[9]; - DLTensor* w = args[10]; - DLTensor* y = args[11]; - std::string conv_dtype = args[12]; - int groups = args[13]; - - ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, - dilation_v, x, w, y, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* x = args[9]; + DLTensor* w = args[10]; + DLTensor* y = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[3], stride_v[3], dilation_v[3]; - for (int i = 0; i < 3; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[6 + i]; - dilation_v[i] = args[9 + i]; - } - DLTensor *x = args[12]; - DLTensor *w = args[13]; - DLTensor *y = args[14]; - std::string conv_dtype = args[15]; - int groups = args[16]; - - ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, - dilation_v, x, w, y, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[3], stride_v[3], dilation_v[3]; + for (int i = 0; i < 3; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[6 + i]; + dilation_v[i] = args[9 + i]; + } + DLTensor* x = args[12]; + DLTensor* w = args[13]; + DLTensor* y = args[14]; + std::string conv_dtype = args[15]; + int groups = args[16]; + + ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - void* out_shape = args[7]; - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - OutputShape(format, dims, groups, pad, stride, dilation, x_dim, - w_dim, out_shape, data_dtype, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + void* out_shape = args[7]; + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - int* y_dim = static_cast(static_cast(args[7])); - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, - w_dim, y_dim, data_dtype, conv_dtype, ret); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* y_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, + conv_dtype, ret); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 9c895c5b7e066..cd934bcb70818 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -21,38 +21,44 @@ * \file Use external cudnn utils function */ #include "cudnn_utils.h" + #include #include - namespace tvm { namespace contrib { // CuDNN Data Type -cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) { +cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8; - else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32; - else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4; - else - LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return CUDNN_DATA_INT8; + else if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_INT32; + else if (dtype.bits == 8 && dtype.lanes == 4) + return CUDNN_DATA_INT8x4; + else LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT; - else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE; - else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF; - else - LOG(FATAL) << "Unsupported type"; - break; - } - return CUDNN_DATA_FLOAT; + break; + case kDLUInt: + LOG(FATAL) << "Unsupported type"; + break; + case kDLFloat: + if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_FLOAT; + else if (dtype.bits == 64 && dtype.lanes == 1) + return CUDNN_DATA_DOUBLE; + else if (dtype.bits == 16 && dtype.lanes == 1) + return CUDNN_DATA_HALF; + else + LOG(FATAL) << "Unsupported type"; + break; + } + return CUDNN_DATA_FLOAT; } -template<> +template <> const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { static const int int_v = 0; static const float float_v = 0; @@ -69,7 +75,7 @@ const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { return nullptr; } -template<> +template <> const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { static const int int_v = 1; static const float float_v = 1.f; @@ -91,22 +97,18 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { CuDNNThreadEntry::CuDNNThreadEntry() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.gpu"); - void *ret = (*func)(); + void* ret = (*func)(); cuda_api = static_cast(ret); CUDNN_CALL(cudnnCreate(&handle)); CUDNN_CALL(cudnnSetStream(handle, stream)); conv_entry.cuda_api = cuda_api; } -CuDNNThreadEntry::~CuDNNThreadEntry() { - CUDNN_CALL(cudnnDestroy(handle)); -} +CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); } typedef dmlc::ThreadLocalStore CuDNNThreadStore; -CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { - return CuDNNThreadStore::Get(); -} +CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); } // ConvEntry @@ -142,13 +144,9 @@ void ConvEntry::CleanWorkspace() { // SoftmaxEntry -SoftmaxEntry::SoftmaxEntry() { - CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); -} +SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } -SoftmaxEntry::~SoftmaxEntry() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); -} +SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index c2000d02e0c99..1b4eb40f193f1 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -24,11 +24,11 @@ #ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ -#include #include +#include #include -#include "../../cuda/cuda_common.h" +#include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { @@ -41,24 +41,22 @@ namespace contrib { /*! breif Convert DLTensor type to CuDNN type */ struct CuDNNDataType { - static cudnnDataType_t DLTypeToCuDNNType(const DLDataType &dtype); - template + static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype); + template static const void* GetConst(cudnnDataType_t type); }; // struct CuDNNDataType -inline void GetStride(int nbdim, const int *dims, int *strides) { +inline void GetStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { mul *= dims[i]; strides[i] = mul; } } -inline void GetCudnnStride(int nbdim, - const int* dims, - int* strides) { +inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { strides[i] = mul; mul *= dims[i]; } @@ -75,8 +73,8 @@ struct ConvEntry { cudnnConvolutionFwdAlgo_t fwd_algo; // cudnnMathType_t math_type; TVMContext ctx; - runtime::DeviceAPI *cuda_api; - void *workspace{nullptr}; + runtime::DeviceAPI* cuda_api; + void* workspace{nullptr}; size_t workspace_size{0}; ConvEntry(); ~ConvEntry(); @@ -98,7 +96,7 @@ struct CuDNNThreadEntry { cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; SoftmaxEntry softmax_entry; - runtime::DeviceAPI *cuda_api{nullptr}; + runtime::DeviceAPI* cuda_api{nullptr}; static CuDNNThreadEntry* ThreadLocal(); }; // CuDNNThreadEntry diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index fb6d8a6fdc56f..ff6d6a1dbd812 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -21,8 +21,9 @@ * \file src/runtime/contrib/cudnn/softmax.cc * \brief Use external cudnn softmax function */ -#include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -31,64 +32,53 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* x = args[0]; - DLTensor* y = args[1]; - int axis = args[2]; - int ndim = x->ndim; - int64_t* shape = x->shape; - if (axis < 0) axis += ndim; - CHECK(axis >= 0 && axis < ndim); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + CHECK(axis >= 0 && axis < ndim); - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Set mode and shape descriptor - if (axis == ndim - 1) { - int64_t N = 1; - for (int i = 0; i < ndim - 1; ++i) { - N *= shape[i]; - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(N), - static_cast(shape[ndim - 1]), - 1, - 1)); - } else { - int64_t pre_axis_dim = 1; - int64_t post_axis_dim = 1; - for (int i = 0; i < ndim; ++i) { - if (i < axis) { - pre_axis_dim *= shape[i]; - } else if (i > axis) { - post_axis_dim *= shape[i]; + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); } - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(pre_axis_dim), - static_cast(shape[axis]), - static_cast(post_axis_dim), - 1)); - } - auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); - auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); - CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, - CUDNN_SOFTMAX_ACCURATE, - entry_ptr->softmax_entry.mode, - alpha, - entry_ptr->softmax_entry.shape_desc, - x->data, - beta, - entry_ptr->softmax_entry.shape_desc, - y->data)); -}); + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE, + entry_ptr->softmax_entry.mode, alpha, + entry_ptr->softmax_entry.shape_desc, x->data, beta, + entry_ptr->softmax_entry.shape_desc, y->data)); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 0922ac1a65dfc..5b9f5e17232ca 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -22,8 +22,6 @@ * \brief TVM compatible wrappers for dnnl kernels. */ -#include "dnnl_kernel.h" - #include #include #include @@ -34,6 +32,8 @@ #include #include +#include "dnnl_kernel.h" + namespace tvm { namespace runtime { namespace contrib { @@ -133,8 +133,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op()); } -extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, - int p_I_, int p_O_) { +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -157,8 +156,8 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, auto bias_memory = memory(bias_md, eng, bias.data()); auto dst_memory = memory(dst_md, eng); - auto dense_desc = inner_product_forward::desc( - prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); + auto dense_desc = inner_product_forward::desc(prop_kind::forward_inference, data_md, weight_md, + bias_md, dst_md); auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng); assert(dst_md == dense_prim_desc.dst_desc()); @@ -171,8 +170,7 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -186,8 +184,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); - auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference, - algorithm::eltwise_relu, data_md, 0); + auto relu_desc = + eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_relu, data_md, 0); auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng); assert(data_md == relu_prim_desc.dst_desc()); @@ -215,8 +213,7 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo auto bn_desc = batch_normalization_forward::desc( prop_kind::forward_inference, data_md, p_E_, - normalization_flags::use_global_stats | - normalization_flags::use_scale_shift); + normalization_flags::use_global_stats | normalization_flags::use_scale_shift); auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng); assert(data_md == bn_prim_desc.dst_desc()); @@ -239,8 +236,8 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, - int p_C_, int p_H_, int p_W_) { +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, + int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -257,15 +254,14 @@ extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, auto weight_memory = memory(weight_md, eng, weight); auto dst_memory = memory(dst_md, eng); - auto add_desc = - binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); assert(dst_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); - add.execute(s, {{DNNL_ARG_SRC_0, data_memory}, - {DNNL_ARG_SRC_1, weight_memory}, - {DNNL_ARG_DST, dst_memory}}); + add.execute( + s, + {{DNNL_ARG_SRC_0, data_memory}, {DNNL_ARG_SRC_1, weight_memory}, {DNNL_ARG_DST, dst_memory}}); s.wait(); read_from_dnnl_memory(out, dst_memory); } diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index f92d7679aeeea..dbc064a6bc993 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include + #include "dnnl.hpp" namespace tvm { diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 4823ef7de9590..13b3c34a6b175 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -20,25 +20,23 @@ /*! * \file edgetpu_runtime.cc */ -#include +#include "edgetpu_runtime.h" + +#include #include #include #include -#include - - -#include "edgetpu_runtime.h" +#include namespace tvm { namespace runtime { -void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); // Load compiled model as a FlatBufferModel std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); // Build resolver tflite::ops::builtin::BuiltinOpResolver resolver; // Init EdgeTPUContext object @@ -58,16 +56,14 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = EdgeTPURuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = EdgeTPURuntimeCreate(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index 78730d530018c..af3517ba76f3e 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ -#include #include +#include #include "../tflite/tflite_runtime.h" @@ -44,17 +44,14 @@ class EdgeTPURuntime : public TFLiteRuntime { /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "EdgeTPURuntime"; - } + const char* type_key() const final { return "EdgeTPURuntime"; } /*! * \brief Initialize the edge TPU tflite runtime with tflite model and context. * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); private: std::shared_ptr edgetpu_context_; diff --git a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc index 98078b68c23a7..1a63eded5adff 100644 --- a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc +++ b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc @@ -42,8 +42,8 @@ #include #include -#include #include +#include #include #include #include @@ -76,9 +76,8 @@ int Add(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Add_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Add_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -93,9 +92,8 @@ int Sub(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Sub_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Sub_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -110,9 +108,8 @@ int Mul(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Mul_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Mul_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -136,8 +133,7 @@ class ExampleJsonModule : public ModuleNode { * * \return The function pointer when it is found, otherwise, PackedFunc(nullptr). */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (this->graph_.find(name) != this->graph_.end()) { this->curr_subgraph_ = name; return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -215,9 +211,7 @@ class ExampleJsonModule : public ModuleNode { * * \param stream. The stream to save the binary. */ - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(this->graph_json_); - } + void SaveToBinary(dmlc::Stream* stream) final { stream->Write(this->graph_json_); } /*! * \brief Parse the example json string. @@ -333,12 +327,10 @@ class ExampleJsonModule : public ModuleNode { }; TVM_REGISTER_GLOBAL("runtime.module.loadfile_examplejson") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = ExampleJsonModule::Create(args[0]); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ExampleJsonModule::Create(args[0]); }); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_examplejson") -.set_body_typed(ExampleJsonModule::LoadFromBinary); + .set_body_typed(ExampleJsonModule::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index d4575484320ba..1353e2f996bb4 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external miopen utils function */ -#include #include #include +#include + #include "miopen_utils.h" namespace tvm { @@ -31,8 +32,7 @@ namespace miopen { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup").set_body([](TVMArgs args, TVMRetValue* ret) { const int mode = args[0]; const int dtype = args[1]; const int pad_h = args[2]; @@ -50,72 +50,52 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const int w_dim2 = args[14]; const int w_dim3 = args[15]; const int n_group = args[16]; - void *out_shape = args[17]; + void* out_shape = args[17]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); assert(n_group > 0 && "Group Size > 0 is expected"); - if (n_group > 1) - assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); + if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); // Set Ctx entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at - // this moment. + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), + // int32, int8 at this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); + entry_ptr->conv_entry.mode, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w)); if (n_group > 1) MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); // Set Filter MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w_dim0, - w_dim1/n_group, - w_dim2, - w_dim3)); + entry_ptr->conv_entry.data_type, w_dim0, w_dim1 / n_group, + w_dim2, w_dim3)); // Set Input MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x_dim0, - x_dim1, - x_dim2, + entry_ptr->conv_entry.data_type, x_dim0, x_dim1, x_dim2, x_dim3)); // Set Output shape - MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 1, - static_cast(out_shape) + 2, - static_cast(out_shape) + 3)); - - const int *oshape = static_cast(out_shape); + MIOPEN_CALL(miopenGetConvolutionForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 1, static_cast(out_shape) + 2, + static_cast(out_shape) + 3)); + + const int* oshape = static_cast(out_shape); // Set Output MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - oshape[0], - oshape[1], - oshape[2], - oshape[3])); + entry_ptr->conv_entry.data_type, oshape[0], oshape[1], + oshape[2], oshape[3])); // Set workspace size_t workspace_size = 0; - MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - &workspace_size)); + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3; @@ -123,12 +103,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3]; runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api; - float* input_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - input_size * sizeof(float))); - float* filter_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - filter_size * sizeof(float))); - float* output_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - output_size * sizeof(float))); + float* input_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, input_size * sizeof(float))); + float* filter_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, filter_size * sizeof(float))); + float* output_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, output_size * sizeof(float))); const int request_algo_count = 4; const bool exhaustive_search = false; @@ -137,20 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") int returned_algo_count = 0; miopenConvAlgoPerf_t perfs[4]; - MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - input_buf, - entry_ptr->conv_entry.filter_desc, - filter_buf, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - output_buf, - request_algo_count, - &returned_algo_count, - perfs, - workspace, - workspace_size, - exhaustive_search)); + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf, + entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, &returned_algo_count, + perfs, workspace, workspace_size, exhaustive_search)); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf); @@ -163,8 +134,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") "miopenConvolutionFwdAlgoWinograd", }; const auto best_algo = perfs[0].fwd_algo; - LOG(INFO) << "\tMIOpen Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo] << " - time: " << perfs[i].time << " ms" @@ -174,79 +145,56 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ret[0] = static_cast(best_algo); }); - TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - const int mode = args[0]; - const int dtype = args[1]; - const int pad_h = args[2]; - const int pad_w = args[3]; - const int stride_h = args[4]; - const int stride_w = args[5]; - const int dilation_h = args[6]; - const int dilation_w = args[7]; - const int algo = args[8]; - const DLTensor *x = args[9]; - const DLTensor *w = args[10]; - const DLTensor *y = args[11]; - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); - entry_ptr->conv_entry.fwd_algo = static_cast(algo); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Ctx - entry_ptr->conv_entry.ctx = x->ctx; - // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at - // this moment. - // Set Desc - MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); - // Set Filter - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w->shape[0], - w->shape[1], - w->shape[2], - w->shape[3])); - // Set Input - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x->shape[0], - x->shape[1], - x->shape[2], - x->shape[3])); - // Set Output - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - y->shape[0], - y->shape[1], - y->shape[2], - y->shape[3])); - - const float alpha = 1.f; - const float beta = 0.f; - MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle, - &alpha, - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - &beta, - entry_ptr->conv_entry.output_desc, - y->data, - entry_ptr->conv_entry.workspace, - entry_ptr->conv_entry.workspace_size)); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + const int mode = args[0]; + const int dtype = args[1]; + const int pad_h = args[2]; + const int pad_w = args[3]; + const int stride_h = args[4]; + const int stride_w = args[5]; + const int dilation_h = args[6]; + const int dilation_w = args[7]; + const int algo = args[8]; + const DLTensor* x = args[9]; + const DLTensor* w = args[10]; + const DLTensor* y = args[11]; + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Ctx + entry_ptr->conv_entry.ctx = x->ctx; + // Set Data Type + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), + // fp16(miopenHalf) at this moment. + // Set Desc + MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.mode, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w)); + // Set Filter + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.data_type, w->shape[0], + w->shape[1], w->shape[2], w->shape[3])); + // Set Input + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.data_type, x->shape[0], + x->shape[1], x->shape[2], x->shape[3])); + // Set Output + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.data_type, y->shape[0], + y->shape[1], y->shape[2], y->shape[3])); + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenConvolutionForward( + entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, + entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); + }); } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 330ccdd043d08..a57918045d87b 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -21,20 +21,22 @@ * \file Use external miopen utils function */ #include "miopen_utils.h" + #include #include -#include + #include +#include namespace tvm { namespace contrib { namespace miopen { std::string miopenGetErrorString(int error_code) { - const std::vector mio_err{ - "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ", - "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ", - "StatusNotImplemented ", "StatusUnknownError "}; + const std::vector mio_err{"StatusSuccess ", "StatusNotInitialized ", + "StatusInvalidValue ", "StatusBadParm ", + "StatusAllocFailed ", "StatusInternalError ", + "StatusNotImplemented ", "StatusUnknownError "}; return mio_err[error_code]; } @@ -42,22 +44,18 @@ std::string miopenGetErrorString(int error_code) { MIOpenThreadEntry::MIOpenThreadEntry() { auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.rocm"); - void *ret = (*func)(); + void* ret = (*func)(); rocm_api = static_cast(ret); MIOPEN_CALL(miopenCreate(&handle)); MIOPEN_CALL(miopenSetStream(handle, stream)); conv_entry.rocm_api = rocm_api; } -MIOpenThreadEntry::~MIOpenThreadEntry() { - MIOPEN_CALL(miopenDestroy(handle)); -} +MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); } typedef dmlc::ThreadLocalStore MIOpenThreadStore; -MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { - return MIOpenThreadStore::Get(); -} +MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); } // ConvEntry diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index 8831e4fac95c4..4dec2ad710ba4 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "../../rocm/rocm_common.h" namespace tvm { @@ -36,11 +38,10 @@ namespace miopen { std::string miopenGetErrorString(int error_code); -#define MIOPEN_CALL(func) \ - { \ - miopenStatus_t e = (func); \ - CHECK_EQ(e, miopenStatusSuccess) \ - << "miopen error: " << miopenGetErrorString(e); \ +#define MIOPEN_CALL(func) \ + { \ + miopenStatus_t e = (func); \ + CHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \ } struct ConvEntry { @@ -52,8 +53,8 @@ struct ConvEntry { miopenTensorDescriptor_t output_desc; miopenConvFwdAlgorithm_t fwd_algo; TVMContext ctx; - runtime::DeviceAPI *rocm_api; - void *workspace{nullptr}; + runtime::DeviceAPI* rocm_api; + void* workspace{nullptr}; size_t workspace_size{0}; ConvEntry(); ~ConvEntry(); @@ -66,8 +67,8 @@ struct MIOpenThreadEntry { ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; - runtime::DeviceAPI *rocm_api{nullptr}; - static MIOpenThreadEntry *ThreadLocal(); + runtime::DeviceAPI* rocm_api{nullptr}; + static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry } // namespace miopen diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 064e6d53cfb81..b598014f02673 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -24,69 +24,59 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *buf = args[0]; - DLTensor *img = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* buf = args[0]; + DLTensor* img = args[1]; // copy to temp id mtlbuf = (__bridge id)(buf->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(buf->ctx); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr - ); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); - MPSImageDescriptor *desc = [MPSImageDescriptor - imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 - width:buf->shape[2] - height:buf->shape[1] - featureChannels:buf->shape[3]]; + MPSImageDescriptor* desc = + [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 + width:buf->shape[2] + height:buf->shape[1] + featureChannels:buf->shape[3]]; - MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc); + MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc); [mpsimg writeBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - img->data = (__bridge void *)mpsimg; + img->data = (__bridge void*)mpsimg; [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; +}); - }); - -TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *img = args[0]; - DLTensor *buf = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* img = args[0]; + DLTensor* buf = args[1]; id mtlbuf = (__bridge id)(buf->data); - MPSImage *mpsimg = (__bridge MPSImage *)(img->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MPSImage* mpsimg = (__bridge MPSImage*)(img->data); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr); - - }); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); +}); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { // MPS-NHWC - DLTensor *data = args[0]; - DLTensor *weight = args[1]; - DLTensor *output = args[2]; + DLTensor* data = args[0]; + DLTensor* weight = args[1]; + DLTensor* output = args[2]; int pad = args[3]; int stride = args[4]; @@ -108,54 +98,48 @@ auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img"); auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer"); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(data->ctx); - id queue = - entry_ptr->metal_api->GetCommandQueue(data->ctx); + id queue = entry_ptr->metal_api->GetCommandQueue(data->ctx); id cb = [queue commandBuffer]; // data to MPSImage DLTensor tmp_in; (*f_buf2img)(data, &tmp_in); - MPSImage *tempA = (__bridge MPSImage *)tmp_in.data; + MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; // weight to temp memory id bufB = (__bridge id)(weight->data); id tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length], - weight->ctx, weight->ctx, nullptr); - float *ptr_w = (float *)[tempB contents]; + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, + [bufB length], weight -> ctx, weight -> ctx, nullptr); + float* ptr_w = (float*)[tempB contents]; // output to MPSImage DLTensor tmp_out; (*f_buf2img)(output, &tmp_out); - MPSImage *tempC = (__bridge MPSImage *)tmp_out.data; + MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; // conv desc - MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor - cnnConvolutionDescriptorWithKernelWidth:kW - kernelHeight:kH - inputFeatureChannels:iCh - outputFeatureChannels:oCh]; + MPSCNNConvolutionDescriptor* conv_desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; [conv_desc setStrideInPixelsX:stride]; [conv_desc setStrideInPixelsY:stride]; - MPSCNNConvolution *conv = - [[MPSCNNConvolution alloc] initWithDevice:dev - convolutionDescriptor:conv_desc - kernelWeights:ptr_w - biasTerms:nil - flags:MPSCNNConvolutionFlagsNone]; + MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; if (pad == 0) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeSame]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeSame]; } else if (pad == 1) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeValidOnly]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeValidOnly]; } [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; @@ -166,8 +150,7 @@ [cb waitUntilCompleted]; (*f_img2buf)(&tmp_out, output); +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index bc1216704cc4a..109c952ff0c4a 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,11 +24,10 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // call gemm for simple compact code. @@ -42,7 +41,7 @@ CHECK(TypeMatch(B->dtype, kDLFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32)); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); // CHECK_EQ(A->ctx, B->ctx); // CHECK_EQ(A->ctx, C->ctx); id dev = entry_ptr->metal_api->GetDevice(A->ctx); @@ -55,36 +54,31 @@ CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); // mps a MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor *descA = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:M - columns:K - rowBytes:K * sizeof(MPSDataTypeFloat32) - dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor* descA = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; id bufA = (__bridge id)(A->data); - MPSMatrix *matrixA = - [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; // mps b - MPSMatrixDescriptor *descB = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:K - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufB = (__bridge id)(B->data); - MPSMatrix *matrixB = - [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; // mps c - MPSMatrixDescriptor *descC = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufC = (__bridge id)(C->data); - MPSMatrix *matrixC = - [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; // kernel - MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev + MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev transposeLeft:transa transposeRight:transb resultRows:M @@ -93,13 +87,9 @@ alpha:1.0f beta:0.0f]; CHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb - leftMatrix:matrixA - rightMatrix:matrixB - resultMatrix:matrixC]; + [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; [cb commit]; +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index f1fff95c1df32..170451ea385bd 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -27,10 +27,12 @@ #import #include #include +#include #include #include -#include + #include + #include "../../metal/metal_common.h" namespace tvm { @@ -38,18 +40,17 @@ namespace contrib { /*! breif Convert DLTensor type to MPS type */ struct MPSType { - static MPSDataType DLTypeToMPSType(const DLDataType &dtype); + static MPSDataType DLTypeToMPSType(const DLDataType& dtype); }; // struct MPSType struct MetalThreadEntry { MetalThreadEntry(); ~MetalThreadEntry(); - MPSImage *AllocMPSImage(id dev, MPSImageDescriptor *desc); - MPSTemporaryImage *AllocTempImage(id cb, - MPSImageDescriptor *desc); - runtime::metal::MetalWorkspace *metal_api{nullptr}; - static MetalThreadEntry *ThreadLocal(); - std::vector img_table; + MPSImage* AllocMPSImage(id dev, MPSImageDescriptor* desc); + MPSTemporaryImage* AllocTempImage(id cb, MPSImageDescriptor* desc); + runtime::metal::MetalWorkspace* metal_api{nullptr}; + static MetalThreadEntry* ThreadLocal(); + std::vector img_table; }; // MetalThreadEntry } // namespace contrib diff --git a/src/runtime/contrib/mps/mps_utils.mm b/src/runtime/contrib/mps/mps_utils.mm index b3d4070ca6b74..f9f80431165ee 100644 --- a/src/runtime/contrib/mps/mps_utils.mm +++ b/src/runtime/contrib/mps/mps_utils.mm @@ -23,60 +23,58 @@ namespace contrib { // MPS Data Type -MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { +MPSDataType MPSType::DLTypeToMPSType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeInt16; - else + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeInt16; + else + LOG(FATAL) << "Unsupported type"; + break; + case kDLUInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeUInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeUInt16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeUInt32; LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeUInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeUInt16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeUInt32; - LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeFloat16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeFloat32; - else + break; + case kDLFloat: + if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeFloat16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeFloat32; + else + LOG(FATAL) << "Unsupported type"; + break; + default: LOG(FATAL) << "Unsupported type"; - break; - default: - LOG(FATAL) << "Unsupported type"; } return MPSDataTypeFloat32; } // MetalThreadEntry -MPSImage *MetalThreadEntry::AllocMPSImage(id dev, - MPSImageDescriptor *desc) { - MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; +MPSImage* MetalThreadEntry::AllocMPSImage(id dev, MPSImageDescriptor* desc) { + MPSImage* mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; img_table.push_back(mpsimg); return mpsimg; } -MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id cb, - MPSImageDescriptor *desc) { - MPSTemporaryImage *mpsimg = - [MPSTemporaryImage temporaryImageWithCommandBuffer:cb - imageDescriptor:desc]; +MPSTemporaryImage* MetalThreadEntry::AllocTempImage(id cb, + MPSImageDescriptor* desc) { + MPSTemporaryImage* mpsimg = [MPSTemporaryImage temporaryImageWithCommandBuffer:cb + imageDescriptor:desc]; return mpsimg; } MetalThreadEntry::MetalThreadEntry() { auto func = runtime::Registry::Get("device_api.metal"); - void *ret = (*func)(); - metal_api = static_cast(ret); + void* ret = (*func)(); + metal_api = static_cast(ret); } MetalThreadEntry::~MetalThreadEntry() { @@ -87,9 +85,7 @@ typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry *MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/nnpack/convolution.cc b/src/runtime/contrib/nnpack/convolution.cc index 79ea19175d653..54c9ea4f969ba 100644 --- a/src/runtime/contrib/nnpack/convolution.cc +++ b/src/runtime/contrib/nnpack/convolution.cc @@ -20,11 +20,12 @@ /*! * \file Use external nnpack library call. */ -#include -#include -#include #include #include +#include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -32,28 +33,25 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); CHECK_EQ(kernel->ndim, 4); if (bias) { @@ -93,10 +91,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -107,24 +104,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, - stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } @@ -132,28 +126,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") }); TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *transformed_kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* transformed_kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); if (bias) { CHECK_EQ(bias->ndim, 1); @@ -189,10 +180,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_reuse, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -203,38 +193,34 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, input_size, input_padding, kernel_size, stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(transformed_kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(transformed_kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } cpu_api->FreeWorkspace(ctx, workspace_buffer); }); -TVM_REGISTER_GLOBAL( - "tvm.contrib.nnpack.convolution_inference_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); +TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_weight_transform") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *kernel = args[0]; - DLTensor *transformed_kernel = args[1]; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* kernel = args[0]; + DLTensor* transformed_kernel = args[1]; // Dummy sizes nnp_padding input_padding{1, 1, 1, 1}; nnp_size stride_size{1, 1}; @@ -244,8 +230,7 @@ TVM_REGISTER_GLOBAL( NNPackConfig(args[2]); uint64_t algo_ = args[3]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(kernel->ndim, 4); size_t input_channels = kernel->shape[1]; size_t output_channels = kernel->shape[0]; @@ -259,21 +244,20 @@ TVM_REGISTER_GLOBAL( size_t transformed_kernel_size = 0; nnp_status status; status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &transformed_kernel_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel)); status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, static_cast(kernel->data), nullptr, nullptr, - static_cast(transformed_kernel->data), - &transformed_kernel_size, nnp_activation_identity, nullptr, - entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, + static_cast(kernel->data), nullptr, nullptr, + static_cast(transformed_kernel->data), &transformed_kernel_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); }); } // namespace contrib diff --git a/src/runtime/contrib/nnpack/fully_connected.cc b/src/runtime/contrib/nnpack/fully_connected.cc index 5f111efac4df5..543d239586339 100644 --- a/src/runtime/contrib/nnpack/fully_connected.cc +++ b/src/runtime/contrib/nnpack/fully_connected.cc @@ -20,10 +20,11 @@ /*! * \file Use external nnpack library call. */ -#include -#include #include #include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -33,33 +34,30 @@ using namespace runtime; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") -.set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - nnp_initialize(); - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - NNPackConfig(args[3]); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); + nnp_initialize(); + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + NNPackConfig(args[3]); - CHECK_EQ(A->ndim, 1); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 1); - CHECK_EQ(B->shape[0], C->shape[0]); - CHECK_EQ(B->shape[1], A->shape[0]); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); + CHECK_EQ(A->ndim, 1); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 1); + CHECK_EQ(B->shape[0], C->shape[0]); + CHECK_EQ(B->shape[1], A->shape[0]); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - nnp_fully_connected_inference(B->shape[1], - B->shape[0], - static_cast(A->data), - static_cast(B->data), - static_cast(C->data), - entry->threadpool); - }); + nnp_fully_connected_inference(B->shape[1], B->shape[0], static_cast(A->data), + static_cast(B->data), static_cast(C->data), + entry->threadpool); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.cc b/src/runtime/contrib/nnpack/nnpack_utils.cc index f01ad8557feea..91cf865128e90 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.cc +++ b/src/runtime/contrib/nnpack/nnpack_utils.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,12 @@ using namespace runtime; typedef dmlc::ThreadLocalStore NNPackThreadLocalStore; - NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } bool NNPackConfig(uint64_t nthreads) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) { CHECK_NE(nthreads, 1); return true; @@ -55,11 +54,9 @@ bool NNPackConfig(uint64_t nthreads) { return true; } - -TVM_REGISTER_GLOBAL("contrib.nnpack._initialize") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = nnp_initialize(); - }); +TVM_REGISTER_GLOBAL("contrib.nnpack._initialize").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = nnp_initialize(); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 4ba586fe08ac7..bbb0d16bc868a 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -22,11 +22,11 @@ */ #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ -#include -#include -#include #include +#include #include +#include +#include namespace tvm { namespace contrib { diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 37166e2c8d0fc..c628e327643e4 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief mt19937 random engine */ #include + #include #include #include @@ -34,45 +35,37 @@ namespace contrib { */ class RandomEngine { public: - /*! - * \brief Creates a RandomEngine using a default seed. - */ - RandomEngine() { - this->Seed(time(0)); - } - - /*! - * \brief Creates a RandomEngine, suggesting the use of a provided seed. - */ - explicit RandomEngine(unsigned seed) { - this->Seed(seed); - } - - /*! - * \brief Seeds the underlying RNG, if possible. - */ + /*! + * \brief Creates a RandomEngine using a default seed. + */ + RandomEngine() { this->Seed(time(0)); } + + /*! + * \brief Creates a RandomEngine, suggesting the use of a provided seed. + */ + explicit RandomEngine(unsigned seed) { this->Seed(seed); } + + /*! + * \brief Seeds the underlying RNG, if possible. + */ inline void Seed(unsigned seed) { rnd_engine_.seed(seed); this->rseed_ = static_cast(seed); } - /*! - * \return the seed associated with the underlying RNG. - */ - inline unsigned GetSeed() const { - return rseed_; - } + /*! + * \return the seed associated with the underlying RNG. + */ + inline unsigned GetSeed() const { return rseed_; } - /*! - * \return a random integer sampled from the RNG. - */ - inline unsigned GetRandInt() { - return rnd_engine_(); - } + /*! + * \return a random integer sampled from the RNG. + */ + inline unsigned GetRandInt() { return rnd_engine_(); } - /*! - * \brief Fills a tensor with values drawn from Unif(low, high) - */ + /*! + * \brief Fills a tensor with values drawn from Unif(low, high) + */ void SampleUniform(DLTensor* data, float low, float high) { CHECK_GT(high, low) << "high must be bigger than low"; CHECK(data->strides == nullptr); @@ -87,17 +80,16 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::uniform_real_distribution uniform_dist(low, high); - std::generate_n(static_cast(data->data), size, [&] () { - return uniform_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return uniform_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.uniform on this device yet"; } } - /*! - * \brief Fills a tensor with values drawn from Normal(loc, scale**2) - */ + /*! + * \brief Fills a tensor with values drawn from Normal(loc, scale**2) + */ void SampleNormal(DLTensor* data, float loc, float scale) { CHECK_GT(scale, 0) << "standard deviation must be positive"; CHECK(data->strides == nullptr); @@ -112,9 +104,8 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::normal_distribution normal_dist(loc, scale); - std::generate_n(static_cast(data->data), size, [&] () { - return normal_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return normal_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.normal on this device yet"; } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 8ae1f8668c877..acba193c12305 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -20,32 +20,34 @@ /*! * \file External random functions for tensor. */ -#include -#include #include #include +#include +#include + #include + #include "mt_random_engine.cc" #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ if (type.code == kDLInt && type.bits == 32) { \ typedef int32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 16) { \ typedef int16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 8) { \ typedef int8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 32) { \ typedef uint32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 16) { \ typedef uint16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 8) { \ typedef uint8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else { \ LOG(FATAL) << "unknown data type"; \ } @@ -66,61 +68,54 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } +TVM_REGISTER_GLOBAL("tvm.contrib.random.randint").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + int64_t low = args[0]; + int64_t high = args[1]; + DLTensor* out = args[2]; + CHECK_GT(high, low) << "high must be bigger than low"; + CHECK(out->strides == nullptr); + + DLDataType dtype = out->dtype; + int64_t size = 1; + for (int i = 0; i < out->ndim; ++i) { + size *= out->shape[i]; + } -TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - int64_t low = args[0]; - int64_t high = args[1]; - DLTensor* out = args[2]; - CHECK_GT(high, low) << "high must be bigger than low"; - CHECK(out->strides == nullptr); - - DLDataType dtype = out->dtype; - int64_t size = 1; - for (int i = 0; i < out->ndim; ++i) { - size *= out->shape[i]; + DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { + int64_t numeric_low = std::numeric_limits::min(); + int64_t numeric_high = std::numeric_limits::max(); + numeric_high += 1; // exclusive upper bound + low = std::max(low, numeric_low); + high = std::min(high, numeric_high); + + if (out->ctx.device_type == kDLCPU) { + // file the data with random byte + std::generate_n(static_cast(out->data), size, [&]() { + unsigned rint = entry->random_engine.GetRandInt(); + return low + rint % (high - low); + }); + } else { + LOG(FATAL) << "Do not support random.randint on this device yet"; } - - DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { - int64_t numeric_low = std::numeric_limits::min(); - int64_t numeric_high = std::numeric_limits::max(); - numeric_high += 1; // exclusive upper bound - low = std::max(low, numeric_low); - high = std::min(high, numeric_high); - - if (out->ctx.device_type == kDLCPU) { - // file the data with random byte - std::generate_n(static_cast(out->data), size, [&] () { - unsigned rint = entry->random_engine.GetRandInt(); - return low + rint % (high - low); - }); - } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; - } - }) - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double low = args[0]; - double high = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleUniform(out, low, high); - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double loc = args[0]; - double scale = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleNormal(out, loc, scale); - }); - + }) +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double low = args[0]; + double high = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleUniform(out, low, high); +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double loc = args[0]; + double scale = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleNormal(out, loc, scale); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index dda4ee30fde5c..0e6f4bd696864 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -20,75 +20,68 @@ /*! * \file Use external rocblas library call. */ -#include -#include -#include #include "rocblas.h" +#include +#include +#include + namespace tvm { namespace contrib { using namespace runtime; #ifndef CHECK_ROCBLAS_ERROR -#define CHECK_ROCBLAS_ERROR(error) \ -if (error != rocblas_status_success) { \ - fprintf(stderr, "rocBLAS error: "); \ - if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ - if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \ - if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \ - if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ - if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ - if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ - fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ -} +#define CHECK_ROCBLAS_ERROR(error) \ + if (error != rocblas_status_success) { \ + fprintf(stderr, "rocBLAS error: "); \ + if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ + if (error == rocblas_status_not_implemented) \ + fprintf(stderr, " rocblas_status_not_implemented"); \ + if (error == rocblas_status_invalid_pointer) \ + fprintf(stderr, "rocblas_status_invalid_pointer"); \ + if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ + if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ + if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ + } #endif - // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - bool transa = args[3]; - bool transb = args[4]; - // call gemm for simple compact code. - CHECK_EQ(A->ndim, 2); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 2); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); +TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // call gemm for simple compact code. + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - rocblas_handle handle; - CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); - float alpha = 1.0; - float beta = 0.0; - float *A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float *B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); - float *C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - transb ? B->shape[0] : B->shape[1], - transa ? A->shape[1] : A->shape[0], - transb ? B->shape[1] : B->shape[0], - &alpha, - A_ptr, - B->shape[1], - B_ptr, - A->shape[1], - &beta, - C_ptr, - C->shape[1])); + CHECK_ROCBLAS_ERROR( + rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0], + transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr, + A->shape[1], &beta, C_ptr, C->shape[1])); - CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); + CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 0c9c57533dbe0..9543e4b4c64ed 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -21,8 +21,9 @@ * \file Use standard C library call. */ -#include #include +#include + #include #include @@ -31,19 +32,16 @@ namespace contrib { using namespace runtime; -template -bool CompareAscend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { return lhs.second < rhs.second; } -template -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; } - // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -51,17 +49,16 @@ bool CompareDescend(const std::pair& lhs, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *sort_num = args[1]; - DLTensor *output = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* sort_num = args[1]; + DLTensor* output = args[2]; int32_t axis = args[3]; bool is_ascend = args[4]; auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - auto sort_num_ptr = static_cast(sort_num->data); + auto data_ptr = static_cast(input->data); + auto sort_num_ptr = static_cast(sort_num->data); std::vector> sorter; int64_t axis_mul_before = 1; int64_t axis_mul_after = 1; @@ -72,13 +69,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float."; + "to be float."; #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; + "to be float32."; #endif CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { @@ -88,8 +86,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") } } - for (int64_t i = 0 ; i < axis_mul_before; ++i) { - for (int64_t j = 0 ; j < axis_mul_after; ++j) { + for (int64_t i = 0; i < axis_mul_before; ++i) { + for (int64_t j = 0; j < axis_mul_after; ++j) { sorter.clear(); int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; @@ -103,7 +101,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif @@ -113,24 +111,24 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) - = k < static_cast(sorter.size()) ? sorter[k].first : k; + *(static_cast(output->data) + base_idx + k * axis_mul_after) = + k < static_cast(sorter.size()) ? sorter[k].first : k; } } } }); -template +template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - auto data_ptr = static_cast(input->data); - auto out_ptr = static_cast(output->data); - std::vector > sorter; + auto data_ptr = static_cast(input->data); + auto out_ptr = static_cast(output->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -142,8 +140,8 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { } } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; for (int64_t k = 0; k < input->shape[axis]; ++k) { @@ -169,17 +167,17 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *output = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; if (axis < 0) { axis = input->ndim + axis; } CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; auto data_dtype = DLDataType2String(input->dtype); auto out_dtype = DLDataType2String(output->dtype); @@ -228,7 +226,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { argsort(input, output, axis, is_ascend); } else if (out_dtype == "int64") { @@ -245,19 +243,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } }); -template -void topk(DLTensor* input, - DLTensor* out_values, - DLTensor* out_indices, - int k, - int axis, +template +void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, bool is_ascend) { - DataType* data_ptr = static_cast(input->data); - DataType* values_ptr = (out_values == nullptr) ? nullptr : - static_cast(out_values->data); - IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : - static_cast(out_indices->data); - std::vector > sorter; + DataType* data_ptr = static_cast(input->data); + DataType* values_ptr = + (out_values == nullptr) ? nullptr : static_cast(out_values->data); + IndicesType* indices_ptr = + (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -272,8 +266,8 @@ void topk(DLTensor* input, k = input->shape[axis]; } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; @@ -290,11 +284,10 @@ void topk(DLTensor* input, for (int64_t kk = 0; kk < cnt; ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(sorter[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); } } } @@ -308,8 +301,7 @@ void topk(DLTensor* input, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* input = args[0]; DLTensor* values_out = nullptr; DLTensor* indices_out = nullptr; @@ -371,7 +363,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { topk(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 56d3ce93433e4..53d7754be9469 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -20,53 +20,52 @@ /*! * \file tflite_runtime.cc */ -#include +#include "tflite_runtime.h" + #include #include #include - - -#include "tflite_runtime.h" +#include namespace tvm { namespace runtime { -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } DataType TfLiteDType2TVMDType(TfLiteType dtype) { @@ -91,12 +90,11 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { } } -void TFLiteRuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); @@ -108,24 +106,22 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -void TFLiteRuntime::Invoke() { - interpreter_->Invoke(); -} +void TFLiteRuntime::Invoke() { interpreter_->Invoke(); } void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { DataType dtype(data_in->dtype); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = interpreter_->typed_input_tensor(index); - DType* src = static_cast(data_in->data); - CHECK(data_in->strides == NULL); - int64_t size = 1; - for (int64_t i = 0; i < data_in->ndim; ++i) { - size *= data_in->shape[i]; - } - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = interpreter_->typed_input_tensor(index); + DType* src = static_cast(data_in->data); + CHECK(data_in->strides == NULL); + int64_t size = 1; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + } + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); } NDArray TFLiteRuntime::GetOutput(int index) const { @@ -140,48 +136,42 @@ NDArray TFLiteRuntime::GetOutput(int index) const { } NDArray ret = NDArray::Empty(shape, dtype, ctx_); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = static_cast(ret->data); - DType* src = interpreter_->typed_output_tensor(index); - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = static_cast(ret->data); + DType* src = interpreter_->typed_output_tensor(index); + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); return ret; } -PackedFunc TFLiteRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc TFLiteRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = args[0]; - CHECK_GE(in_idx, 0); - this->SetInput(in_idx, args[1]); - }); + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + this->SetInput(in_idx, args[1]); + }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutput(args[0]); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Invoke(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else { return PackedFunc(); } } -Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = TFLiteRuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TFLiteRuntimeCreate(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index d823690126b1a..f61f6ee37e0b9 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -29,9 +29,9 @@ #include #include -#include -#include #include +#include +#include namespace tvm { namespace runtime { @@ -52,18 +52,15 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const { - return "TFLiteRuntime"; - } + const char* type_key() const { return "TFLiteRuntime"; } /*! - * \brief Invoke the internal tflite interpreter and run the whole model in + * \brief Invoke the internal tflite interpreter and run the whole model in * dependency order. */ void Invoke(); @@ -73,8 +70,7 @@ class TFLiteRuntime : public ModuleNode { * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); /*! * \brief set index-th input to the model. diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 920bdae5b964c..c70a4f29ccbe4 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -22,10 +22,12 @@ */ #include #include -#include #include +#include + #include #include + #include "workspace_pool.h" #ifdef __ANDROID__ @@ -42,9 +44,7 @@ class CPUDeviceAPI final : public DeviceAPI { *rv = 1; } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { void* ptr; #if _MSC_VER @@ -69,53 +69,38 @@ class CPUDeviceAPI final : public DeviceAPI { #endif } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { - memcpy(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct CPUWorkspacePool : public WorkspacePool { - CPUWorkspacePool() : - WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} + CPUWorkspacePool() : WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} }; -void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); +void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(ctx, data); } -TVM_REGISTER_GLOBAL("device_api.cpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CPUDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/crt/graph_runtime.c b/src/runtime/crt/graph_runtime.c index a4c07f48ddf03..ab96a0cef7cef 100644 --- a/src/runtime/crt/graph_runtime.c +++ b/src/runtime/crt/graph_runtime.c @@ -815,10 +815,10 @@ void TVMGraphRuntime_Init(TVMGraphRuntime * runtime, const char * graph_json, const TVMModule * module, const TVMContext * ctxs) { JSONReader reader = JSONReader_Create(graph_json); runtime->Load(runtime, &reader); + JSONReader_Release(&reader); runtime->ctxs[0] = ctxs[0]; runtime->SetupStorage(runtime); runtime->SetupOpExecs(runtime); - JSONReader_Release(&reader); } TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, diff --git a/src/runtime/crt/graph_runtime.h b/src/runtime/crt/graph_runtime.h index 3cb8ba95e0fac..fd3b146332229 100644 --- a/src/runtime/crt/graph_runtime.h +++ b/src/runtime/crt/graph_runtime.h @@ -27,9 +27,9 @@ #include #include "load_json.h" +#include "module.h" #include "ndarray.h" #include "packed_func.h" -#include "module.h" /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { @@ -51,7 +51,7 @@ typedef struct TVMGraphRuntimeNodeEntry { uint32_t index; uint32_t version; // JSON Loader - void (*Load)(JSONReader *reader); + void (*Load)(JSONReader* reader); } TVMGraphRuntimeNodeEntry; // Node @@ -63,26 +63,26 @@ typedef struct TVMGraphRuntimeNode { // parameters TVMOpParam param; // inputs - TVMGraphRuntimeNodeEntry * inputs; + TVMGraphRuntimeNodeEntry* inputs; // number of inputs size_t inputs_count; // control deps uint32_t control_deps[20]; // JSON Loader - void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param); + void (*LoadAttrs)(struct TVMGraphRuntimeNode* node, JSONReader* reader, TVMOpParam* param); // JSON Loader - int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader); + int (*Load)(struct TVMGraphRuntimeNode* node, JSONReader* reader); } TVMGraphRuntimeNode; // Graph attribute typedef struct TVMGraphRuntimeGraphAttr { uint32_t storage_num_not_alloctaed; - uint32_t * storage_id; - uint32_t * device_index; - char * dltype; // "int8", "int16", "float32" + uint32_t* storage_id; + uint32_t* device_index; + char* dltype; // "int8", "int16", "float32" uint32_t dltype_count; - int64_t * shape; - uint32_t * ndim; + int64_t* shape; + uint32_t* ndim; uint32_t shape_count; } TVMGraphRuntimeGraphAttr; @@ -96,7 +96,7 @@ typedef DLTensor* DLTensorPtr; */ /* class GraphRuntime : public ModuleNode { */ typedef struct TVMGraphRuntime { - void (*Run)(struct TVMGraphRuntime * runtime); + void (*Run)(struct TVMGraphRuntime* runtime); /*! * \brief Initialize the graph executor with graph and context. @@ -107,10 +107,8 @@ typedef struct TVMGraphRuntime { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ - void (*Init)(struct TVMGraphRuntime * runtime, - const char * graph_json, - const TVMModule * module, - const TVMContext * ctxs); + void (*Init)(struct TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module, + const TVMContext* ctxs); /*! * \brief Get the input index given the name of input. @@ -118,7 +116,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \return The index of input. */ - int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name); + int (*GetInputIndex)(struct TVMGraphRuntime* runtime, const char* name); /*! * \brief set input to the graph based on name. @@ -126,7 +124,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \param data_in The input data. */ - void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); + void (*SetInput)(struct TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); /*! * \brief Return NDArray for given output index. @@ -135,7 +133,7 @@ typedef struct TVMGraphRuntime { * \param out The DLTensor corresponding to given output node index. * \return The result of this function execution. */ - int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out); + int (*GetOutput)(struct TVMGraphRuntime* runtime, const int32_t index, DLTensor* out); /*! * \brief Load parameters from parameter blob. * \param runtime The graph runtime. @@ -143,15 +141,15 @@ typedef struct TVMGraphRuntime { * \param param_size The parameter size. * \return The result of this function execution. */ - int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob, + int (*LoadParams)(struct TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); // The graph attribute fields. - int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader); + int (*Load)(struct TVMGraphRuntime* runtime, JSONReader* reader); /*! \brief Setup the temporal storage */ - void (*SetupStorage)(struct TVMGraphRuntime * runtime); + void (*SetupStorage)(struct TVMGraphRuntime* runtime); /*! \brief Setup the executors. */ - int (*SetupOpExecs)(struct TVMGraphRuntime * runtime); + int (*SetupOpExecs)(struct TVMGraphRuntime* runtime); /*! * \brief Create an execution function given input. @@ -163,25 +161,25 @@ typedef struct TVMGraphRuntime { * \param pf The created executor. * \return The result of this function execution. */ - int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs, - DLTensorPtr * args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc * pf); + int32_t (*CreateTVMOp)(struct TVMGraphRuntime* runtime, const TVMOpParam* attrs, + DLTensorPtr* args, const uint32_t args_count, uint32_t num_inputs, + TVMPackedFunc* pf); // Get node entry index. - uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index); + uint32_t (*GetEntryId)(struct TVMGraphRuntime* runtime, uint32_t nid, uint32_t index); /*! \brief The graph nodes. */ - TVMGraphRuntimeNode * nodes; + TVMGraphRuntimeNode* nodes; /*! \brief The graph nodes counter. */ uint32_t nodes_count; /*! \brief The argument nodes. */ - uint32_t * input_nodes; + uint32_t* input_nodes; uint32_t input_nodes_count; /*! \brief Used for quick entry indexing. */ - uint32_t * node_row_ptr; + uint32_t* node_row_ptr; uint32_t node_row_ptr_count; /*! \brief Output entries. */ - TVMGraphRuntimeNodeEntry * outputs; + TVMGraphRuntimeNodeEntry* outputs; /*! \brief Output entries counter. */ uint32_t outputs_count; /*! \brief Additional graph attributes. */ @@ -190,28 +188,28 @@ typedef struct TVMGraphRuntime { TVMModule module; /*! \brief Execution context of all devices including the host. */ TVMContext ctxs[1]; - uint32_t ctxs_count; + uint32_t ctxs_count; /*! \brief Common storage pool for all devices. */ - TVMNDArray * storage_pool; + TVMNDArray* storage_pool; uint32_t storage_pool_count; /*! \brief Data entry of each node. */ - TVMNDArray * data_entry; + TVMNDArray* data_entry; uint32_t data_entry_count; /*! \brief Operator on each node. */ - TVMPackedFunc * op_execs; + TVMPackedFunc* op_execs; uint32_t op_execs_count; } TVMGraphRuntime; // public functions -TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m, - const TVMContext * ctxs); -void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime); +TVMGraphRuntime* TVMGraphRuntimeCreate(const char* sym_json, const TVMModule* m, + const TVMContext* ctxs); +void TVMGraphRuntimeRelease(TVMGraphRuntime** runtime); // private functions -void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); -int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, +void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); +int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); -void TVMGraphRuntime_Run(TVMGraphRuntime * runtime); -int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out); +void TVMGraphRuntime_Run(TVMGraphRuntime* runtime); +int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out); #endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ diff --git a/src/runtime/crt/load_json.h b/src/runtime/crt/load_json.h index a5df7a055af05..0c9324777c1d7 100644 --- a/src/runtime/crt/load_json.h +++ b/src/runtime/crt/load_json.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_ #define TVM_RUNTIME_CRT_LOAD_JSON_H_ -#include #include +#include enum { JSON_READ_TYPE_U8 = 1, @@ -42,12 +42,12 @@ enum { }; typedef struct Seq { - uint32_t * data; + uint32_t* data; uint64_t allocated; uint32_t size; - void (*push_back)(struct Seq * seq, uint32_t src); - uint32_t * (*back)(struct Seq * seq); - void (*pop_back)(struct Seq * seq); + void (*push_back)(struct Seq* seq, uint32_t src); + uint32_t* (*back)(struct Seq* seq); + void (*pop_back)(struct Seq* seq); } Seq; /*! @@ -56,8 +56,8 @@ typedef struct Seq { */ typedef struct JSONReader { /*! \brief internal reader string */ - char * is_; - char * isptr; + char* is_; + char* isptr; /*! \brief "\\r" counter */ size_t line_count_r_; /*! \brief "\\n" counter */ @@ -66,27 +66,27 @@ typedef struct JSONReader { * \brief record how many element processed in * current array/object scope. */ - Seq * scope_counter_; + Seq* scope_counter_; - char (*NextChar)(struct JSONReader * reader); - char (*NextNonSpace)(struct JSONReader * reader); - char (*PeekNextChar)(struct JSONReader * reader); - char (*PeekNextNonSpace)(struct JSONReader * reader); - int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value); - int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value); - int (*ReadString)(struct JSONReader * reader, char * out_value); - void (*BeginArray)(struct JSONReader * reader); - void (*BeginObject)(struct JSONReader * reader); - uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key); - uint8_t (*NextArrayItem)(struct JSONReader * reader); + char (*NextChar)(struct JSONReader* reader); + char (*NextNonSpace)(struct JSONReader* reader); + char (*PeekNextChar)(struct JSONReader* reader); + char (*PeekNextNonSpace)(struct JSONReader* reader); + int (*ReadUnsignedInteger)(struct JSONReader* reader, unsigned int* out_value); + int (*ReadInteger)(struct JSONReader* reader, int64_t* out_value); + int (*ReadString)(struct JSONReader* reader, char* out_value); + void (*BeginArray)(struct JSONReader* reader); + void (*BeginObject)(struct JSONReader* reader); + uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key); + uint8_t (*NextArrayItem)(struct JSONReader* reader); } JSONReader; /*! * \brief Constructor of JSONReader class * \param is the input source. */ -JSONReader JSONReader_Create(const char * is); +JSONReader JSONReader_Create(const char* is); -void JSONReader_Release(JSONReader * reader); +void JSONReader_Release(JSONReader* reader); #endif // TVM_RUNTIME_CRT_LOAD_JSON_H_ diff --git a/src/runtime/crt/logging.h b/src/runtime/crt/logging.h index 2c58834ca6a94..c711b3aa3bb94 100644 --- a/src/runtime/crt/logging.h +++ b/src/runtime/crt/logging.h @@ -27,31 +27,31 @@ #define TVM_RUNTIME_CRT_LOGGING_H_ #ifndef CHECK -#define CHECK(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "Check failed: %s\n", #x); \ - exit(-1); \ - } \ - }while(0) +#define CHECK(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "Check failed: %s\n", #x); \ + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_BINARY_OP -#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ - do { \ - if (!(x op y)) { \ +#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ + do { \ + if (!(x op y)) { \ fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \ - exit(-1); \ - } \ - }while(0) + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_LT -#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) +#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_GT -#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) +#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_LE diff --git a/src/runtime/crt/module.h b/src/runtime/crt/module.h index 9ef287d650d84..57f8dd708f883 100644 --- a/src/runtime/crt/module.h +++ b/src/runtime/crt/module.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_MODULE_H_ #define TVM_RUNTIME_CRT_MODULE_H_ -#include #include +#include struct TVMPackedFunc; @@ -41,7 +41,7 @@ typedef struct TVMModule { * * This function will return PackedFunc(nullptr) if function do not exist. */ - void (*GetFunction)(struct TVMModule * mod, const char * name, struct TVMPackedFunc * pf); + void (*GetFunction)(struct TVMModule* mod, const char* name, struct TVMPackedFunc* pf); } TVMModule; #endif // TVM_RUNTIME_CRT_MODULE_H_ diff --git a/src/runtime/crt/ndarray.h b/src/runtime/crt/ndarray.h index dde23ca6cd411..ae76726ae0b91 100644 --- a/src/runtime/crt/ndarray.h +++ b/src/runtime/crt/ndarray.h @@ -24,13 +24,12 @@ #ifndef TVM_RUNTIME_CRT_NDARRAY_H_ #define TVM_RUNTIME_CRT_NDARRAY_H_ -#include -#include #include - -#include #include #include +#include +#include +#include /*! \brief Magic number for NDArray file */ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; @@ -42,17 +41,17 @@ typedef struct TVMNDArray { DLTensor dl_tensor; } TVMNDArray; -TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -int TVMNDArray_Load(TVMNDArray * ret, const char ** strm); +int TVMNDArray_Load(TVMNDArray* ret, const char** strm); -TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, - uint32_t ndim, DLDataType dtype); +TVMNDArray TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, uint32_t ndim, + DLDataType dtype); -int TVMNDArray_Release(TVMNDArray * arr); +int TVMNDArray_Release(TVMNDArray* arr); #endif // TVM_RUNTIME_CRT_NDARRAY_H_ diff --git a/src/runtime/crt/packed_func.h b/src/runtime/crt/packed_func.h index 93898a436c885..d4597e62fd0ff 100644 --- a/src/runtime/crt/packed_func.h +++ b/src/runtime/crt/packed_func.h @@ -24,29 +24,34 @@ #ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_ #define TVM_RUNTIME_CRT_PACKED_FUNC_H_ -#include - +#include #include #include -#include +#include #include "module.h" -static inline DLDataType String2DLDataType(const char * s) { +static inline DLDataType String2DLDataType(const char* s) { DLDataType t; // handle None type if (strlen(s) == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; + t.code = kTVMOpaqueHandle; return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (!strncmp(s, "int", 3)) { - t.code = kDLInt; scan = s + 3; + t.code = kDLInt; + scan = s + 3; } else if (!strncmp(s, "uint", 4)) { - t.code = kDLUInt; scan = s + 4; + t.code = kDLUInt; + scan = s + 4; } else if (!strncmp(s, "float", 5)) { - t.code = kDLFloat; scan = s + 5; + t.code = kDLFloat; + scan = s + 5; } else if (!strncmp(s, "handle", 6)) { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. @@ -75,11 +80,11 @@ static inline DLDataType String2DLDataType(const char * s) { typedef struct TVMArgs { TVMValue values[TVM_CRT_MAX_ARGS]; - int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ + int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ uint32_t values_count; } TVMArgs; -static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) { +static inline TVMArgs TVMArgs_Create(TVMValue* values, uint32_t* tcodes, uint32_t values_count) { uint32_t idx; TVMArgs args; memset(&args, 0, sizeof(args)); @@ -91,8 +96,8 @@ static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint3 return args; } -static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args, - TVMRetValueHandle ret, void * res) { +static inline int TVMNoOperation(TVMValue* args, int* type_codes, int num_args, + TVMRetValueHandle ret, void* res) { return 0; } @@ -100,24 +105,24 @@ typedef struct TVMPackedFunc { char name[200]; TVMPackedCFunc fexec; TVMArgs args; - void (*Call)(struct TVMPackedFunc * pf); - void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args); + void (*Call)(struct TVMPackedFunc* pf); + void (*SetArgs)(struct TVMPackedFunc* pf, const struct TVMArgs* args); } TVMPackedFunc; -static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) { +static inline void TVMPackedFunc_Call(TVMPackedFunc* pf) { pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0); } -static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) { +static inline void TVMPackedFunc_SetArgs(TVMPackedFunc* pf, const TVMArgs* args) { memcpy(&(pf->args), args, sizeof(TVMArgs)); } -TVMPackedFunc * g_fexecs = 0; +TVMPackedFunc* g_fexecs = 0; uint32_t g_fexecs_count = 0; // Implement TVMModule::GetFunction // Put implementation in this file so we have seen the TVMPackedFunc -static inline void TVMModule_GetFunction(TVMModule * mod, const char * name, TVMPackedFunc * pf) { +static inline void TVMModule_GetFunction(TVMModule* mod, const char* name, TVMPackedFunc* pf) { int idx; memset(pf, 0, sizeof(TVMPackedFunc)); assert(strlen(name) <= sizeof(pf->name)); diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index 87cf3be5491df..25ff28a91a6cd 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,7 +26,9 @@ #include #include + #include + #include "../workspace_pool.h" namespace tvm { @@ -36,18 +38,16 @@ namespace runtime { { \ CUresult result = x; \ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ - const char *msg; \ + const char* msg; \ cuGetErrorName(result, &msg); \ - LOG(FATAL) \ - << "CUDAError: " #x " failed with error: " << msg; \ + LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \ } \ } -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index d9f03e773bc95..a6d4a54994697 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -21,13 +21,14 @@ * \file cuda_device_api.cc * \brief GPU specific API */ -#include - -#include -#include #include #include +#include +#include +#include + #include + #include "cuda_common.h" namespace tvm { @@ -35,40 +36,32 @@ namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - CUDA_CALL(cudaSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { case kExist: - value = ( - cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) - == cudaSuccess); + value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) == + cudaSuccess); break; case kMaxThreadsPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrWarpSize, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); + CUDA_CALL( + cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); os << value << "."; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -81,40 +74,33 @@ class CUDADeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrClockRate, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - CUDA_CALL(cudaDeviceGetAttribute( - &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); return; } - case kGcnArch: return; + case kGcnArch: + return; } *rv = value; } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { - CHECK_EQ(256 % alignment, 0U) - << "CUDA space is aligned at 256 bytes"; - void *ret; + CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; + void* ret; if (ctx.device_type == kDLCPUPinned) { CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { @@ -133,14 +119,8 @@ class CUDADeviceAPI final : public DeviceAPI { } } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { cudaStream_t cu_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -156,8 +136,8 @@ class CUDADeviceAPI final : public DeviceAPI { // In case there is a copy from host mem to host mem */ if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) { - memcpy(to, from, size); - return; + memcpy(to, from, size); + return; } if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { @@ -165,9 +145,7 @@ class CUDADeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); } else { - cudaMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, cu_stream); + cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream); } } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); @@ -210,8 +188,7 @@ class CUDADeviceAPI final : public DeviceAPI { } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - CUDAThreadEntry::ThreadLocal() - ->stream = static_cast(stream); + CUDAThreadEntry::ThreadLocal()->stream = static_cast(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { @@ -223,16 +200,12 @@ class CUDADeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, - void* to, - size_t size, - cudaMemcpyKind kind, + static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind, cudaStream_t stream) { if (stream != 0) { CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); @@ -244,25 +217,19 @@ class CUDADeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore CUDAThreadStore; -CUDAThreadEntry::CUDAThreadEntry() - : pool(kDLGPU, CUDADeviceAPI::Global()) { -} +CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {} -CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { - return CUDAThreadStore::Get(); -} +CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.gpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); -TVM_REGISTER_GLOBAL("device_api.cpu_pinned") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 0550712de9ab7..498a9b703a7b8 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -22,19 +22,21 @@ */ #include "cuda_module.h" -#include #include #include -#include +#include + #include -#include #include +#include #include -#include "cuda_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "cuda_common.h" namespace tvm { namespace runtime { @@ -45,8 +47,7 @@ namespace runtime { // The modules will be lazily loaded class CUDAModuleNode : public runtime::ModuleNode { public: - explicit CUDAModuleNode(std::string data, - std::string fmt, + explicit CUDAModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string cuda_source) : data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) { @@ -62,16 +63,11 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "cuda"; - } + const char* type_key() const final { return "cuda"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -79,8 +75,7 @@ class CUDAModuleNode : public runtime::ModuleNode { SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, cuda_source_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); } @@ -112,18 +107,14 @@ class CUDAModuleNode : public runtime::ModuleNode { CUfunction func; CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetFunction " << func_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg; } return func; } // get a global var from primary context in device_id - CUdeviceptr GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -132,15 +123,12 @@ class CUDAModuleNode : public runtime::ModuleNode { CUdeviceptr global; size_t nbytes; - CUresult result = cuModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str()); + CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()); CHECK_EQ(nbytes, expect_nbytes); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetGlobal " << global_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg; } return global; } @@ -164,11 +152,8 @@ class CUDAModuleNode : public runtime::ModuleNode { class CUDAWrappedFunc { public: // initialize the CUDA function. - void Init(CUDAModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -176,9 +161,7 @@ class CUDAWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -186,24 +169,17 @@ class CUDAWrappedFunc { } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - CUresult result = cuLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, void_args, 0); + CUresult result = + cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), 0, strm, void_args, 0); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); std::ostringstream os; os << "CUDALaunch Error: " << msg << "\n" - << " grid=(" << wl.grid_dim(0) << "," - << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " - << " block=(" << wl.block_dim(0) << "," - << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n"; + << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " + << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) + << ")\n"; std::string cuda = m_->GetSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" @@ -231,9 +207,7 @@ class CUDAWrappedFunc { class CUDAPrepGlobalBarrier { public: - CUDAPrepGlobalBarrier(CUDAModuleNode* m, - ObjectPtr sptr) - : m_(m), sptr_(sptr) { + CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr sptr) : m_(m), sptr_(sptr) { std::fill(pcache_.begin(), pcache_.end(), 0); } @@ -241,8 +215,8 @@ class CUDAPrepGlobalBarrier { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (pcache_[device_id] == 0) { - pcache_[device_id] = m_->GetGlobal( - device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); + pcache_[device_id] = + m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); } CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1)); } @@ -256,12 +230,10 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -PackedFunc CUDAModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc CUDAModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); } @@ -273,18 +245,15 @@ PackedFunc CUDAModuleNode::GetFunction( return PackFuncVoidAddr(f, info.arg_types); } -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { auto n = make_object(data, fmt, fmap, cuda_source); return Module(n); } // Load module from module. -Module CUDAModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -305,13 +274,10 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda") -.set_body_typed(CUDAModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index bce0d63e98a16..e65c5fe60811a 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_CUDA_CUDA_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,11 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param cuda_source Optional, cuda source file */ -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source); +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 378f976dead10..6d3eec402306c 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,10 +21,11 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ -#include #include -#include +#include #include +#include + #include "library_module.h" #if defined(_WIN32) @@ -43,13 +44,9 @@ class DSOLibrary final : public Library { ~DSOLibrary() { if (lib_handle_) Unload(); } - void Init(const std::string& name) { - Load(name); - } + void Init(const std::string& name) { Load(name); } - void* GetSymbol(const char* name) final { - return GetSymbol_(name); - } + void* GetSymbol(const char* name) final { return GetSymbol_(name); } private: // Platform dependent handling. @@ -58,8 +55,7 @@ class DSOLibrary final : public Library { HMODULE lib_handle_{nullptr}; void* GetSymbol_(const char* name) { - return reinterpret_cast( - GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } // Load the library @@ -67,8 +63,7 @@ class DSOLibrary final : public Library { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); lib_handle_ = LoadLibraryW(wname.c_str()); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name; + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void Unload() { @@ -81,14 +76,11 @@ class DSOLibrary final : public Library { // load the library void Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name - << " " << dlerror(); + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); } - void* GetSymbol_(const char* name) { - return dlsym(lib_handle_, name); - } + void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } void Unload() { dlclose(lib_handle_); @@ -97,11 +89,10 @@ class DSOLibrary final : public Library { #endif }; -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->Init(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.cc b/src/runtime/file_util.cc index f94b2d37b72b3..68d174e470a24 100644 --- a/src/runtime/file_util.cc +++ b/src/runtime/file_util.cc @@ -20,13 +20,15 @@ /*! * \file file_util.cc */ +#include "file_util.h" + #include #include #include + #include -#include #include -#include "file_util.h" +#include namespace tvm { namespace runtime { @@ -69,8 +71,7 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { return true; } -std::string GetFileFormat(const std::string& file_name, - const std::string& format) { +std::string GetFileFormat(const std::string& file_name, const std::string& format) { std::string fmt = format; if (fmt.length() == 0) { size_t pos = file_name.find_last_of("."); @@ -103,7 +104,7 @@ std::string GetFileBasename(const std::string& file_name) { } std::string GetMetaFilePath(const std::string& file_name) { - size_t pos = file_name.find_last_of("."); + size_t pos = file_name.find_last_of("."); if (pos != std::string::npos) { return file_name.substr(0, pos) + ".tvm_meta.json"; } else { @@ -111,8 +112,7 @@ std::string GetMetaFilePath(const std::string& file_name) { } } -void LoadBinaryFromFile(const std::string& file_name, - std::string* data) { +void LoadBinaryFromFile(const std::string& file_name, std::string* data) { std::ifstream fs(file_name, std::ios::in | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; // get its size: @@ -123,17 +123,14 @@ void LoadBinaryFromFile(const std::string& file_name, fs.read(&(*data)[0], size); } -void SaveBinaryToFile( - const std::string& file_name, - const std::string& data) { +void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap) { +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap) { std::string version = "0.1.0"; std::ofstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; @@ -145,9 +142,8 @@ void SaveMetaDataToFile( fs.close(); } -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap) { +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap) { std::ifstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; std::string version; @@ -159,9 +155,7 @@ void LoadMetaDataFromFile( fs.close(); } -void RemoveFile(const std::string& file_name) { - std::remove(file_name.c_str()); -} +void RemoveFile(const std::string& file_name) { std::remove(file_name.c_str()); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.h b/src/runtime/file_util.h index dfbaa16bded6f..1c350357ec9af 100644 --- a/src/runtime/file_util.h +++ b/src/runtime/file_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include + #include "meta_data.h" namespace tvm { @@ -35,8 +36,7 @@ namespace runtime { * \param file_name The name of the file. * \param format The format of the file. */ -std::string GetFileFormat(const std::string& file_name, - const std::string& format); +std::string GetFileFormat(const std::string& file_name, const std::string& format); /*! * \return the directory in which TVM stores cached files. @@ -62,34 +62,30 @@ std::string GetFileBasename(const std::string& file_name); * \param file_name The name of the file. * \param data The data to be loaded. */ -void LoadBinaryFromFile(const std::string& file_name, - std::string* data); +void LoadBinaryFromFile(const std::string& file_name, std::string* data); /*! * \brief Load binary file into a in-memory buffer. * \param file_name The name of the file. * \param data The binary data to be saved. */ -void SaveBinaryToFile(const std::string& file_name, - const std::string& data); +void SaveBinaryToFile(const std::string& file_name, const std::string& data); /*! * \brief Save meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap); +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap); /*! * \brief Load meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap); +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap); /*! * \brief Remove (unlink) a file. diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 1c85de8592739..9f206fd48d6ea 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -20,12 +20,13 @@ /*! * \file graph_runtime_debug.cc */ +#include #include #include -#include #include #include + #include "../graph_runtime.h" namespace tvm { @@ -59,15 +60,14 @@ class GraphRuntimeDebug : public GraphRuntime { std::ostringstream os; std::vector time_per_op(op_execs_.size(), 0); for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + std::chrono::time_point tbegin, + tend; double duration_ms = 0.0; do { std::fill(time_per_op.begin(), time_per_op.end(), 0); if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random } tbegin = std::chrono::high_resolution_clock::now(); for (int k = 0; k < number; k++) { @@ -78,15 +78,17 @@ class GraphRuntimeDebug : public GraphRuntime { op_execs_[index](); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration = std::chrono::duration_cast< - std::chrono::duration >(op_tend - op_tbegin).count(); + double op_duration = + std::chrono::duration_cast >(op_tend - op_tbegin) + .count(); time_per_op[index] += op_duration * 1e6; // us } } } tend = std::chrono::high_resolution_clock::now(); - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; + duration_ms = + std::chrono::duration_cast >(tend - tbegin).count() * + 1000; } while (duration_ms < min_repeat_ms); LOG(INFO) << "Iteration: " << i; @@ -94,8 +96,8 @@ class GraphRuntimeDebug : public GraphRuntime { for (size_t index = 0; index < time_per_op.size(); index++) { if (op_execs_[index]) { time_per_op[index] /= number; - LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " - << time_per_op[index] << " us/iter"; + LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_per_op[index] + << " us/iter"; } } } @@ -110,17 +112,14 @@ class GraphRuntimeDebug : public GraphRuntime { * \param index The index of op which needs to be returned. * \param eid The Entry id of the op. */ - NDArray GetOutputByLayer(int index, int eid) { - return data_entry_[entry_id(index, eid)]; - } + NDArray GetOutputByLayer(int index, int eid) { return data_entry_[entry_id(index, eid)]; } /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \brief Get the node index given the name of node. @@ -135,53 +134,51 @@ class GraphRuntimeDebug : public GraphRuntime { } LOG(FATAL) << "cannot find " << name << " among nodex"; return -1; -} + } -/*! - * \brief Copy index-th node to data_out. - * - * This method will do a partial run of the the graph - * from begining upto the index-th node and return output of index-th node. - * This is costly operation and suggest to use only for debug porpose. - * - * \param index: The index of the node. - * \param data_out the node data. - */ -void DebugGetNodeOutput(int index, DLTensor* data_out) { - CHECK_LT(static_cast(index), op_execs_.size()); - uint32_t eid = index; + /*! + * \brief Copy index-th node to data_out. + * + * This method will do a partial run of the the graph + * from begining upto the index-th node and return output of index-th node. + * This is costly operation and suggest to use only for debug porpose. + * + * \param index: The index of the node. + * \param data_out the node data. + */ + void DebugGetNodeOutput(int index, DLTensor* data_out) { + CHECK_LT(static_cast(index), op_execs_.size()); + uint32_t eid = index; - for (size_t i = 0; i < op_execs_.size(); ++i) { - if (op_execs_[i]) op_execs_[i](); - if (static_cast(i) == index) break; - } + for (size_t i = 0; i < op_execs_.size(); ++i) { + if (op_execs_[i]) op_execs_[i](); + if (static_cast(i) == index) break; + } - data_entry_[eid].CopyTo(data_out); -} + data_entry_[eid].CopyTo(data_out); + } }; - /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ -PackedFunc GraphRuntimeDebug::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // return member functions during query. if (name == "get_output_by_layer") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutputByLayer(args[0], args[1]); - }); + *rv = this->GetOutputByLayer(args[0], args[1]); + }); } else if (name == "debug_get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); - } else { - this->DebugGetNodeOutput(args[0], args[1]); - } - }); + if (args[0].type_code() == kTVMStr) { + this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); + } else { + this->DebugGetNodeOutput(args[0], args[1]); + } + }); } else if (name == "run_individual") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int number = args[0]; @@ -203,21 +200,18 @@ PackedFunc GraphRuntimeDebug::GetFunction( * \param m Compiled module which will be loaded. * \param ctxs All devices contexts. */ -Module GraphRuntimeDebugCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index f3bcedf088dd5..239e43d93e50a 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -20,6 +20,8 @@ /*! * \file graph_runtime.cc */ +#include "graph_runtime.h" + #include #include #include @@ -35,8 +37,6 @@ #include #include -#include "graph_runtime.h" - namespace tvm { namespace runtime { namespace details { @@ -64,8 +64,7 @@ void GraphRuntime::Run() { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ -void GraphRuntime::Init(const std::string& graph_json, - tvm::runtime::Module module, +void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); @@ -133,9 +132,7 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { * * \return The number of outputs from graph. */ -int GraphRuntime::NumOutputs() const { - return outputs_.size(); -} +int GraphRuntime::NumOutputs() const { return outputs_.size(); } /*! * \brief Return NDArray for given input index. * \param index The input index. @@ -188,21 +185,16 @@ void GraphRuntime::LoadParams(const std::string& param_blob) { void GraphRuntime::LoadParams(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; - CHECK(strm->Read(&names)) - << "Invalid parameters file format"; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; strm->Read(&sz); size_t size = static_cast(sz); - CHECK(size == names.size()) - << "Invalid parameters file format"; + CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; @@ -217,13 +209,10 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; @@ -268,15 +257,14 @@ void GraphRuntime::SetupStorage() { CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; - CHECK(bits % 8U == 0U || bits ==1U); + CHECK(bits % 8U == 0U || bits == 1U); size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id); if (sid >= pool_entry.size()) { pool_entry.resize(sid + 1, {0, -1}); } else { - CHECK(pool_entry[sid].device_type == -1 || - pool_entry[sid].device_type == device_type) + CHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type) << "The same pool entry cannot be assigned to multiple devices"; } pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); @@ -288,14 +276,12 @@ void GraphRuntime::SetupStorage() { std::vector shape; // This for loop is very fast since there are usually only a couple of // devices available on the same hardware. - const auto& cit = - std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { - return pit.device_type == static_cast(c.device_type); - }); + const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { + return pit.device_type == static_cast(c.device_type); + }); TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit; shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back( - NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); } // Assign the pooled entries. A unified memory pool is used to simplifiy @@ -306,8 +292,7 @@ void GraphRuntime::SetupStorage() { for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; CHECK_LT(static_cast(storage_id), storage_pool_.size()); - data_entry_[i] = - storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); const DLTensor* tmp = data_entry_[i].operator->(); data_alignment_[i] = details::GetDataAlignment(*tmp); } @@ -338,24 +323,20 @@ void GraphRuntime::SetupOpExecs() { CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; std::shared_ptr op_args = nullptr; - std::tie(op_execs_[nid], op_args) = - CreateTVMOp(inode.param, args, inode.inputs.size()); + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); for (size_t i = 0; i < inode.inputs.size(); i++) { uint32_t eid = this->entry_id(inode.inputs[i]); // check if op input is model input if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back( - static_cast(op_args->arg_values[i].v_handle)); + input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); } } } } std::pair, std::shared_ptr > GraphRuntime::CreateTVMOp( - const TVMOpParam& param, - const std::vector& args, - size_t num_inputs) { + const TVMOpParam& param, const std::vector& args, size_t num_inputs) { std::shared_ptr arg_ptr = std::make_shared(); // setup address. arg_ptr->args = args; @@ -369,15 +350,15 @@ std::pair, std::shared_ptr > GraphRu arg_ptr->arg_values.push_back(v); arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle); if (param.flatten_data) { - arg_ptr->shape_data[i] = std::accumulate( - t->shape, t->shape + t->ndim, 1, std::multiplies()); + arg_ptr->shape_data[i] = + std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies()); t->ndim = 1; t->shape = &(arg_ptr->shape_data[i]); } } if (param.func_name == "__nop") { - return {[](){}, arg_ptr}; + return {[]() {}, arg_ptr}; } else if (param.func_name == "__copy") { // Perform cross device data copy. // Directly copy data from the input to the output. @@ -396,27 +377,25 @@ std::pair, std::shared_ptr > GraphRu auto fexec = [arg_ptr, pf]() { TVMRetValue rv; - TVMArgs targs(arg_ptr->arg_values.data(), - arg_ptr->arg_tcodes.data(), + TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), static_cast(arg_ptr->arg_values.size())); pf.CallPacked(targs, &rv); }; return {fexec, arg_ptr}; } -PackedFunc GraphRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); - if (in_idx >= 0) this->SetInput(in_idx, args[1]); - } else { - this->SetInput(args[0], args[1]); - } - }); + if (args[0].type_code() == kTVMStr) { + int in_idx = this->GetInputIndex(args[0]); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args[0].type_code() == kTVMStr) { @@ -436,42 +415,38 @@ PackedFunc GraphRuntime::GetFunction( }); } else if (name == "get_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = 0; - if (args[0].type_code() == kTVMStr) { - in_idx = this->GetInputIndex(args[0]); - } else { - in_idx = args[0]; - } - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); - }); + int in_idx = 0; + if (args[0].type_code() == kTVMStr) { + in_idx = this->GetInputIndex(args[0]); + } else { + in_idx = args[0]; + } + CHECK_GE(in_idx, 0); + *rv = this->GetInput(in_idx); + }); } else if (name == "get_num_outputs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->NumOutputs(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); } else if (name == "run") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Run(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); } else if (name == "load_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); + this->LoadParams(args[0].operator std::string()); + }); } else if (name == "share_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - const auto& module = args[0].operator Module(); - CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); - const auto& param_blob = args[1].operator std::string(); - dmlc::MemoryStringStream strm(const_cast(¶m_blob)); - this->ShareParams(dynamic_cast(*module.operator->()), &strm); - }); + const auto& module = args[0].operator Module(); + CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + const auto& param_blob = args[1].operator std::string(); + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + this->ShareParams(dynamic_cast(*module.operator->()), &strm); + }); } else { return PackedFunc(); } } -Module GraphRuntimeCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); @@ -497,14 +472,12 @@ std::vector GetAllContext(const TVMArgs& args) { // execution support yet. For heterogenenous execution, at least 5 arguments will // be passed in. The third one is the number of devices. // Eventually, we will only probably pass TVMContext for all the languages. -TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeCreate(args[0], args[1], contexts); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeCreate(args[0], args[1], contexts); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index b787c0a53726c..d0c982281b34b 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -26,26 +26,25 @@ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #include -#include #include +#include #include #include #include +#include #include #include #include -#include namespace tvm { namespace runtime { /*! \brief macro to do C API call */ -#define TVM_CCALL(func) \ - { \ - int ret = (func); \ - CHECK_EQ(ret, 0) \ - << TVMGetLastError(); \ +#define TVM_CCALL(func) \ + { \ + int ret = (func); \ + CHECK_EQ(ret, 0) << TVMGetLastError(); \ } /*! \brief Magic number for NDArray list file */ @@ -80,15 +79,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "GraphRuntime"; - } + const char* type_key() const final { return "GraphRuntime"; } void Run(); /*! @@ -100,8 +96,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { * executed on. */ - void Init(const std::string& graph_json, - tvm::runtime::Module module, + void Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs); /*! @@ -172,14 +167,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \brief Get total number of nodes. * \return Total number of nodes. */ - uint32_t GetNumOfNodes() const { - return static_cast(nodes_.size()); - } - - std::string GetNodeName(uint32_t nid) const { - return nodes_[nid].name; - } + uint32_t GetNumOfNodes() const { return static_cast(nodes_.size()); } + std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } protected: // Memory pool entry. @@ -194,7 +184,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { uint32_t index; uint32_t version; // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -221,7 +211,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { // control deps std::vector control_deps; // JSON Loader - void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) { + void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { int bitmask = 0; std::string key, value; reader->BeginObject(); @@ -241,10 +231,10 @@ class TVM_DLL GraphRuntime : public ModuleNode { bitmask |= 8; } } - CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format"; } // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key; @@ -266,7 +256,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { LOG(FATAL) << "do not support key " << key; } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; struct GraphAttr { @@ -274,9 +264,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector storage_id; std::vector device_index; std::vector dltype; - std::vector > shape; + std::vector> shape; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key, type; @@ -334,37 +324,37 @@ class TVM_DLL GraphRuntime : public ModuleNode { CHECK(!reader->NextArrayItem()); } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { - reader->BeginObject(); - int bitmask = 0; - std::string key; - while (reader->NextObjectItem(&key)) { - if (key == "nodes") { - reader->Read(&nodes_); - bitmask |= 1; - } else if (key == "arg_nodes") { - reader->Read(&input_nodes_); - bitmask |= 2; - } else if (key == "node_row_ptr") { - reader->Read(&node_row_ptr_); - bitmask |= 4; - } else if (key == "heads") { - reader->Read(&outputs_); - bitmask |= 8; - } else if (key == "attrs") { - reader->Read(&attrs_); - bitmask |= 16; - } else if (key == "metadata") { - break; - } else { - LOG(FATAL) << "key " << key << " is not supported"; - } + void Load(dmlc::JSONReader* reader) { + reader->BeginObject(); + int bitmask = 0; + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "nodes") { + reader->Read(&nodes_); + bitmask |= 1; + } else if (key == "arg_nodes") { + reader->Read(&input_nodes_); + bitmask |= 2; + } else if (key == "node_row_ptr") { + reader->Read(&node_row_ptr_); + bitmask |= 4; + } else if (key == "heads") { + reader->Read(&outputs_); + bitmask |= 8; + } else if (key == "attrs") { + reader->Read(&attrs_); + bitmask |= 16; + } else if (key == "metadata") { + break; + } else { + LOG(FATAL) << "key " << key << " is not supported"; } - CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; + } + CHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format"; } /*! \brief Setup the temporal storage */ void SetupStorage(); @@ -377,21 +367,14 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param num_inputs Number of inputs. * \return The created executor. */ - std::pair, std::shared_ptr > CreateTVMOp( - const TVMOpParam& attrs, const std::vector& args, - size_t num_inputs); + std::pair, std::shared_ptr> CreateTVMOp( + const TVMOpParam& attrs, const std::vector& args, size_t num_inputs); // Get node entry index. - uint32_t entry_id(uint32_t nid, uint32_t index) const { - return node_row_ptr_[nid] + index; - } + uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. - uint32_t entry_id(const NodeEntry& e) const { - return entry_id(e.node_id, e.index); - } + uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } // Number of node entries. - uint32_t num_node_entries() const { - return node_row_ptr_.back(); - } + uint32_t num_node_entries() const { return node_row_ptr_.back(); } /*! \brief The graph nodes. */ std::vector nodes_; /*! \brief The argument nodes. */ @@ -417,7 +400,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { /*! \brief Data alignment of each node. */ std::vector data_alignment_; /*! \brief Operator on each node. */ - std::vector > op_execs_; + std::vector> op_execs_; }; std::vector GetAllContext(const TVMArgs& args); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index d88e6d7284a3e..fd6f323740055 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -33,21 +33,17 @@ class HexagonDeviceAPI : public DeviceAPI { public: void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t num_bytes, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, - TVMStreamHandle stream) final; + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; - void* AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint = {}) final; + void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final; void FreeWorkspace(TVMContext ctx, void* ptr) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; @@ -56,13 +52,11 @@ class HexagonDeviceAPI : public DeviceAPI { inline void HexagonDeviceAPI::SetDevice(TVMContext ctx) {} -inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, - TVMRetValue* rv) { +inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { if (kind == kExist) *rv = 1; } -inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, - size_t alignment, +inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); return hexagon::Device::Global()->Alloc(nbytes, alignment); @@ -73,10 +67,10 @@ inline void HexagonDeviceAPI::FreeDataSpace(TVMContext ctx, void* ptr) { hexagon::Device::Global()->Free(ptr); } -inline void HexagonDeviceAPI::CopyDataFromTo( - const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, - DLDataType type_hint, TVMStreamHandle stream) { +inline void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, + TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { const char* src = static_cast(from) + from_offset; char* dst = static_cast(to) + to_offset; @@ -110,11 +104,9 @@ inline void HexagonDeviceAPI::CopyDataFromTo( } } -inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, - TVMStreamHandle stream) {} +inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, TVMStreamHandle stream) {} -inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint) { +inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); if (type_hint.code == 100) { size_t align = std::min(nbytes, 2048lu); @@ -128,11 +120,10 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { DeviceAPI::FreeWorkspace(ctx, ptr); } -TVM_REGISTER_GLOBAL("device_api.hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); + *rv = ptr; +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index e14843688b73e..f76ac1670e242 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -176,8 +176,7 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { if (!InReg) { // Allocate on stack. - CHECK_EQ((t_align & (t_align - 1)), 0) - << "Alignment should be a power of 2"; + CHECK_EQ((t_align & (t_align - 1)), 0) << "Alignment should be a power of 2"; CHECK_GE(t_align, 4) << "Alignment should be at least 4"; // Round t_size up to a multiple of 4. unsigned s_size = Stack.size(); @@ -193,9 +192,8 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { class HexagonModuleNode final : public runtime::ModuleNode { public: HexagonModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) : hexagon_device_(hexagon::Device::Global()), data_(data), @@ -214,13 +212,11 @@ class HexagonModuleNode final : public runtime::ModuleNode { } } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; const char* type_key() const final { return "hexagon"; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -240,8 +236,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { CHECK(!bc_.empty()) << "LLVM IR bitcode not available"; SaveBinaryToFile(file_name, bc_); } else { - LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt - << "'"; + LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'"; } } void SaveToBinary(dmlc::Stream* stream) final { @@ -251,10 +246,8 @@ class HexagonModuleNode final : public runtime::ModuleNode { } private: - void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; - void CallRemoteDirect(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; + void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; + void CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; void RemapArgs(const TVMArgs& args, std::vector& values, // NOLINT(*) std::vector& type_codes, // NOLINT(*) @@ -274,8 +267,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { std::set packed_c_abi_funcs_; }; -void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, - const TVMArgs& args, +void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { // Remap all arguments, creating remote DLTensors. std::vector values; @@ -297,8 +289,8 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, int num_args = args.size(); int values_size = num_args * sizeof(TVMValue); int codes_size = num_args * sizeof(int); - void* remote = hexagon_device_->Alloc( - values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); + void* remote = + hexagon_device_->Alloc(values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); // Copy all argument TVMValues to the remote space. void* remote_values = remote; @@ -316,12 +308,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, temp_values[2].v_int64 = num_args; temp_values[3].v_handle = remote_ret_value; temp_values[4].v_handle = remote_ret_code; - int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, - kTVMOpaqueHandle, kTVMOpaqueHandle}; + int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, kTVMOpaqueHandle, + kTVMOpaqueHandle}; TVMArgs temp_args(temp_values, temp_codes, 5); hexagon::ArgLayout as = BuildArgLayout(temp_args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); // TODO(kparzysz-quic): copy return value back std::for_each(remote_tensors.begin(), remote_tensors.end(), @@ -332,12 +324,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, void HexagonModuleNode::CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { hexagon::ArgLayout as = BuildArgLayout(args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); } -PackedFunc HexagonModuleNode::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc HexagonModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { auto f = fmap_.find(name); if (f == fmap_.end()) return PackedFunc(nullptr); @@ -363,8 +355,7 @@ PackedFunc HexagonModuleNode::GetFunction( } } -void HexagonModuleNode::RemapArgs(const TVMArgs& args, - std::vector& values, +void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector& values, std::vector& type_codes, std::vector& remote_tensors) const { for (unsigned i = 0, e = args.size(); i != e; ++i) { @@ -437,18 +428,17 @@ void* HexagonModuleNode::CreateRemoteTensor(const DLTensor* t) const { uint32_t remote_as_int = reinterpret_cast(remote); void* remote_ss = reinterpret_cast(remote_as_int + size_ht); - HexagonDLTensor local = { - .data = static_cast(reinterpret_cast(t->data)), - .ctx_device_type = uint8_t(t->ctx.device_type), - .pad0 = {0, 0, 0}, - .ctx_device_id = t->ctx.device_id, - .ndim = t->ndim, - .dtype_code = t->dtype.code, - .dtype_bits = t->dtype.bits, - .dtype_lanes = t->dtype.lanes, - .shape = remote_as_int + size_ht, - .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, - .byte_offset = t->byte_offset}; + HexagonDLTensor local = {.data = static_cast(reinterpret_cast(t->data)), + .ctx_device_type = uint8_t(t->ctx.device_type), + .pad0 = {0, 0, 0}, + .ctx_device_id = t->ctx.device_id, + .ndim = t->ndim, + .dtype_code = t->dtype.code, + .dtype_bits = t->dtype.bits, + .dtype_lanes = t->dtype.lanes, + .shape = remote_as_int + size_ht, + .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, + .byte_offset = t->byte_offset}; std::vector local_ss(size_ss / 8); for (int i = 0; i != ndim; ++i) local_ss[i] = t->shape[i]; @@ -505,18 +495,16 @@ hexagon::ArgLayout HexagonModuleNode::BuildArgLayout(const TVMArgs& As) const { } Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, - ir_str, bc_str, packed_c_abi); + auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, + packed_c_abi); return Module(n); } // Load module from file. -Module HexagonModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module HexagonModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data = file_name; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -552,10 +540,9 @@ std::shared_ptr Device::Global() { } // namespace hexagon -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = HexagonModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = HexagonModuleLoadFile(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index c9e23a77776e1..b922b169bd61a 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -47,9 +47,8 @@ namespace runtime { * convention. */ Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi); namespace hexagon { @@ -91,24 +90,21 @@ class Device { * \param src Pointer (local to device) of the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToDevice(void* dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from device to host. * \param host_dst Pointer (local to host) to the destination buffer. * \param src Pointer (local to device) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from host to device. * \param dst Pointer (local to device) to the destination buffer. * \param host_src Pointer (local to host) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyHostToDevice(void* dst, const void* host_src, - unsigned len) = 0; + virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; /*! * \brief Load a module (typically a shared library) into device. * \param data Name of the shared library. @@ -141,8 +137,8 @@ class Device { * for padding. * \param st_num Number of values in the "stack" array. */ - virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) = 0; + virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) = 0; virtual ~Device() = 0; diff --git a/src/runtime/hexagon/hexagon_posix.cc b/src/runtime/hexagon/hexagon_posix.cc index 627963f384f5e..e98fefd1da221 100644 --- a/src/runtime/hexagon/hexagon_posix.cc +++ b/src/runtime/hexagon/hexagon_posix.cc @@ -23,12 +23,10 @@ #include extern "C" { -int posix_memalign(void** memptr, size_t alignment, size_t size) - __attribute__((nothrow)); +int posix_memalign(void** memptr, size_t alignment, size_t size) __attribute__((nothrow)); } -__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, - size_t size) { +__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, size_t size) { if (void* p = memalign(alignment, size)) { *memptr = p; return 0; diff --git a/src/runtime/hexagon/sim/driver/CMakeLists.txt b/src/runtime/hexagon/sim/driver/CMakeLists.txt new file mode 100644 index 0000000000000..8632b491f2591 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/CMakeLists.txt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +project(SIM_DEV C CXX) +cmake_minimum_required(VERSION 3.0.2) + +set(CMAKE_SYSTEM_NAME "Linux") + +if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) + include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) +endif() + +set(EXTRA_CXX_FLAGS + "-O2" + "-Wno-format" + "-mhvx -mhvx-length=128b" + "-mv60" + "-stdlib=libc++" +) + +set(EXTRA_LINK_FLAGS + "-stdlib=libc++" + "-G0" + "-Wl,--force-dynamic" + "-Wl,--export-dynamic" + "-Wl,--whole-archive" # This should link entire libc, libc++ and libc+abi. + "-Wl,--defsym=HEAP_SIZE=0x40000000" +) + +string(REGEX REPLACE ";" " " EXTRA_CXX_FLAGS_STR "${EXTRA_CXX_FLAGS}") +string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_FLAGS "${EXTRA_CXX_FLAGS_STR} ${CMAKE_CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${EXTRA_LINK_FLAGS_STR} ${CMAKE_EXE_LINKER_FLAGS}") + +# Set project properties. + +file(GLOB SOURCE_FILES "*.cc") +add_executable(sim_dev ${SOURCE_FILES}) +target_include_directories(sim_dev + PUBLIC "." + PUBLIC ".." + PUBLIC "../../../../../include" + PUBLIC "../../../../../3rdparty/dlpack/include" +) + +target_link_libraries(sim_dev "-ldl") diff --git a/tests/webgl/README.md b/src/runtime/hexagon/sim/driver/README.md similarity index 50% rename from tests/webgl/README.md rename to src/runtime/hexagon/sim/driver/README.md index 5303cc0597403..3aee1a14b7968 100644 --- a/tests/webgl/README.md +++ b/src/runtime/hexagon/sim/driver/README.md @@ -15,10 +15,24 @@ -## Test cases for the WebGL backend +# Hexagon simulator driver -Any test case with name `test_local_...` tests the C++ OpenGL backend on the -local OS, which can be executed automatically. +The driver (`sim_dev` executable) is the process running on the Hexagon simulator that handles the Hexagon-side communication with the TVM runtime running on x86. The location of `sim_dev` should be added to `PATH` before running any python code that uses Hexagon. The `sim_dev` executable is not intended to be run by users, it is automatically loaded by the simulator control code (in `hexagon_device_sim.cc`). -Any test case with name `test_remote_...` tests the WebGL backend within the -browser, which must be run manually. See instruction within the test. +### Prerequisites + +1. Hexagon C/C++ toolchain (such as the one in Hexagon SDK version 3.5.0 or later). + +Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. + +### Configuring + +Set +``` +CMAKE_C_COMPILER=hexagon-clang +CMAKE_CXX_COMPILER=hexagon-clang++ +``` + +### Building + +There are no special options required for `make` (or the tool selected with `cmake`). The location of the resulting binary `sim_dev` should be added to `PATH`. diff --git a/src/runtime/hexagon/sim/driver/fake_pthread.cc b/src/runtime/hexagon/sim/driver/fake_pthread.cc new file mode 100644 index 0000000000000..74090d0bf7960 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/fake_pthread.cc @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "pthread.h" +#include "sched.h" + +/*! + * Implementation of a subset of pthread API for single-threaded execution. + * + * They main idea is that the thread function ("start_routine" in the call + * to pthread_create) is executed immediately. When pthread_create returns, + * the thread function has already finished. + * + * Since the thread routine can itself call pthread_create, it is possible + * to have multiple threads existing at the same time, although only the + * last one is running. + * + * There are two main things that need to be taken care of: + * - thread-specific data, i.e. pthread_setspecific, pthread_getspecific, + * and the handling of thread keys, + * - handling of thread return values. + * + * Threads are identified by thread ids (of type pthread_t). The main process + * thread has the id of 0, the remaining threads have ids starting at 1 and + * incrementing by 1. For each thread there is some data (thread_info_t) + * associated with it, and stored in "thread_data" map. When a thread + * terminates, the corresponding entry from "thread_data" cannot be removed + * until the return value is claimed (pthread_join), unless it is explicitly + * discarded (pthread_detach). When a new thread is created, it gets the + * first available id for which there is no entry in "thread_data". This + * could be an id that was never allocated, or an id that was used, but + * has since been removed from the map. + * A thread can terminate through thread_exit. This means that when the + * thread function calls thread_exit, the execution should return to the + * pthread_create call that ran it. This is implemented via setjmp/longjmp + * (neither longjmp nor pthread_exit unwind the stack). + * + * Any mutexes or condition variables cannot block, or else it would cause + * a deadlock. Since there is only one thread running at a time, locking + * a mutex or waiting for a condition always succeeds (returns immediately). + */ + +struct key_entry_t { + key_entry_t(void* v, void (*d)(void*)) : value(v), dtor(d) {} + void* value = nullptr; + void (*dtor)(void*) = nullptr; +}; + +struct thread_info_t { + thread_info_t() = default; + std::map keys; + std::jmp_buf env; + void* ret_value = nullptr; + bool finished = false; + bool detached = false; +}; + +static pthread_t main_thread_id = 0; + +static std::map thread_data = { + // Reserve the 0th entry. + {main_thread_id, {}}}; + +static std::vector running_threads = {main_thread_id}; + +template +K first_available_key(const std::map& m) { + auto i = m.begin(), e = m.end(); + K key = 1; + for (; i != e && key == i->first; ++i, ++key) { + } + return key; +} + +int pthread_cond_destroy(pthread_cond_t* cond) { return 0; } + +int pthread_cond_init(pthread_cond_t* __restrict cond, + const pthread_condattr_t* __restrict attr) { + return 0; +} + +int pthread_cond_signal(pthread_cond_t* cond) { return 0; } + +int pthread_cond_broadcast(pthread_cond_t* cond) { return 0; } + +int pthread_cond_timedwait(pthread_cond_t* __restrict cond, + pthread_mutex_t* __restrict mutex, + const struct timespec* __restrict abstime) { + return 0; +} + +int pthread_cond_wait(pthread_cond_t* __restrict cond, + pthread_mutex_t* __restrict mutex) { + return 0; +} + +int pthread_mutexattr_init(pthread_mutexattr_t* attr) { return 0; } + +int pthread_mutexattr_destroy(pthread_mutexattr_t* attr) { return 0; } + +int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type) { + return 0; +} + +int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, + int* __restrict type) { + *type = PTHREAD_MUTEX_NORMAL; + return 0; +} + +int pthread_mutex_init(pthread_mutex_t* __restrict mutex, + const pthread_mutexattr_t* __restrict attr) { + return 0; +} + +int pthread_mutex_destroy(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_lock(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_trylock(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_unlock(pthread_mutex_t* mutex) { return 0; } + +int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)) { + static_assert(PTHREAD_ONCE_INIT != PTHREAD_ONCE_DONE, + "PTHREAD_ONCE_INIT must be different from PTHREAD_ONCE_DONE"); + if (*once_control == PTHREAD_ONCE_INIT) { + init_routine(); + *once_control = PTHREAD_ONCE_DONE; + } + return 0; +} + +int pthread_equal(pthread_t t1, pthread_t t2) { return t1 == t2; } + +int pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg) { + std::jmp_buf& env = thread_data[pthread_self()].env; + volatile pthread_t tid; + if (setjmp(env) == 0) { + tid = first_available_key(thread_data); + *thread = tid; + running_threads.push_back(pthread_t(tid)); + thread_info_t& thr = thread_data[pthread_t(tid)]; + thr.ret_value = start_routine(arg); + } + thread_info_t& thr = thread_data[pthread_t(tid)]; + thr.finished = true; + running_threads.pop_back(); + + // Destroy all keys. + bool repeat = true; + size_t iter = 0; + while (repeat && iter++ < PTHREAD_DESTRUCTOR_ITERATIONS) { + repeat = false; + // Assume that destructors can create new keys (i.e. modify the map). + for (size_t k = 0; k != PTHREAD_KEYS_MAX; ++k) { + auto f = thr.keys.find(k); + if (f == thr.keys.end()) { + continue; + } + key_entry_t& key = f->second; + if (key.dtor == nullptr || key.value == nullptr) { + continue; + } + key.dtor(key.value); + repeat = true; + } + } + + if (thr.detached) { + thread_data.erase(pthread_t(tid)); + } + + return 0; +} + +int pthread_join(pthread_t thread, void** retval) { + auto f = thread_data.find(thread); + if (f == thread_data.end()) { + return ESRCH; + } + thread_info_t& thr = f->second; + if (!thr.finished) { + return EDEADLK; + } + if (retval != nullptr) { + *retval = thr.ret_value; + } + thread_data.erase(f); + return 0; +} + +int pthread_detach(pthread_t thread) { + auto f = thread_data.find(thread); + if (f == thread_data.end()) { + return ESRCH; + } + // Can discard the return value. + f->second.detached = true; + return 0; +} + +void pthread_exit(void* retval) { + pthread_t sid = pthread_self(); + if (sid != main_thread_id) { + thread_info_t& self = thread_data[sid]; + self.ret_value = retval; + self.finished = true; + longjmp(self.env, 1); + } + exit(0); // Only executes for the main thread, plus silences + // the "should not return" warning. +} + +int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)) { + if (key == nullptr) { + return EINVAL; + } + auto& keys = thread_data[pthread_self()].keys; + pthread_key_t k = first_available_key(keys); + if (k >= PTHREAD_KEYS_MAX) { + return EAGAIN; + } + *key = k; + keys.emplace(k, key_entry_t{nullptr, destructor}); + return 0; +} + +int pthread_key_delete(pthread_key_t key) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f == keys.end()) { + return EINVAL; + } + // pthread_key_delete does not call key destructors. + keys.erase(f); + return 0; +} + +int pthread_setspecific(pthread_key_t key, const void* value) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f == keys.end()) { + return EINVAL; + } + f->second.value = const_cast(value); + return 0; +} + +void* pthread_getspecific(pthread_key_t key) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f != keys.end()) { + return f->second.value; + } + return nullptr; +} + +pthread_t pthread_self(void) { return running_threads.back(); } + +int sched_yield(void) { return 0; } + +#ifdef __cplusplus_ +extern "C" int nanosleep(const struct timespec* req, struct timespec* rem); +#endif + +int nanosleep(const struct timespec* req, struct timespec* rem) { return 0; } diff --git a/src/runtime/hexagon/sim/driver/pthread.h b/src/runtime/hexagon/sim/driver/pthread.h new file mode 100644 index 0000000000000..1748d614cbbf8 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/pthread.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ +#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ + +#define _PROVIDE_POSIX_TIME_DECLS 1 +#include +#undef _PROVIDE_POSIX_TIME_DECLS + +typedef int pthread_t; +typedef int pthread_attr_t; +typedef int pthread_cond_t; +typedef int pthread_condattr_t; +typedef int pthread_key_t; +typedef int pthread_mutex_t; +typedef int pthread_mutexattr_t; +typedef int pthread_once_t; + +enum { + PTHREAD_COND_INITIALIZER, + PTHREAD_MUTEX_DEFAULT, + PTHREAD_MUTEX_ERRORCHECK, + PTHREAD_MUTEX_INITIALIZER, + PTHREAD_MUTEX_NORMAL, + PTHREAD_MUTEX_RECURSIVE, + PTHREAD_ONCE_INIT = 0, // Must be same as in QuRT + PTHREAD_ONCE_DONE, // Non-standard +}; + +const size_t PTHREAD_KEYS_MAX = 128; +const size_t PTHREAD_DESTRUCTOR_ITERATIONS = 4; + +#ifdef __cplusplus +extern "C" { +#endif +int pthread_cond_destroy(pthread_cond_t* cond); +int pthread_cond_init(pthread_cond_t* __restrict cond, + const pthread_condattr_t* __restrict attr); +int pthread_cond_signal(pthread_cond_t* cond); +int pthread_cond_broadcast(pthread_cond_t* cond); +int pthread_cond_timedwait(pthread_cond_t* __restrict cond, + pthread_mutex_t* __restrict mutex, + const struct timespec* __restrict abstime); +int pthread_cond_wait(pthread_cond_t* __restrict cond, + pthread_mutex_t* __restrict mutex); + +int pthread_mutexattr_init(pthread_mutexattr_t* attr); +int pthread_mutexattr_destroy(pthread_mutexattr_t* attr); +int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, + int* __restrict type); +int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type); + +int pthread_mutex_init(pthread_mutex_t* __restrict mutex, + const pthread_mutexattr_t* __restrict attr); +int pthread_mutex_destroy(pthread_mutex_t* mutex); +int pthread_mutex_lock(pthread_mutex_t* mutex); +int pthread_mutex_trylock(pthread_mutex_t* mutex); +int pthread_mutex_unlock(pthread_mutex_t* mutex); + +int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)); +int pthread_equal(pthread_t t1, pthread_t t2); + +int pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg); +int pthread_join(pthread_t thread, void** retval); +int pthread_detach(pthread_t thread); +void pthread_exit(void* retval) __attribute__((__noreturn__)); + +int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)); +int pthread_key_delete(pthread_key_t key); +int pthread_setspecific(pthread_key_t key, const void* value); +void* pthread_getspecific(pthread_key_t key); + +pthread_t pthread_self(void); +#ifdef __cplusplus +} +#endif + +#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ diff --git a/web/.eslintrc.js b/src/runtime/hexagon/sim/driver/sched.h similarity index 68% rename from web/.eslintrc.js rename to src/runtime/hexagon/sim/driver/sched.h index 2e82ba50e3c46..cc63630f20723 100644 --- a/web/.eslintrc.js +++ b/src/runtime/hexagon/sim/driver/sched.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,29 +17,15 @@ * under the License. */ -module.exports = { - "env": { - "browser": true, - "node": true, - "es6": true - }, - "extends": "eslint:recommended", - "rules": { - "indent": [ - "error", - 2 - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ] - } -}; +#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ +#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ + +#ifdef __cplusplus +extern "C" { +#endif +int sched_yield(void); +#ifdef __cplusplus +} +#endif + +#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ diff --git a/src/runtime/hexagon/sim/driver/sim_device.cc b/src/runtime/hexagon/sim/driver/sim_device.cc new file mode 100644 index 0000000000000..23dc053070381 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/sim_device.cc @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + Required options: + -ldl -G0 For dlinit/dlopen/dlclose. + -Wl,--force-dynamic Make this a dynamic executable (with dynamic + symbol table). + -Wl,-E Export all defined symbols as dynamic. + -Wl,--whole-archive Link the entire contents of libc. + -mhvx -mhvx-length=128b Enable HVX. + -Wno-format Silence format warning (unsigned vs uint32_t). +*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "hexagon_sim_proto.h" +#include "pthread.h" +#include "tvm/runtime/c_runtime_api.h" + +static std::string timeNow() { + char str[11]; // [hh:mm:ss] + time_t time_value = time(NULL); + tm* pnow = localtime(&time_value); // NOLINT(runtime/threadsafe_fn) + + snprintf(str, sizeof(str), "[%02d:%02d:%02d]", pnow->tm_hour, pnow->tm_min, + pnow->tm_sec); + return std::string(str); +} + +#define LOG(FMT, ...) \ + fprintf(stderr, "%s %s:%d: " FMT "\n", timeNow().c_str(), __FILE__, \ + __LINE__, ##__VA_ARGS__) + +using HVX_Vector = + int __attribute__((__vector_size__(128))) __attribute__((aligned(128))); + +static unsigned getVectorLength() { + HVX_Vector v = __builtin_HEXAGON_V6_lvsplatw_128B(0x01010101); + unsigned char* p = reinterpret_cast(&v); + if (p[127] == 1) return 128; + assert(p[63] == 1); + return 64; +} + +extern "C" { +// Print vector functions. They can be used to help debug tensorized +// code, via +// ib.emit(tvm.call_extern('int32', 'V6_pv8', 'vector:', v)) +// ib.emit(tvm.call_extern('int32', 'V6_pv16', 'info:', v)) +// ib.emit(tvm.call_extern('int32', 'V6_pv32', 'value:', v)) + +// The first argument is a string printed before the vector contents. +int V6_pv8(const char* s, HVX_Vector v); +int V6_pv16(const char* s, HVX_Vector v); +int V6_pv32(const char* s, HVX_Vector v); +} + +int V6_pv8(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint8_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen; ++i) { + fprintf(stderr, " %02x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +int V6_pv16(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint16_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen / sizeof(uint16_t); ++i) { + fprintf(stderr, " %04x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +int V6_pv32(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint32_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen / sizeof(uint32_t); ++i) { + fprintf(stderr, " %08x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +extern "C" { +// Function referenced from libc++.a, but not defined in libc.a. +int clock_gettime(clockid_t clock_id, struct timespec* tp); +// pthread_create is wrapped so that we can set a bigger stack size +// for QuRT. Here this isn't needed, but we still need to implement +// the wrapper. +int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg); +} + +int clock_gettime(clockid_t clock_id, struct timespec* tp) { + // Stub implementation. + return 0; +} + +int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg) { + LOG("%s", __func__); + return pthread_create(thread, attr, start_routine, arg); +} + +// FIXME(kparzysz-quic): query the cfg register to compute the VTCM base. +// This works now. +const unsigned int TCM_BASE = 0xD8000000; +const unsigned int VTCM_BASE = TCM_BASE + 0x400000; + +class Allocator { + private: + struct Block { + Block(void* p, size_t s) : ptr_(p), size_(s), vtcm_(false) {} + Block(void* p, size_t s, bool v) : ptr_(p), size_(s), vtcm_(v) {} + bool operator<(const Block& b) const { + return uintptr_t(ptr_) < uintptr_t(b.ptr_); + } + void* ptr_; + size_t size_; + bool vtcm_; + }; + + using vector_type = std::vector; + using iterator = vector_type::iterator; + vector_type allocations_; + + uintptr_t cur_vtcm = VTCM_BASE; + + public: + void* alloc(unsigned size, size_t align); + void* vtcm_alloc(unsigned size, size_t align); + void free(void* p); +}; + +void* Allocator::alloc(unsigned size, size_t align) { + void* ptr = aligned_alloc(align, size); + if (ptr == nullptr) { + perror("device: error allocating memory:"); + return ptr; + } + + Block b(ptr, size); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); + iterator w = allocations_.insert(i, b); + if (w != allocations_.begin()) { + iterator pw = w - 1; + assert(uintptr_t(pw->ptr_) + pw->size_ < uintptr_t(w->ptr_)); + } + if (w + 1 != allocations_.end()) { + iterator nw = w + 1; + assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); + } + + LOG("device: allocated %d bytes aligned at %d: %p", size, align, ptr); + return ptr; +} + +// For now, just allocation sequentially. This needs to be improved to use a +// free list. +void* Allocator::vtcm_alloc(unsigned size, size_t align) { + uintptr_t a = cur_vtcm; + a = (a + (align - 1)) & -align; + cur_vtcm = a + size; + void* ptr = reinterpret_cast(a); + if (ptr == nullptr) { + perror("device: error allocating vtcm memory:"); + return ptr; + } + + Block b(ptr, size, true); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); + iterator w = allocations_.insert(i, b); + if (w != allocations_.begin()) { + iterator pw = w - 1; + assert(uintptr_t(pw->ptr_) + pw->size_ <= uintptr_t(w->ptr_)); + } + if (w + 1 != allocations_.end()) { + iterator nw = w + 1; + assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); + } + + LOG("device: allocated vtcm %d bytes aligned at %d: %p", size, align, ptr); + return ptr; +} + +void Allocator::free(void* ptr) { + LOG("device: freeing %p", ptr); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), + Block(ptr, 0)); + assert(i != allocations_.end()); + assert(i->ptr_ == ptr); + if (!i->vtcm_) ::free(i->ptr_); + allocations_.erase(i); +} + +static void printMsgCall(const MsgCall& mc) { + auto to_dec_string = [](int v) { + char tmp[11]; + snprintf(tmp, sizeof(tmp), "%d", v); + return std::string(tmp); + }; + auto to_hex_string = [](uint32_t v) { + char tmp[9]; + snprintf(tmp, sizeof(tmp), "%lx", v); + return std::string(tmp); + }; + std::string str = "device: launching " + to_hex_string(mc.func_va) + + " sc:" + to_dec_string(mc.scalar_num) + " {"; + for (unsigned i = 0; i != mc.scalar_num; ++i) { + str += ' ' + to_hex_string(mc.data[i]); + if (i + 1 != mc.scalar_num) str += ','; + } + str += " }, st:" + to_dec_string(mc.stack_num) + " {"; + for (unsigned i = 0; i != mc.stack_num; ++i) { + str += ' ' + to_hex_string(mc.data[i + mc.scalar_num]); + if (i + 1 != mc.stack_num) str += ','; + } + str += " }"; + LOG("%s", str.c_str()); +} + +static std::vector task_queue; + +struct Environment { + Allocator alloc; + void* dl_handle = nullptr; +}; + +extern "C" { +volatile Message message_buffer; +int dispatch(Environment* env) __attribute__((noinline)); +} + +static volatile unsigned char payload_buffer[4096]; + +static void setMsg(uint32_t code, uint32_t len, uint32_t va) { + message_buffer.code = code; + message_buffer.len = len; + message_buffer.va = va; +} + +inline void* pointer(uint32_t v) { + return reinterpret_cast(static_cast(v)); +} + +inline uint32_t va(const volatile void* p) { + return static_cast(reinterpret_cast(p)); +} + +__attribute__((naked)) uint32_t launcher(volatile MsgCall* mc, uint64_t* pcc) { + __asm__( + "// This function is intentionally written to be readable, \n" + "// rather than fast. \n" + "// r0 = value of 'volatile MsgCall *mc' \n" + "// r1 = address where to store the program cycle count \n" + "{ memd(r29+#-16) = r21:20 \n" + " allocframe(#24) } \n" + "{ memd(r29+#0) = r17:16 \n" + " memd(r29+#8) = r19:18 } \n" + "{ r17:16 = combine(r1,r0) \n" + " r18 = r29 \n" + " r1 = memw(r0+#4) // scalar_num \n" + " r2 = memw(r0+#8) } // stack_num \n" + "// If there are no stack values, skip the stack setup. \n" + "{ p0 = cmp.eq(r2,#0) \n" + " if (p0.new) jump:t .Llauncher1 } \n" + + "// Allocate space on the stack. Let r2 = needed space \n" + "// rounded up to a multiple of 8. \n" + "{ loop0(.Llauncher0,r2) \n" + " r2 = asl(r2,#2) } \n" + "{ r2 = add(r2,#4) } \n" + "{ r2 = clrbit(r2,#2) } \n" + "{ r29 = sub(r29,r2) } \n" + + "// Copy stack contents onto the stack. Stack contents start \n" + "// at r3 = r0 + offsetof(data) + scalar_num*4 \n" + "{ r3 = addasl(r0,r1,#2) \n" + " r4 = r29 } \n" + "{ r3 = add(r3,#12) } // offsetof(data) \n" + ".Llauncher0: \n" + "{ r5 = memw(r3++#4) \n" + " memw(r4++#4) = r5.new } :endloop0 \n" + + "// Load registers. Some of the loaded data may actually be \n" + "// values from the stack part of 'data', but it's not an issue.\n" + ".Llauncher1: \n" + "{ r0 = memw(r16+#12) // mc + offsetof(data) \n" + " r1 = memw(r16+#16) } \n" + "{ r2 = memw(r16+#20) \n" + " r3 = memw(r16+#24) } \n" + "{ r4 = memw(r16+#28) \n" + " r5 = memw(r16+#32) } \n" + + "// Call. \n" + "{ r6 = memw(r16+#0) \n" + " r21:20 = upcycle } \n" + "{ callr r6 } \n" + + "// Restore stack pointer (free up r18), calculate cycle count. \n" + "{ r29 = r18 \n" + " r19:18 = upcycle } \n" + "{ r19:18 = sub(r19:18, r21:20) } \n" + + "// Store pcount, restore non-volatile registers, and return. \n" + "{ memd(r17+#0) = r19:18 \n" + " r21:20 = memd(r29+#16) } \n" + "{ r19:18 = memd(r29+#8) \n" + " r17:16 = memd(r29+#0) } \n" + "{ dealloc_return } // implicit-use r1:0 \n"); +} + +int dispatch(Environment* env) { + uint32_t code = message_buffer.code; + // Special handling of MsgReq. + if (code == kMsgReq) { + assert(message_buffer.len <= sizeof(payload_buffer)); + setMsg(kMsgAck, sizeof(payload_buffer), va(payload_buffer)); + return 0; + } + + switch (code) { + case kAlloc: { + LOG("device: {kAlloc, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgAlloc)); + auto* ma = reinterpret_cast(message_buffer.va); + void* p = env->alloc.alloc(ma->size, ma->align); + reinterpret_cast(payload_buffer)->va = va(p); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kFree: { + LOG("device: {kFree, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgPointer)); + auto* mp = reinterpret_cast(message_buffer.va); + env->alloc.free(pointer(mp->va)); + setMsg(kNone, 0u, 0u); + break; + } + case kAllocVtcm: { + LOG("device: {kAllocVtcm, %lu, %lx}", message_buffer.len, + message_buffer.va); + assert(message_buffer.len == sizeof(MsgAlloc)); + auto* ma = reinterpret_cast(message_buffer.va); + void* p = env->alloc.vtcm_alloc(ma->size, ma->align); + reinterpret_cast(payload_buffer)->va = va(p); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kCopy: { + LOG("device: {kCopy, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgCopy)); + auto* mc = reinterpret_cast(message_buffer.va); + memcpy(pointer(mc->dst), pointer(mc->src), mc->len); + setMsg(kNone, 0u, 0u); + break; + } + case kLoad: { + if (env->dl_handle != nullptr) dlclose(env->dl_handle); + const char* name = static_cast(pointer(message_buffer.va)); + // LOG(stderr, "device: dlopen(%s)", name); + env->dl_handle = dlopen(name, RTLD_LAZY); + if (env->dl_handle == nullptr) LOG("dlopen: %s\n", dlerror()); + assert(env->dl_handle != nullptr); + reinterpret_cast(payload_buffer)->va = + va(env->dl_handle); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kUnload: { + assert(env->dl_handle != nullptr); + assert(message_buffer.len == sizeof(MsgPointer)); + auto* mp = reinterpret_cast(message_buffer.va); + assert(pointer(mp->va) == env->dl_handle); + dlclose(env->dl_handle); + env->dl_handle = nullptr; + setMsg(kNone, 0u, 0u); + break; + } + case kResolve: { + LOG("device: {kResolve, %lu, %lx}", message_buffer.len, + message_buffer.va); + assert(env->dl_handle != nullptr); + dlerror(); + const char* name = static_cast(pointer(message_buffer.va)); + void* s = dlsym(env->dl_handle, name); + reinterpret_cast(payload_buffer)->va = va(s); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kCall: { + LOG("device: {kCall, %lu, %lx}", message_buffer.len, message_buffer.va); + // Add the task to the queue. + auto* mc = reinterpret_cast(message_buffer.va); + uint32_t size = 4 * (3 + mc->scalar_num + mc->stack_num); + MsgCall* t = static_cast(malloc(size)); + memcpy(t, mc, size); + task_queue.push_back(t); + // Return 0. + *reinterpret_cast(payload_buffer) = 0; + setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); + break; + } + case kFlush: { + LOG("device: {kFlush}"); + LOG("device: %d tasks in the queue", task_queue.size()); + // Execute all tasks from the queue and release memory buffers + // for as long as the return values are 0. Upon receiving a non-zero + // return value, continue freeing memory but no longer execute + // any tasks. The task queue will be cleared in any case. + uint32_t rv = 0; + uint64_t pcc; // Pcycle counter, will be 0 under simulator (upcycle). + for (MsgCall* t : task_queue) { + if (rv == 0) { + printMsgCall(*t); + rv = launcher(t, &pcc); + LOG("device: execution took %lld pcycles", pcc); + } + free(t); + } + task_queue.clear(); + *reinterpret_cast(payload_buffer) = rv; + setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); + break; + } + default: + LOG("device: unknown code: %lu", message_buffer.code); + abort(); + break; + } + return 0; +} + +extern "C" { +int acquire_vector_unit(int); +void release_vector_unit(); +} + +static void makePathList(const std::string& arg, + std::vector* list) { + size_t p = 0, e = arg.size(); + std::vector tmp; + + while (p < e) { + tmp.clear(); + bool check_next = true; + size_t i = p; + for (; i != e; ++i) { + char c = arg[i]; + if (check_next) { + if (c == '\\') { + check_next = false; + continue; + } else if (c == ':') { + break; + } + } + check_next = true; + tmp.push_back(c); + } + if (!tmp.empty()) list->emplace_back(tmp.begin(), tmp.end()); + p = i + 1; + } +} + +static std::string findInPaths(const std::string& filename, + const std::string& paths) { + std::vector path_list; + makePathList(paths, &path_list); + + for (const auto& p : path_list) { + std::string pf = p + '/' + filename; + if (access(pf.c_str(), X_OK) == 0) return std::move(pf); + } + // If the search failed, try bare filename. If it cannot be loaded, + // dlerror will print a meaningful message. + return filename; +} + +// Presence of this function indicates that sim_dev is running. +extern "C" int running_in_sim_dev_17bc90206f6cf5a7(); +int running_in_sim_dev_17bc90206f6cf5a7() { return 0; } + +int main(int argc, char* argv[]) { + int opt; + std::string ld_path; + while ((opt = getopt(argc, argv, "L:")) != -1) { + switch (opt) { + case 'L': + ld_path += ':' + std::string(optarg); + break; + case '?': + LOG("Usage %s: [-L path1[:path2...]]", argv[0]); + return 1; + } + } + + std::string rt_path = findInPaths("libtvm_runtime.so", ld_path); + LOG("TVM runtime path: %s", rt_path.c_str()); + + Environment env; + acquire_vector_unit(0); + + const char* builtin[] = { + "libgcc.so", "libc.so", "libc++.so", + "libc++abi.so", "libc++.so.1", "libc++abi.so.1" // Alternative names. + }; + dlinit(sizeof(builtin) / sizeof(builtin[0]), const_cast(builtin)); + void* rt_handle = dlopen(rt_path.c_str(), RTLD_GLOBAL); + if (rt_handle == nullptr) { + LOG("error loading TVM runtime: %s", dlerror()); + return 1; + } + + // When running TVM runtime on Hexagon there is no longer a device + // for Hexagon, but standalone ops can still refer to it. All of + // required DeviceAPI's functionality is adequately implemented + // via the CPU device, so remap device_api.hexagon to device_api.cpu. + auto* get_global = reinterpret_cast( + dlsym(rt_handle, "TVMFuncGetGlobal")); + assert(get_global != nullptr); + auto* register_global = reinterpret_cast( + dlsym(rt_handle, "TVMFuncRegisterGlobal")); + assert(register_global != nullptr); + + TVMFunctionHandle cpu_api; + if (get_global("device_api.cpu", &cpu_api) != 0 || + register_global("device_api.hexagon", cpu_api, true) != 0) { + LOG("error setting device_api.hexagon"); + return 1; + } + + while (!dispatch(&env)) { + } + + dlclose(rt_handle); + release_vector_unit(); + return 0; +} diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/sim/hexagon_device_sim.cc index b58377baa947e..477da09c1c652 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/sim/hexagon_device_sim.cc @@ -41,8 +41,7 @@ namespace tvm { namespace runtime { namespace hexagon { -static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), - "Hexagon VA must be uint32"); +static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32"); template struct unalign { @@ -89,8 +88,7 @@ std::unique_ptr make_unique(size_t size) { // user from memory reallocation and copying. struct non_const_str { non_const_str() {} - explicit non_const_str(const std::string& str) - : non_const_str(std::vector{str}) {} + explicit non_const_str(const std::string& str) : non_const_str(std::vector{str}) {} explicit non_const_str(const std::vector& vec) { for (const std::string& s : vec) { auto c = detail::make_unique(s.size() + 1); @@ -220,8 +218,7 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { void* Load(const std::string& data, const std::string& fmt) final; void Unload(void* mod) final; void* Resolve(const std::string& sym) final; - void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) final; + void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final; static std::string to_string(HEXAPI_Status status); @@ -312,10 +309,8 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { bool should_parse_next(const string_list& rest); llvm::Optional to_interval(const detail::MaybeString& str); - llvm::Optional to_timingmode( - const detail::MaybeString& str); - llvm::Optional to_verbosemode( - const detail::MaybeString& str); + llvm::Optional to_timingmode(const detail::MaybeString& str); + llvm::Optional to_verbosemode(const detail::MaybeString& str); llvm::Optional to_nullptr(const detail::MaybeString& str); MaybeUIntRange ahb_, axi2_; @@ -399,12 +394,11 @@ decltype(HexagonSimulator::opt_map_) HexagonSimulator::opt_map_ = { {"--verbose", &HexagonSimulator::HandleVerbose}, }; -#define CHECKED_CALL(func, ...) \ - do { \ - HEXAPI_Status s = sim_->func(__VA_ARGS__); \ - CHECK_EQ(s, HEX_STAT_SUCCESS) \ - << "HexagonSimulator: " #func " failed with code " \ - << HexagonSimulator::to_string(s); \ +#define CHECKED_CALL(func, ...) \ + do { \ + HEXAPI_Status s = sim_->func(__VA_ARGS__); \ + CHECK_EQ(s, HEX_STAT_SUCCESS) << "HexagonSimulator: " #func " failed with code " \ + << HexagonSimulator::to_string(s); \ } while (false) inline HEX_VA_t HexagonSimulator::p2va(const void* p) { @@ -444,8 +438,7 @@ void HexagonSimulator::CopyNFromV(void* host_dst, HEX_VA_t src) { pd->value = v; } -void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, - unsigned len) { +void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) { const uint8_t* src = static_cast(host_src); while (len >= 8) { @@ -556,18 +549,15 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { using iterator = std::istream_iterator; auto sim_args = string_list(iterator(sim_args_iss), iterator()); - std::string target_str = - !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); + std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); arch_ = target_str; - sim_ = - detail::make_unique(detail::non_const_str(target_str)); + sim_ = detail::make_unique(detail::non_const_str(target_str)); LOG(INFO) << "HexagonSimulator: Core version: " << arch_; // Locate the sim_dev binary in PATH, or in the current working directory. llvm::StringRef sim_dev = "sim_dev"; - detail::MaybeString path_sim_dev = - llvm::sys::Process::FindInEnvPath("PATH", sim_dev); + detail::MaybeString path_sim_dev = llvm::sys::Process::FindInEnvPath("PATH", sim_dev); if (!path_sim_dev) { if (!llvm::sys::fs::exists(sim_dev)) { LOG(FATAL) << "Cannot find sim_dev in PATH."; @@ -615,8 +605,7 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { } void* HexagonSimulator::Alloc(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align - << ')'; + LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')'; Message m = {kAlloc, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -631,8 +620,7 @@ void* HexagonSimulator::Alloc(unsigned size, unsigned align) { } void HexagonSimulator::Free(void* ptr) { - LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec - << ')'; + LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -643,8 +631,7 @@ void HexagonSimulator::Free(void* ptr) { } void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size - << ", align=" << align << ')'; + LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')'; Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -653,28 +640,25 @@ void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { MsgPointer mp; CopyFromV(&mp, m.va, m.len); - LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec; CHECK_NE(mp.va, 0); return va2p(mp.va); } void HexagonSimulator::FreeVtcm(void* ptr) {} -void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst - << ", src=" << src << ", len=" << std::dec << len << ')'; +void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src + << ", len=" << std::dec << len << ')'; CHECK(dst != nullptr && src != nullptr); Message m = {kCopy, sizeof(MsgCopy), 0u}; MsgCopy mc = {p2va(dst), p2va(src), len}; SendMsg(m, &mc, true); } -void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst - << ", src=" << src << ", len=" << len << ')'; +void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src + << ", len=" << len << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -682,10 +666,9 @@ void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, CopyFromV(host_dst, p2va(src), len); } -void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst - << ", host_src=" << host_src << ", len=" << len << ')'; +void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src + << ", len=" << len << ')'; CopyToV(p2va(dst), host_src, len); } @@ -717,19 +700,17 @@ void* HexagonSimulator::Resolve(const std::string& sym) { MsgPointer mp; CopyFromV(&mp, m.va, sizeof(mp)); - LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec; return va2p(mp.va); } -void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) { - LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func - << ", scalar=" << scalar << ", sc_num=" << std::dec +void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) { + LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar + << ", sc_num=" << std::dec << sc_num // NOLINTNEXTLINE(build/include_what_you_use) - << ", stack=" << std::hex << stack << ", st_num=" << std::dec - << st_num; + << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num; std::vector data; @@ -753,8 +734,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, log_data << std::dec << " }" << std::flush; LOG(INFO) << log_data.str(); - Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), - 0u}; + Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), 0u}; SendMsg(m, data.data(), true); if (!task_queuing_) { @@ -768,8 +748,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, std::ostringstream log_rv; log_rv << "HexagonSimulator::Call -> {" << std::hex; for (unsigned i = 0, e = std::min(rv.size(), 4u); i != e; ++i) { - log_rv << ' ' << std::setw(2) << std::setfill('0') - << static_cast(rv[i]); + log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast(rv[i]); } if (rv.size() > 4) log_rv << "..."; log_rv << std::dec << " }"; @@ -1059,8 +1038,7 @@ bool HexagonSimulator::HandlePacketAnalyze(string_list& rest) { } bool HexagonSimulator::HandlePCFilter(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second); } @@ -1222,11 +1200,9 @@ bool HexagonSimulator::HandleTCMLowAddr(string_list& rest) { } bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { - CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, - range->second, HEX_NANOSEC); + CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC); } return static_cast(range); } @@ -1284,8 +1260,7 @@ bool HexagonSimulator::should_parse_next(const string_list& rest) { return false; } -llvm::Optional HexagonSimulator::to_interval( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1309,8 +1284,7 @@ llvm::Optional HexagonSimulator::to_interval( .Default(none); } -llvm::Optional HexagonSimulator::to_timingmode( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_timingmode(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1357,8 +1331,7 @@ llvm::Optional HexagonSimulator::to_verbosemode( .Default(none); } -llvm::Optional HexagonSimulator::to_nullptr( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl b/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl index 6aa194ee9fe27..bb7d8a29550d0 100644 --- a/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl +++ b/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl @@ -34,8 +34,8 @@ interface tvm_remote : remote_handle64 { rout handle_t sym_ptr); long kernel(in handle_t mod, in handle_t symbol, - inrout sequence scalar, - inrout sequence stack, + in sequence scalar, + in sequence stack, in sequence scalar_in_octet, rout sequence scalar_out_octet, in sequence stack_in_octet, diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl b/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl index d49e094c9f0fe..845ddeffa26f9 100644 --- a/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl +++ b/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl @@ -36,8 +36,8 @@ interface tvm_remote_nd { rout handle_t sym_ptr); long kernel(in handle_t mod, in handle_t symbol, - inrout sequence scalar, - inrout sequence stack, + in sequence scalar, + in sequence stack, in sequence scalar_in_octet, rout sequence scalar_out_octet, in sequence stack_in_octet, diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc index f245d7dcba4ed..c9e3332d59a76 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc @@ -35,8 +35,8 @@ // Stub functions for targets that don't support VTCM. static void* HAP_request_VTCM(int a, int b) { return 0; } static int HAP_release_VTCM(void* a) { return 0; } -static int HAP_query_avail_VTCM(unsigned* avail_block_size, - unsigned* max_page_size, unsigned* num_pages) { +static int HAP_query_avail_VTCM(unsigned* avail_block_size, unsigned* max_page_size, + unsigned* num_pages) { FARF(ALWAYS, "%s: running on architecture V62 or less", __func__); return AEE_ENOMEMORY; } @@ -62,8 +62,7 @@ int tvm_remote_open(const char* uri, remote_handle64* handle_ptr) { return rc; } - *handle_ptr = - static_cast(reinterpret_cast(malloc(1))); + *handle_ptr = static_cast(reinterpret_cast(malloc(1))); if (!*handle_ptr) { FARF(ERROR, "%s: cannot allocate memory", __func__); return AEE_ENOMEMORY; @@ -98,9 +97,7 @@ int tvm_remote_close(remote_handle64 handle) { * This function is present as a workaround. See comment at the call site * in hexagon_device_target.cc. */ -int tvm_remote_call_mmap64(remote_handle64 handle) { - return AEE_SUCCESS; -} +int tvm_remote_call_mmap64(remote_handle64 handle) { return AEE_SUCCESS; } /*! * \brief Load a shared library. @@ -112,8 +109,8 @@ int tvm_remote_call_mmap64(remote_handle64 handle) { * * \return 0 on success, negative value on error. */ -int tvm_remote_load_library(remote_handle64 handle, const char* soname, - int soname_len, tvm_remote_handle_t* lib_ptr) { +int tvm_remote_load_library(remote_handle64 handle, const char* soname, int soname_len, + tvm_remote_handle_t* lib_ptr) { return tvm_remote_nd_load_library(soname, soname_len, lib_ptr); } @@ -128,9 +125,8 @@ int tvm_remote_load_library(remote_handle64 handle, const char* soname, * * \return 0 on success, negative value on error. */ -int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, - const char* name, int name_len, - tvm_remote_handle_t* sym_ptr) { +int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, const char* name, + int name_len, tvm_remote_handle_t* sym_ptr) { return tvm_remote_nd_get_symbol(lib, name, name_len, sym_ptr); } @@ -163,24 +159,20 @@ int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, * The 8 "octet" arguments in this function are used for cache operations * only. They are not used for procesing. */ -int tvm_remote_kernel( - remote_handle64 handle, tvm_remote_handle_t lib, - tvm_remote_handle_t symbol, int* scalar, int scalar_len, int* stack, - int stack_len, const tvm_remote_buffer* scalar_in_octet, - int scalar_in_octet_len, tvm_remote_buffer* scalar_out_octet, - int scalar_out_octet_len, const tvm_remote_buffer* stack_in_octet, - int stack_in_octet_len, tvm_remote_buffer* stack_out_octet, - int stack_out_octet_len, uint64* pcycles, uint64* time_usec) { +int tvm_remote_kernel(remote_handle64 handle, tvm_remote_handle_t lib, tvm_remote_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_buffer* stack_out_octet, int stack_out_octet_len, uint64* pcycles, + uint64* time_usec) { return tvm_remote_nd_kernel( lib, symbol, scalar, scalar_len, stack, stack_len, - reinterpret_cast(scalar_in_octet), - scalar_in_octet_len, - reinterpret_cast(scalar_out_octet), - scalar_out_octet_len, - reinterpret_cast(stack_in_octet), - stack_in_octet_len, - reinterpret_cast(stack_out_octet), - stack_out_octet_len, pcycles, time_usec); + reinterpret_cast(scalar_in_octet), scalar_in_octet_len, + reinterpret_cast(scalar_out_octet), scalar_out_octet_len, + reinterpret_cast(stack_in_octet), stack_in_octet_len, + reinterpret_cast(stack_out_octet), stack_out_octet_len, pcycles, + time_usec); } /*! @@ -191,8 +183,7 @@ int tvm_remote_kernel( * * \return 0 on success, negative value on error. */ -int tvm_remote_release_library(remote_handle64 handle, - tvm_remote_handle_t lib) { +int tvm_remote_release_library(remote_handle64 handle, tvm_remote_handle_t lib) { // FARF(ALWAYS, "tvm_remote_release_library begin "); return tvm_remote_nd_release_library(lib); } @@ -208,8 +199,7 @@ int tvm_remote_release_library(remote_handle64 handle, * * \return 0 on success, negative value on error. */ -int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, - unsigned align, unsigned* dsp_va) { +int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, unsigned align, unsigned* dsp_va) { FARF(ALWAYS, "%s: size=%u, align=%u", __func__, size, align); unsigned avail_block_size, max_page_size, num_pages; int rc = HAP_query_avail_VTCM(&avail_block_size, &max_page_size, &num_pages); @@ -217,12 +207,11 @@ int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, FARF(ERROR, "%s: HAP_query_avail_VTCM failed, rc=%08x", __func__, rc); return rc; } - FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", - __func__, avail_block_size, max_page_size, num_pages); + FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", __func__, + avail_block_size, max_page_size, num_pages); if (max_page_size < MIN_VTCM_SZ) { - FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, - MIN_VTCM_SZ / 1024); + FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, MIN_VTCM_SZ / 1024); return AEE_ENOMEMORY; } diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc index 24af924da65f1..c0f6f22172c0e 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc @@ -41,8 +41,7 @@ struct msg_call { uint32_t data[]; } __attribute__((packed)); -__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, - uint64_t* pcc) { +__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, uint64_t* pcc) { __asm__( "// This function is intentionally written to be readable, \n" "// rather than fast. \n" @@ -114,8 +113,7 @@ __attribute__((naked)) uint32_t launcher(volatile msg_call* mc, extern "C" { #pragma weak __wrap_pthread_create -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg) { FARF(ERROR, "Wrong %s called", __func__); abort(); @@ -133,15 +131,13 @@ static void* lib_thread = nullptr; int tvm_remote_nd_open() { lib_thread = dlopen("libtvm_wrap_pthread.so", RTLD_NOW | RTLD_GLOBAL); if (lib_thread == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, - dlerror()); + FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, dlerror()); return AEE_EUNABLETOLOAD; } lib_rt = dlopen("libtvm_runtime.so", RTLD_NOW | RTLD_GLOBAL); if (lib_rt == nullptr) { - FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, - dlerror()); + FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, dlerror()); return AEE_EUNABLETOLOAD; } return AEE_SUCCESS; @@ -174,9 +170,7 @@ int tvm_remote_nd_close() { * This function is present as a workaround. See comment at the call site * in hexagon_device_target.cc. */ -int tvm_remote_nd_call_mmap64() { - return AEE_SUCCESS; -} +int tvm_remote_nd_call_mmap64() { return AEE_SUCCESS; } /*! * \brief Load a shared library. @@ -210,8 +204,8 @@ int tvm_remote_nd_load_library(const char* soname, int soname_len, * * \return 0 on success, negative value on error. */ -int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, - int name_len, tvm_remote_nd_handle_t* sym_ptr) { +int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, int name_len, + tvm_remote_nd_handle_t* sym_ptr) { FARF(ALWAYS, "%s: name=%s", __func__, name); if (void* p = dlsym(reinterpret_cast(lib), name)) { *sym_ptr = reinterpret_cast(p); @@ -223,8 +217,8 @@ int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, } static void print_msg_call(const msg_call& mc) { - FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, - mc.scalar_num, mc.stack_num); + FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, mc.scalar_num, + mc.stack_num); for (unsigned i = 0; i != mc.scalar_num; ++i) { FARF(ALWAYS, "scalar_data[%d] %x", i, mc.data[i]); } @@ -261,14 +255,13 @@ static void print_msg_call(const msg_call& mc) { * The 8 "octet" arguments in this function are used for cache operations * only. They are not used for procesing. */ -int tvm_remote_nd_kernel( - tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, int* scalar, - int scalar_len, int* stack, int stack_len, - const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, - tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, - const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, - tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, - uint64* pcycles, uint64* time_usec) { +int tvm_remote_nd_kernel(tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, + uint64* pcycles, uint64* time_usec) { hvx::config_t hvx_info = {0}; hvx::prepare_mt_job(&hvx_info); @@ -277,18 +270,16 @@ int tvm_remote_nd_kernel( if (hvx_info.num_reserved > 0) { lock_result = hvx::lock(hvx::MODE_128B); if (lock_result < 0) { - FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", - __func__, lock_result, hvx_info.num_reserved); + FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", __func__, lock_result, + hvx_info.num_reserved); } else { - FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, - lock_result); + FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, lock_result); } } else { FARF(ERROR, "%s: there are no HVX units available", __func__); } - struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * - (3 + scalar_len + stack_len)); + struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * (3 + scalar_len + stack_len)); if (mc == nullptr) { FARF(ERROR, "%s: failed to allocate memory for mc", __func__); return AEE_ENOMEMORY; @@ -312,8 +303,7 @@ int tvm_remote_nd_kernel( uint64_t start_time = HAP_perf_get_time_us(); int result = launcher(mc, pcycles); *time_usec = HAP_perf_get_time_us() - start_time; - FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, - *time_usec); + FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, *time_usec); if (lock_result > 0) hvx::unlock(); hvx::cleanup_mt_job(&hvx_info); if (mc) free(mc); diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc index 1192f7a6ac737..d26073af8ae13 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc @@ -44,13 +44,11 @@ static constexpr size_t kThreadStackSize = 128 * 1024; // 128kB // Make sure the function has C linkage. extern "C" { -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg); } -int __wrap_pthread_create(pthread_t* restrict thread, - const pthread_attr_t* restrict attr, +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, void* (*start)(void*), void* restrict arg) { pthread_attr_t def_attr; if (attr == nullptr) { @@ -72,8 +70,7 @@ int __wrap_pthread_create(pthread_t* restrict thread, FARF(ALWAYS, "launching thread with stack_size=%zu", stack_size); int t = pthread_create(thread, attr, start, arg); if (int rc = pthread_attr_destroy(&def_attr)) { - FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", - rc); + FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", rc); } return t; } diff --git a/src/runtime/hexagon/target/hexagon_device_target.cc b/src/runtime/hexagon/target/hexagon_device_target.cc index a62aa479bf143..ee326ca0b159b 100644 --- a/src/runtime/hexagon/target/hexagon_device_target.cc +++ b/src/runtime/hexagon/target/hexagon_device_target.cc @@ -45,10 +45,8 @@ // The downside is that the format string must be given as a string literal, // but it seems to be a minor issue. #define VA_EXPANDER(...) , ##__VA_ARGS__ -#define TVM_LOGD_HT(fmt, ...) \ - TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) -#define TVM_LOGE_HT(fmt, ...) \ - TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGD_HT(fmt, ...) TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGE_HT(fmt, ...) TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) namespace tvm { namespace runtime { @@ -74,8 +72,7 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { unsigned stack_num) final; private: - std::pair AddAddrMapping(const void* dsp_addr, - void* apps_addr, size_t size); + std::pair AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size); std::pair GetAppsAddr(const void* dsp_addr, bool exact) const; void RemoveAddrMapping(const void* dsp_addr); int OpenDomainChannel(bool set_unsigned_pd); @@ -102,24 +99,19 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { void* const HexagonTarget::vtcm_mark_ = reinterpret_cast(~0); -std::shared_ptr CreateHexagonTarget() { - return std::make_shared(); -} +std::shared_ptr CreateHexagonTarget() { return std::make_shared(); } -std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, - void* apps_addr, +std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size) { crit_section_.lock(); auto p = dsp_to_apps_.insert({dsp_addr, {apps_addr, size}}); crit_section_.unlock(); if (!p.second) { - TVM_LOGE_HT( - "failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", - dsp_addr, apps_addr, size); + TVM_LOGE_HT("failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, + apps_addr, size); return std::make_pair(nullptr, 0); } - TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, - apps_addr, size); + TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, apps_addr, size); return p.first->second; } @@ -135,8 +127,7 @@ void HexagonTarget::RemoveAddrMapping(const void* dsp_addr) { crit_section_.unlock(); } -std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, - bool exact) const { +std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, bool exact) const { struct AutoUnlock { explicit AutoUnlock(std::mutex& m) : m(m) {} ~AutoUnlock() { m.unlock(); } @@ -192,16 +183,14 @@ int HexagonTarget::OpenDomainChannel(bool use_unsigned_pd) { data.domain = CDSP_DOMAIN_ID; int rc = rsc_ptr(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", - rc); + TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", rc); } } } else { TVM_LOGD_HT("remote_session_control not available"); } - int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", - &domain_channel_handle_); + int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", &domain_channel_handle_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to open channel rc=0x%x", rc); } else { @@ -231,8 +220,7 @@ void HexagonTarget::ReleaseLibrary() { crit_section_.lock(); if (module_pointer_ != AEE_EUNKNOWN) { const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, - module_pointer_); + int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to unload device library rc=0x%x", rc); } else { @@ -267,24 +255,20 @@ void* HexagonTarget::Alloc(unsigned size, unsigned align) { // thread then remote_mmap64 fails. FastRPC expects one call to be made to // DSP before calling remote_map64. Hence this call is needed for now untill // FastRPC comes up with a fix. - int rc_call_mmap_64 = - stub_api->tvm_remote_call_mmap64(domain_channel_handle_); + int rc_call_mmap_64 = stub_api->tvm_remote_call_mmap64(domain_channel_handle_); if (rc_call_mmap_64 != AEE_SUCCESS) { - TVM_LOGE_HT("mmap64 failed for domain channel %lu", - domain_channel_handle_); + TVM_LOGE_HT("mmap64 failed for domain channel %lu", domain_channel_handle_); return nullptr; } - void* mem = - stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); + void* mem = stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); if (mem == nullptr) { TVM_LOGE_HT("mem alloc failed for size=0x%x alignment=0x%x", size, align); return nullptr; } int mem_fd = stub_api->rpcmem_to_fd_ptr()(mem); uintptr_t dsp_va = 0; - int rc = dsp_api->remote_mmap64_ptr()( - mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); + int rc = dsp_api->remote_mmap64_ptr()(mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT( "buffer mapping failed for remote_map64 fd=0x%x rc=0x%x " @@ -313,8 +297,7 @@ void HexagonTarget::Free(void* ptr) { auto aa = GetAppsAddr(ptr, true); if (aa.first == nullptr) return; - int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), - aa.second); + int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), aa.second); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("buffer unmapping failed rc=0x%x", rc); } @@ -326,8 +309,7 @@ void* HexagonTarget::AllocVtcm(unsigned size, unsigned align) { const StubAPI* stub_api = StubAPI::Global(); unsigned int dsp_va = 0; - int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, - &dsp_va); + int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("VTCM allocation failed size=%u, align=%u", size, align); return nullptr; @@ -350,8 +332,7 @@ void HexagonTarget::FreeVtcm(void* ptr) { TVM_LOGD_HT("Done VTCM free from HexagonTarget::FreeVtcm"); } -void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { auto aa_src = GetAppsAddr(src, false); auto aa_dst = GetAppsAddr(dst, false); if (aa_src.first == vtcm_mark_ || aa_dst.first == vtcm_mark_) { @@ -375,13 +356,12 @@ void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, len, aa_dst.second); } len = std::min({size_t(len), aa_src.second, aa_dst.second}); - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, - aa_src.first, dst, aa_dst.first, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, aa_src.first, dst, + aa_dst.first, len); std::memcpy(aa_dst.first, aa_src.first, len); } -void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { auto aa = GetAppsAddr(src, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -392,18 +372,14 @@ void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, - host_dst, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, host_dst, len); std::memcpy(host_dst, aa.first, len); } -void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { +void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { auto aa = GetAppsAddr(dst, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -414,13 +390,10 @@ void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, - host_src, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, host_src, len); std::memcpy(aa.first, host_src, len); } @@ -429,8 +402,7 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd); crit_section_.unlock(); if (rc_oc != AEE_SUCCESS) { - TVM_LOGE_HT("loading of %s failed: unable to open domain channel", - data.c_str()); + TVM_LOGE_HT("loading of %s failed: unable to open domain channel", data.c_str()); return nullptr; } @@ -440,8 +412,8 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { crit_section_.lock(); TVM_LOGD_HT("loading library %s ", data.c_str()); const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_remote_load_library( - domain_channel_handle_, data.c_str(), data.size() + 1, &module_pointer_); + int rc = stub_api->tvm_remote_load_library(domain_channel_handle_, data.c_str(), data.size() + 1, + &module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to load device library rc=0x%x", rc); } @@ -473,9 +445,8 @@ void* HexagonTarget::Resolve(const std::string& sym) { tvm_remote_handle_t pf; TVM_LOGD_HT("resolving symbol %s", sym.c_str()); - int rc = - stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, - sym.c_str(), sym.size() + 1, &pf); + int rc = stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, sym.c_str(), + sym.size() + 1, &pf); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to get symbol from CDSP rc=0x%x", rc); return nullptr; @@ -485,13 +456,11 @@ void* HexagonTarget::Resolve(const std::string& sym) { return addr; } -void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, - uint32_t* stack, unsigned stack_num) { +void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack, + unsigned stack_num) { uint64 pcycles = 0, execution_time_usec = 0; - auto scalar_octet = - std::unique_ptr(new tvm_remote_buffer[scalar_num]); - auto stack_octet = - std::unique_ptr(new tvm_remote_buffer[stack_num]); + auto scalar_octet = std::unique_ptr(new tvm_remote_buffer[scalar_num]); + auto stack_octet = std::unique_ptr(new tvm_remote_buffer[stack_num]); TVM_LOGD_HT("scalars=%p, stack=%p", scalar, stack); if (scalar_octet == nullptr || stack_octet == nullptr) { @@ -501,8 +470,7 @@ void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, std::memset(scalar_octet.get(), 0, scalar_num * sizeof(tvm_remote_buffer)); std::memset(stack_octet.get(), 0, stack_num * sizeof(tvm_remote_buffer)); - auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, - unsigned num) { + auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, unsigned num) { for (unsigned i = 0; i != num; ++i) { void* ptr = reinterpret_cast(static_cast(inputs[i])); auto aa = GetAppsAddr(ptr, false); @@ -534,16 +502,15 @@ void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, int rc = stub_api->tvm_remote_kernel( domain_channel_handle_, module_pointer_, static_cast(reinterpret_cast(func)), - reinterpret_cast(scalar), scalar_num, - reinterpret_cast(stack), stack_num, scalar_octet.get(), scalar_num, - scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, + reinterpret_cast(scalar), scalar_num, reinterpret_cast(stack), stack_num, + scalar_octet.get(), scalar_num, scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, stack_octet.get(), stack_num, &pcycles, &execution_time_usec); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to run kernel on CDSP rc=0x%x", rc); } else { - TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", - pcycles, execution_time_usec, scalar_num); + TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", pcycles, + execution_time_usec, scalar_num); } } diff --git a/src/runtime/hexagon/target/hexagon_stubapi.cc b/src/runtime/hexagon/target/hexagon_stubapi.cc index 939c382364ab6..2ed33471b98f9 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.cc +++ b/src/runtime/hexagon/target/hexagon_stubapi.cc @@ -44,8 +44,7 @@ StubAPI::StubAPI() { constexpr auto domain_lib_name = "libtvm_remote_stub.so"; constexpr auto nondomain_lib_name = "libtvm_remote_nd_stub.so"; - const char* lib_name = - enable_domains_ ? domain_lib_name : nondomain_lib_name; + const char* lib_name = enable_domains_ ? domain_lib_name : nondomain_lib_name; CHECK(lib_handle_ = dlopen(lib_name, RTLD_LAZY | RTLD_LOCAL)); #define RESOLVE(fn) p##fn##_ = GetSymbol(#fn) diff --git a/src/runtime/hexagon/target/hexagon_stubapi.h b/src/runtime/hexagon/target/hexagon_stubapi.h index 6f3828fb2c8ab..5213b6d0d7af6 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.h +++ b/src/runtime/hexagon/target/hexagon_stubapi.h @@ -162,8 +162,7 @@ class StubAPI { // two types identical in the function types created below. // For example, int foo(tvm_remote_buffer*) and // int bar(tvm_remote_nd_buffer*) should both have the same type. -#define MAPTYPE(fn, ty) \ - using fn##_t = typename map_func_type::type; +#define MAPTYPE(fn, ty) using fn##_t = typename map_func_type::type; MAPTYPE(tvm_remote_load_library, tvm_remote_buffer) MAPTYPE(tvm_remote_release_library, tvm_remote_buffer) MAPTYPE(tvm_remote_get_symbol, tvm_remote_buffer) @@ -196,8 +195,7 @@ class StubAPI { public: template - int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, - Ts... args) const { + int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, Ts... args) const { if (enable_domains_) { return func_d(handle, args...); } @@ -219,11 +217,10 @@ class StubAPI { #define FUNC_ND(name) CONCAT_STR(tvm_remote_nd_, name) #define PTRNAME(fn) CONCAT_STR(p, CONCAT_STR(fn, _)) -#define DECLFUNC(name) \ - template \ - int FUNC(name)(remote_handle64 handle, Ts... args) const { \ - return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, \ - args...); \ +#define DECLFUNC(name) \ + template \ + int FUNC(name)(remote_handle64 handle, Ts... args) const { \ + return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, args...); \ } #define DECLFUNC_D(name) \ diff --git a/src/runtime/hexagon/target/hexagon_target_log.h b/src/runtime/hexagon/target/hexagon_target_log.h index ae09503cd35b5..c7684fc561970 100644 --- a/src/runtime/hexagon/target/hexagon_target_log.h +++ b/src/runtime/hexagon/target/hexagon_target_log.h @@ -23,18 +23,12 @@ #include -#define TVM_LOGV(...) \ - __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) -#define TVM_LOGD(...) \ - __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) -#define TVM_LOGI(...) \ - __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) -#define TVM_LOGW(...) \ - __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) -#define TVM_LOGE(...) \ - __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) -#define TVM_LOGF(...) \ - __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) +#define TVM_LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) +#define TVM_LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) +#define TVM_LOGI(...) __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) +#define TVM_LOGW(...) __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) +#define TVM_LOGE(...) __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) +#define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) #endif // __ANDROID__ #endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 306a7e990516d..7c3323c56229e 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -21,13 +21,15 @@ * \file module_util.cc * \brief Utilities for module. */ +#include "library_module.h" + #include #include #include + #include -#include #include -#include "library_module.h" +#include namespace tvm { namespace runtime { @@ -35,22 +37,16 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) - : lib_(lib) { - } + explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} - const char* type_key() const final { - return "library"; - } + const char* type_key() const final { return "library"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - lib_->GetSymbol(runtime::symbol::tvm_module_main)); - CHECK(entry_name!= nullptr) + const char* entry_name = + reinterpret_cast(lib_->GetSymbol(runtime::symbol::tvm_module_main)); + CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); } else { @@ -70,35 +66,27 @@ class LibraryModuleNode final : public ModuleNode { class ModuleInternal { public: // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { - return &(node->imports_); - } + static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; -PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, - const ObjectPtr& sptr_to_self) { +PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMValue ret_value; - int ret_type_code = kTVMNullptr; - int ret = (*faddr)( - const_cast(args.values), - const_cast(args.type_codes), - args.num_args, - &ret_value, - &ret_type_code); - CHECK_EQ(ret, 0) << TVMGetLastError(); - if (ret_type_code != kTVMNullptr) { - *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); - } - }); + TVMValue ret_value; + int ret_type_code = kTVMNullptr; + int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), + args.num_args, &ret_value, &ret_type_code); + CHECK_EQ(ret, 0) << TVMGetLastError(); + if (ret_type_code != kTVMNullptr) { + *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); + } + }); } void InitContextFunctions(std::function fgetsymbol) { - #define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto *fp = reinterpret_cast \ - (fgetsymbol("__" #FuncName))) { \ - *fp = FuncName; \ - } +#define TVM_INIT_CONTEXT_FUNC(FuncName) \ + if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ + *fp = FuncName; \ + } // Initialize the functions TVM_INIT_CONTEXT_FUNC(TVMFuncCall); TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); @@ -108,7 +96,7 @@ void InitContextFunctions(std::function fgetsymbol) { TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); - #undef TVM_INIT_CONTEXT_FUNC +#undef TVM_INIT_CONTEXT_FUNC } /*! @@ -123,10 +111,10 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); + nbytes |= (c & 0xffUL) << (i * 8); } - dmlc::MemoryFixedSizeStream fs( - const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); + dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), + static_cast(nbytes)); dmlc::Stream* stream = &fs; uint64_t size; CHECK(stream->Read(&size)); @@ -147,9 +135,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } else { std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(stream)); modules.emplace_back(m); } @@ -180,14 +166,11 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } Module CreateModuleFromLibrary(ObjectPtr lib) { - InitContextFunctions([lib](const char* fname) { - return lib->GetSymbol(fname); - }); + InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib); // Load the imported modules const char* dev_mblob = - reinterpret_cast( - lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); Module root_mod; if (dev_mblob != nullptr) { root_mod = ProcessModuleBlob(dev_mblob, lib); @@ -197,8 +180,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { } // allow lookup of symbol from root (so all symbols are visible). - if (auto *ctx_addr = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { + if (auto* ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { *ctx_addr = root_mod.operator->(); } diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 61e62661f149e..91918c1ccaa3a 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -24,9 +24,10 @@ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -47,7 +48,7 @@ class Library : public Object { * \param name The name of the symbol. * \return The symbol. */ - virtual void *GetSymbol(const char* name) = 0; + virtual void* GetSymbol(const char* name) = 0; // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting. }; @@ -77,4 +78,4 @@ void InitContextFunctions(std::function fgetsymbol); Module CreateModuleFromLibrary(ObjectPtr lib); } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ +#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 22f2e9aa0909b..451c0e88fcb0c 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_META_DATA_H_ #define TVM_RUNTIME_META_DATA_H_ -#include #include +#include #include + #include #include + #include "runtime_base.h" namespace tvm { @@ -40,10 +42,10 @@ struct FunctionInfo { std::vector arg_types; std::vector thread_axis_tags; - void Save(dmlc::JSONWriter *writer) const; - void Load(dmlc::JSONReader *reader); - void Save(dmlc::Stream *writer) const; - bool Load(dmlc::Stream *reader); + void Save(dmlc::JSONWriter* writer) const; + void Load(dmlc::JSONReader* reader); + void Save(dmlc::Stream* writer) const; + bool Load(dmlc::Stream* reader); }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 8a7c9fe53018a..ca369d46e5baa 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -24,21 +24,22 @@ #ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_ #define TVM_RUNTIME_METAL_METAL_COMMON_H_ +#import #import -#import #import -#import +#import #import #import - +#include #include -#include #include -#include +#include + +#include #include #include #include -#include + #include "../workspace_pool.h" namespace tvm { @@ -64,14 +65,14 @@ class MetalWorkspace final : public DeviceAPI { // Get command queue for given context. id GetCommandQueue(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid Metal device_id=" << ctx.device_id; return queues[ctx.device_id]; } // Get device for given context id GetDevice(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) << "Invalid Metal device_id=" << ctx.device_id; return devices[ctx.device_id]; } @@ -81,19 +82,10 @@ class MetalWorkspace final : public DeviceAPI { // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_size, - void* to, - size_t to_size, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -112,8 +104,7 @@ class MetalThreadEntry { /*! \brief workspace pool */ WorkspacePool pool; // constructor - MetalThreadEntry() - : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { + MetalThreadEntry() : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { context.device_id = 0; context.device_type = static_cast(kDLMetal); } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index a49f8a5cfc96b..3bad2c3e9debb 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -20,8 +20,8 @@ /*! * \file metal_device_api.mm */ -#include #include +#include #include "metal_common.h" namespace tvm { @@ -29,25 +29,21 @@ namespace metal { const std::shared_ptr& MetalWorkspace::Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } -void MetalWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = int(index< devices.size()); + *rv = int(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { case kMaxThreadsPerBlock: { - *rv = static_cast( - [devices[ctx.device_id] maxThreadsPerThreadgroup].width); + *rv = static_cast([devices[ctx.device_id] maxThreadsPerThreadgroup].width); break; } case kWarpSize: { @@ -55,14 +51,22 @@ *rv = 1; break; } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: return; - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kExist: break; - case kGcnArch: return; + case kMaxSharedMemoryPerBlock: + return; + case kComputeVersion: + return; + case kDeviceName: + return; + case kMaxClockRate: + return; + case kMultiProcessorCount: + return; + case kMaxThreadDimensions: + return; + case kExist: + break; + case kGcnArch: + return; } } @@ -87,22 +91,13 @@ kernel void CopyKernel( // But we keep this code. int GetWarpSize(id dev) { NSError* error_msg = nil; - id lib = - [dev - newLibraryWithSource: - [NSString stringWithUTF8String:kDummyKernel] - options:nil - error:&error_msg]; + id lib = [dev newLibraryWithSource:[NSString stringWithUTF8String:kDummyKernel] + options:nil + error:&error_msg]; CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String]; - id f = - [lib - newFunctionWithName: - [NSString stringWithUTF8String:"CopyKernel"]]; - CHECK(f!= nil); - id state = - [dev - newComputePipelineStateWithFunction:f - error:&error_msg]; + id f = [lib newFunctionWithName:[NSString stringWithUTF8String:"CopyKernel"]]; + CHECK(f != nil); + id state = [dev newComputePipelineStateWithFunction:f error:&error_msg]; CHECK(state != nil) << [[error_msg localizedDescription] UTF8String]; return static_cast(state.threadExecutionWidth); } @@ -123,20 +118,19 @@ int GetWarpSize(id dev) { initialized_ = true; if (devices.size() != 0) return; #if TARGET_OS_IPHONE - // on iPhone - id d = MTLCreateSystemDefaultDevice(); + // on iPhone + id d = MTLCreateSystemDefaultDevice(); + devices.push_back([d retain]); + queues.push_back([[d newCommandQueue] retain]); +#else + NSArray >* devs = MTLCopyAllDevices(); + for (size_t i = 0; i < devs.count; ++i) { + id d = [devs objectAtIndex:i]; devices.push_back([d retain]); queues.push_back([[d newCommandQueue] retain]); -#else - NSArray>* devs = MTLCopyAllDevices(); - for (size_t i = 0; i < devs.count; ++i) { - id d = [devs objectAtIndex:i]; - devices.push_back([d retain]); - queues.push_back([[d newCommandQueue] retain]); - LOG(INFO) << "Intializing Metal device " << i - << ", name=" << [d.name UTF8String]; - warp_size.push_back(GetWarpSize(d)); - } + LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; + warp_size.push_back(GetWarpSize(d)); + } #endif } @@ -144,8 +138,8 @@ int GetWarpSize(id dev) { MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id; } -void* MetalWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { +void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) { this->Init(); id dev = GetDevice(ctx); // GPU memory only @@ -157,9 +151,7 @@ int GetWarpSize(id dev) { storage_mode = MTLResourceStorageModeManaged; #endif */ - id buf = [ - dev newBufferWithLength:nbytes - options:storage_mode]; + id buf = [dev newBufferWithLength:nbytes options:storage_mode]; CHECK(buf != nil); return (__bridge void*)([buf retain]); } @@ -169,14 +161,9 @@ int GetWarpSize(id dev) { CFRelease(ptr); } -void MetalWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); @@ -188,65 +175,54 @@ int GetWarpSize(id dev) { int to_dev_type = static_cast(ctx_to.device_type); if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) { - CHECK_EQ(ctx_from.device_id, ctx_to.device_id) - << "Metal disallow cross device copy."; + CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy."; id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:(__bridge id)(from) - sourceOffset:from_offset - toBuffer:(__bridge id)(to) - destinationOffset:to_offset - size:size]; + sourceOffset:from_offset + toBuffer:(__bridge id)(to)destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) { // copy to a local buffer before get into global buffer. id from_buf = (__bridge id)(from); if (from_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_from, size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:from_buf - sourceOffset:from_offset - toBuffer:temp - destinationOffset:0 - size:size]; + sourceOffset:from_offset + toBuffer:temp + destinationOffset:0 + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; - memcpy(static_cast(to) + to_offset, - static_cast([temp contents]), - size); + memcpy(static_cast(to) + to_offset, static_cast([temp contents]), size); } else { memcpy(static_cast(to) + to_offset, - static_cast([from_buf contents]) + from_offset, - size); + static_cast([from_buf contents]) + from_offset, size); } } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) { id to_buf = (__bridge id)(to); if (to_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_to, size); - memcpy([temp contents], - static_cast(from) + from_offset, - size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size); + memcpy([temp contents], static_cast(from) + from_offset, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:temp - sourceOffset:0 - toBuffer:to_buf - destinationOffset:to_offset - size:size]; + sourceOffset:0 + toBuffer:to_buf + destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; } else { memcpy(static_cast([to_buf contents]) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } } else { LOG(FATAL) << "Expect copy from/to Metal or between Metal" - << ", from=" << from_dev_type - << ", to=" << to_dev_type; + << ", from=" << from_dev_type << ", to=" << to_dev_type; } } @@ -259,9 +235,7 @@ int GetWarpSize(id dev) { [cb waitUntilCompleted]; } -void* MetalWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } @@ -279,30 +253,25 @@ int GetWarpSize(id dev) { if (temp_buffer_.size() <= static_cast(ctx.device_id)) { temp_buffer_.resize(ctx.device_id + 1, nil); } - if (temp_buffer_[ctx.device_id] == nil || - temp_buffer_[ctx.device_id].length < size) { + if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) { id dev = MetalWorkspace::Global()->GetDevice(ctx); if (temp_buffer_[ctx.device_id] != nil) { [temp_buffer_[ctx.device_id] release]; } - temp_buffer_[ctx.device_id] = [ - [dev newBufferWithLength:size - options:MTLStorageModeShared] retain]; + temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size + options:MTLStorageModeShared] retain]; } return temp_buffer_[ctx.device_id]; } typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry* MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MetalWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MetalWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace metal } // namespace runtime diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 0d2d429fcf61b..77cdf64df8bc7 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -44,11 +46,8 @@ static constexpr const int kMetalMaxNumDevice = 32; * \param fmap The map function information map of each function. * \param source Optional, source file */ -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 41269b9f1a5d2..9bdebf3d06c10 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -20,18 +20,18 @@ /*! * \file metal_module.cc */ +#include "metal_module.h" #include -#include #include +#include #include -#include #include -#include "metal_module.h" -#include "metal_common.h" +#include +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "metal_common.h" namespace tvm { namespace runtime { @@ -39,27 +39,18 @@ // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded -class MetalModuleNode final :public runtime::ModuleNode { +class MetalModuleNode final : public runtime::ModuleNode { public: - explicit MetalModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - } - const char* type_key() const final { - return "metal"; - } + explicit MetalModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + const char* type_key() const final { return "metal"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -81,8 +72,7 @@ void SaveToBinary(dmlc::Stream* stream) final { } } // get a from primary context in device_id - id GetPipelineState( - size_t device_id, const std::string& func_name) { + id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get(); CHECK_LT(device_id, w->devices.size()); // start lock scope. @@ -97,53 +87,43 @@ void SaveToBinary(dmlc::Stream* stream) final { NSError* err_msg = nil; if (e.lib == nil) { if (fmt_ == "metal") { - MTLCompileOptions *opts = [MTLCompileOptions alloc]; + MTLCompileOptions* opts = [MTLCompileOptions alloc]; // Use the Metal 1.2 for now. opts.languageVersion = MTLLanguageVersion1_2; opts.fastMathEnabled = YES; // opts = nil; - e.lib = [ - w->devices[device_id] - newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] - options:opts - error:&err_msg]; + e.lib = [w->devices[device_id] + newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] + options:opts + error:&err_msg]; [opts dealloc]; if (e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } if (err_msg != nil) { - LOG(INFO) << "Warning: " - << [[err_msg localizedDescription] UTF8String]; + LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; } } else { // Build from library. auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - auto data = dispatch_data_create( - data_.c_str(), data_.length(), q, ^{}); - e.lib = [ - w->devices[device_id] - newLibraryWithData:data - error:&err_msg]; + auto data = dispatch_data_create(data_.c_str(), data_.length(), q, + ^{ + }); + e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; if (err_msg != nil || e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } } [e.lib retain]; } - id f = [ - e.lib - newFunctionWithName: - [NSString stringWithUTF8String:func_name.c_str()]]; + id f = + [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; CHECK(f != nil) << "cannot find function " << func_name; id state = - [w->devices[device_id] - newComputePipelineStateWithFunction:f - error:&err_msg]; - CHECK(state != nil) - << "cannot get state:" << " for function " << func_name - << [[err_msg localizedDescription] UTF8String]; + [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg]; + CHECK(state != nil) << "cannot get state:" + << " for function " << func_name + << [[err_msg localizedDescription] UTF8String]; // The state.threadExecutionWidth can change dynamically according // to the resource constraint in kernel, so it is not strictly hold // Turn of warp aware optimziation for now. @@ -162,7 +142,7 @@ void SaveToBinary(dmlc::Stream* stream) final { ~DeviceEntry() { if (lib != nil) [lib release]; - for (auto &&kv : smap) { + for (auto&& kv : smap) { [kv.second release]; } } @@ -185,11 +165,8 @@ void SaveToBinary(dmlc::Stream* stream) final { class MetalWrappedFunc { public: // initialize the METAL function. - void Init(MetalModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_buffer_args, - size_t num_pack_args, + void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { w_ = metal::MetalWorkspace::Global().get(); m_ = m; @@ -204,9 +181,7 @@ void Init(MetalModuleNode* m, scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -223,16 +198,13 @@ void operator()(TVMArgs args, } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) - atIndex:num_buffer_args_]; + length:num_pack_args_ * sizeof(ArgUnion) + atIndex:num_buffer_args_]; } // launch - MTLSize dimGrid = MTLSizeMake( - wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - MTLSize dimBlock = MTLSizeMake( - wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); - [encoder dispatchThreadgroups: dimGrid - threadsPerThreadgroup: dimBlock]; + MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); + [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock]; [encoder endEncoding]; [cb commit]; } @@ -257,36 +229,29 @@ void operator()(TVMArgs args, ThreadAxisConfig thread_axis_cfg_; }; -PackedFunc MetalModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MetalModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); - f.Init(this, sptr_to_self, name, - num_buffer_args, info.arg_types.size() - num_buffer_args, + f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.thread_axis_tags); return PackFuncNonBufferArg(f, info.arg_types); } -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { metal::MetalWorkspace::Global()->Init(); auto n = make_object(data, fmt, fmap, source); return Module(n); } // Load module from module. -Module MetalModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -307,10 +272,8 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(data, fmt, fmap, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal") -.set_body_typed(MetalModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal") -.set_body_typed(MetalModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s b/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s index 300deb8079a07..f5720f4d7b280 100644 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s +++ b/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s @@ -17,11 +17,6 @@ * under the License. */ -/*! - * \file utvm_init.s - * \brief uTVM init definition for STM32F746XX-series boards - */ - .syntax unified .cpu cortex-m7 .fpu softvfp diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c b/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c index 1b8376150fce9..0f13a7dede883 100644 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c +++ b/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c @@ -29,100 +29,51 @@ extern "C" { #include #include "utvm_runtime.h" +// NOTE: This expects ST CMSIS to be in your include path. +// Download STM32CubeF7 here: +// https://www.st.com/content/st_com/en/products/embedded-software/mcu-mpu-embedded-software/stm32-embedded-software/stm32cube-mcu-mpu-packages/stm32cubef7.html +// and add Drivers/CMSIS to your C include path. +#include "Device/ST/STM32F7xx/Include/stm32f746xx.h" -// There are two implementations of cycle counters on the STM32F7X: SysTick and -// CYCCNT. SysTick is preferred, as it gives better error handling, but the -// counter is only 24 bits wide. If a larger timer is needed, use the CYCCNT -// implementation, which has a 32-bit counter. -#define USE_SYSTICK -#ifdef USE_SYSTICK - -#define SYST_CSR (*((volatile uint32_t *) 0xE000E010)) -#define SYST_RVR (*((volatile uint32_t *) 0xE000E014)) -#define SYST_CVR (*((volatile uint32_t *) 0xE000E018)) -#define SYST_CALIB (*((volatile uint32_t *) 0xE000E01C)) - -#define SYST_CSR_ENABLE 0 -#define SYST_CSR_TICKINT 1 -#define SYST_CSR_CLKSOURCE 2 -#define SYST_COUNTFLAG 16 - -#define SYST_CALIB_NOREF 31 -#define SYST_CALIB_SKEW 30 - -uint32_t start_time = 0; -uint32_t stop_time = 0; +#define utvm_SystemCoreClock 216000000UL int32_t UTVMTimerStart() { - SYST_CSR = (1 << SYST_CSR_ENABLE) | (1 << SYST_CSR_CLKSOURCE); - // wait until timer starts - while (SYST_CVR == 0) {} - start_time = SYST_CVR; - return 0; -} - -void UTVMTimerStop() { - SYST_CSR = 0; - stop_time = SYST_CVR; + UTVMTimerReset(); + TIM2->CR1 = + TIM_CR1_CEN; // Start counter + return UTVM_ERR_OK; } -void UTVMTimerReset() { - SYST_CSR = 0; - // maximum reload value (24-bit) - SYST_RVR = (~((uint32_t) 0)) >> 8; - SYST_CVR = 0; -} - -uint32_t UTVMTimerRead() { - if (SYST_CSR & SYST_COUNTFLAG) { - TVMAPISetLastError("timer overflowed"); - return -1; - } else { - return start_time - stop_time; +uint32_t UTVMTimerStop(int32_t* err) { + TIM2->CR1 &= TIM_CR1_CEN; + if (TIM2->SR & TIM_SR_UIF_Msk) { + *err = UTVM_ERR_TIMER_OVERFLOW; + return 0; } + *err = UTVM_ERR_OK; + uint32_t tim_cnt = TIM2->CNT; + uint32_t millis = tim_cnt / (utvm_SystemCoreClock / 1000); + uint32_t micros = + (tim_cnt - (millis * (utvm_SystemCoreClock / 1000))) / + (utvm_SystemCoreClock / 1000000); + return millis * 1000 + micros; } -#else // !USE_SYSTICK - -#define DWT_CTRL (*((volatile uint32_t *) 0xE0001000)) -#define DWT_CYCCNT (*((volatile uint32_t *) 0xE0001004)) - -#define DWT_CTRL_NOCYCCNT 25 -#define DWT_CTRL_CYCCNTENA 0 - -uint32_t start_time = 0; -uint32_t stop_time = 0; - void UTVMTimerReset() { - DWT_CYCCNT = 0; -} - -int32_t UTVMTimerStart() { - if (DWT_CTRL & DWT_CTRL_NOCYCCNT) { - TVMAPISetLastError("cycle counter not implemented on device"); - return -1; + RCC->APB1RSTR |= RCC_APB1RSTR_TIM2RST; // Hold TIM2 in reset + RCC->DCKCFGR1 = (RCC->DCKCFGR1 & ~RCC_DCKCFGR1_TIMPRE_Msk); // disable 2x clock boost to TIM2 + RCC->CFGR = (RCC->CFGR & ~RCC_CFGR_PPRE1_Msk); // No AHB clock division to APB1 (1:1). + RCC->APB1ENR |= RCC_APB1ENR_TIM2EN; // Enable TIM2 clock. + RCC->APB1RSTR &= ~RCC_APB1RSTR_TIM2RST; // Exit TIM2 reset. + + DBGMCU->APB1FZ |= DBGMCU_APB1_FZ_DBG_TIM2_STOP; // stop TIM2 clock during debug halt. + TIM2->ARR = 0xffffffff; + if (TIM2->SR & TIM_SR_UIF_Msk) { + for (;;) ; } - start_time = DWT_CYCCNT; - DWT_CTRL |= (1 << DWT_CTRL_CYCCNTENA); } -void UTVMTimerStop() { - stop_time = DWT_CYCCNT; - DWT_CTRL &= ~(1 << DWT_CTRL_CYCCNTENA); -} - -int32_t UTVMTimerRead() { - if (stop_time > stop_time) { - return stop_time - start_time; - } else { - uint32_t largest = ~0; - return (largest - start_time) + stop_time; - } -} - -#endif // USE_SYSTICK - #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/src/runtime/micro/device/host/utvm_timer.c b/src/runtime/micro/device/host/utvm_timer.c index 56a36ebae86d7..6ab585a88f244 100644 --- a/src/runtime/micro/device/host/utvm_timer.c +++ b/src/runtime/micro/device/host/utvm_timer.c @@ -22,26 +22,16 @@ * \brief uTVM timer API stubs for the host emulated device */ -#ifdef __cplusplus -extern "C" { -#endif - +#include #include "utvm_runtime.h" // TODO(weberlo): use this? https://stackoverflow.com/questions/5141960/get-the-current-time-in-c int32_t UTVMTimerStart() { - return 0; + return UTVM_ERR_OK; } -void UTVMTimerStop() { } - -void UTVMTimerReset() { } - -uint32_t UTVMTimerRead() { - return 1; +uint32_t UTVMTimerStop(int32_t* err) { + *err = UTVM_ERR_OK; + return 0; } - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif diff --git a/src/runtime/micro/device/riscv_spike/utvm_init.s b/src/runtime/micro/device/riscv_spike/utvm_init.s new file mode 100644 index 0000000000000..68662cce97e71 --- /dev/null +++ b/src/runtime/micro/device/riscv_spike/utvm_init.s @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +UTVMInit: + /* set stack pointer */ + la sp, _utvm_stack_pointer_init + call UTVMMain diff --git a/web/example_rpc_node.js b/src/runtime/micro/device/riscv_spike/utvm_timer.c similarity index 60% rename from web/example_rpc_node.js rename to src/runtime/micro/device/riscv_spike/utvm_timer.c index 45f917a3234ba..5cf38559feab0 100644 --- a/web/example_rpc_node.js +++ b/src/runtime/micro/device/riscv_spike/utvm_timer.c @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,17 +17,26 @@ * under the License. */ -// Javascript RPC server example -// Start and connect to websocket proxy. +/*! + * \file utvm_timer.c + * \brief uTVM timer API stubs for Spike + */ + +#ifdef __cplusplus +extern "C" { +#endif + +#include "utvm_runtime.h" + +int32_t UTVMTimerStart() { + return UTVM_ERR_OK; +} -// Load Emscripten Module, need to change path to root/lib -const path = require("path"); -process.chdir(path.join(__dirname, "../lib")); -var Module = require("../lib/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +uint32_t UTVMTimerStop(int32_t* err) { + *err = UTVM_ERR_OK; + return 0; +} -var websock_proxy = "ws://localhost:9190/ws"; -var num_sess = 100; -tvm.startRPCServer(websock_proxy, "js", num_sess) +#ifdef __cplusplus +} // TVM_EXTERN_C +#endif diff --git a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c b/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c index a8c600ed347b2..9fabce6bdc1e5 100644 --- a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c +++ b/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c @@ -32,10 +32,11 @@ extern "C" { #include #include -void *(*TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = - (void *(*)(int, int, uint64_t, int, int)) NULL; -int (*TVMBackendFreeWorkspace_)(int, int, void*) = (int (*)(int, int, void*)) NULL; -void (*TVMAPISetLastError_)(const char*) = (void (*)(const char*)) NULL; +// TODO(weberlo, areusch): compiler errors say volatile qualifier is discarded. +// should we just get rid of em? +void* (* volatile TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = NULL; +int (* volatile TVMBackendFreeWorkspace_)(int, int, void*) = NULL; +void (* volatile TVMAPISetLastError_)(const char*) = NULL; void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { @@ -51,6 +52,41 @@ void TVMAPISetLastError(const char* msg) { (*TVMAPISetLastError_)(msg); } +void *memset(void *s, int c, size_t n) { + char *p = (char*) s; // NOLINT(readability/casting): linter is configured for c++ + while (n > 0) { + *p = (char) c; // NOLINT(readability/casting): linter is configured for c++ + p++; + n--; + } + return s; +} + +void *memmove(void *to, const void *from, size_t n) { + // TODO(weberlo, areusch): will need to factor memmove calls into workspace size calculation + // NOLINTNEXTLINE(readability/casting): linter is configured for c++ + char *temp = (char*) TVMBackendAllocWorkspace(1, 1, (uint64_t) n, 2, 8); + if (temp == NULL) { + return NULL; + } + + const char *from_pp = (char*) from; // NOLINT(readability/casting): linter is configured for c++ + for (size_t i = 0; i < n; i++) { + temp[i] = from_pp[i]; + } + char *to_pp = (char*) to; // NOLINT(readability/casting): linter is configured for c++ + for (size_t i = 0; i < n; i++) { + to_pp[i] = temp[i]; + } + + // NOLINTNEXTLINE(readability/casting): linter is configured for c++ + if (TVMBackendFreeWorkspace(1, (uint64_t) 1, (void*) temp) != 0) { + return NULL; + } + + return to; +} + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/src/runtime/micro/host_driven/utvm_runtime.c b/src/runtime/micro/host_driven/utvm_runtime.c index a4de495a185c0..2f2f0c1e0dea9 100644 --- a/src/runtime/micro/host_driven/utvm_runtime.c +++ b/src/runtime/micro/host_driven/utvm_runtime.c @@ -34,89 +34,148 @@ extern "C" { #include "utvm_runtime.h" -// Task pointers must be patched before calling a function. -UTVMTask utvm_task = { - .func = NULL, - .arg_values = NULL, - .arg_type_codes = NULL, - .num_args = 0, -}; - -size_t utvm_word_size = 0; // NOLINT(*) +// TODO(weberlo, areusch): move defines into header +// TODO(weberlo, areusch): unify TASK_QUEUE_SIZE and MicroSession::kTaskQueueCapacity. +#define TASK_QUEUE_SIZE 20 +volatile UTVMTask utvm_tasks[TASK_QUEUE_SIZE] = { }; +volatile uint32_t utvm_num_tasks = 0; +volatile uint32_t utvm_task_times[TASK_QUEUE_SIZE] = { }; // These pointers are patched at load time to point to the workspace section. -char* utvm_workspace_start = NULL; // NOLINT(*) -char* utvm_workspace_end = NULL; // NOLINT(*) -char* utvm_workspace_curr = NULL; // NOLINT(*) +volatile char* utvm_workspace_start = NULL; // NOLINT(*) +volatile char* utvm_workspace_end = NULL; // NOLINT(*) +volatile char* utvm_workspace_curr = NULL; // NOLINT(*) +#define MAX_WS_ALLOCS 10 +volatile char* utvm_alloc_ends[MAX_WS_ALLOCS] = {}; // NOLINT(*) +volatile uint32_t utvm_alloc_idx = 0; // Keep track of how many active allocations there are on the workspace. -size_t utvm_num_active_allocs = 0; +volatile uint32_t utvm_num_active_allocs = 0; + +volatile uint32_t utvm_word_size = 0; -const char* utvm_last_error = NULL; // NOLINT(*) -int32_t utvm_return_code = 0; // NOLINT(*) +volatile int32_t utvm_last_error = 0; // NOLINT(*) -uint32_t utvm_task_time = 0; +volatile uint32_t utvm_done = 0; // Gets called by UTVMInit, after device-specific initialization is finished. void UTVMMain() { + utvm_done = 0; + // loss of precision should be fine here, since we only care about the lower bits + if (((uint32_t) utvm_workspace_start) % utvm_word_size) { + utvm_last_error = UTVM_ERR_WS_UNALIGNED_START; + UTVMDone(); + return; + } utvm_workspace_curr = utvm_workspace_start; utvm_num_active_allocs = 0; - utvm_last_error = NULL; // NOLINT(*) - utvm_return_code = 0; - utvm_task_time = 0; - UTVMTimerReset(); - int32_t err = UTVMTimerStart(); - if (err < 0) { - utvm_return_code = err; - UTVMDone(); + utvm_alloc_idx = 0; + utvm_last_error = UTVM_ERR_NOT_FINISHED; + for (uint32_t i = 0; i < utvm_num_tasks; i++) { + int32_t err = UTVM_ERR_OK; + utvm_task_times[i] = 0; + err = UTVMTimerStart(); + if (err < 0) { + utvm_last_error = err; + UTVMDone(); + return; + } + err = utvm_tasks[i].func( + (void*) utvm_tasks[i].arg_values, // NOLINT(*) + (void*) utvm_tasks[i].arg_type_codes, // NOLINT(*) + utvm_tasks[i].num_args); + if (err < 0) { + UTVMDone(); + return; + } + utvm_task_times[i] = UTVMTimerStop(&err); + if (err < 0) { + utvm_last_error = err; + UTVMDone(); + return; + } + } + if (utvm_last_error == UTVM_ERR_NOT_FINISHED) { + utvm_last_error = UTVM_ERR_OK; } - utvm_return_code = utvm_task.func( - (void*) utvm_task.arg_values, // NOLINT(*) - (void*) utvm_task.arg_type_codes, // NOLINT(*) - utvm_task.num_args); - UTVMTimerStop(); - utvm_task_time = UTVMTimerRead(); UTVMDone(); } // We use a dummy function to signal execution is finished for device // backends which require breakpoints. -void UTVMDone() { } +void __attribute__((noinline)) UTVMDone() { + utvm_done = 1; +} + +#define ALIGNED_UP(x, word_size) \ + ((((word_size) - (((uintptr_t) (x)) % (word_size))) % (word_size)) + (x)) void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { - // Align up to 8 bytes. - utvm_workspace_curr += - (utvm_word_size - ((uintptr_t) utvm_workspace_curr % utvm_word_size)) % utvm_word_size; // NOLINT(*) - if (utvm_workspace_curr + size > utvm_workspace_end) { + if (size == 0) { + utvm_last_error = UTVM_ERR_WS_ZERO_SIZE_ALLOC; + return NULL; + } + size_t alloc_requested_bytes = size; + size_t alloc_size_words = (alloc_requested_bytes + utvm_word_size - 1) / utvm_word_size; + size_t alloc_size_bytes = alloc_size_words * utvm_word_size; + + // Align up to the target word size. + if (utvm_workspace_curr + alloc_size_bytes > utvm_workspace_end) { // Out of space in workspace. + utvm_last_error = UTVM_ERR_WS_OUT_OF_SPACE; + return NULL; + } + if (utvm_alloc_idx == MAX_WS_ALLOCS - 1) { + // Exceeded number of allocs we can keep track of. + utvm_last_error = UTVM_ERR_WS_TOO_MANY_ALLOCS; return NULL; } void* ret_ptr = (void*) utvm_workspace_curr; // NOLINT(*) - utvm_workspace_curr += size; + utvm_workspace_curr = utvm_workspace_curr + alloc_size_bytes; + // store the *end* of the alloc, so we can restore the WS pointer when freeing + utvm_alloc_ends[utvm_alloc_idx] = utvm_workspace_curr; + utvm_alloc_idx++; utvm_num_active_allocs++; return ret_ptr; } int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - utvm_num_active_allocs--; - if (utvm_num_active_allocs < 0) { + // TODO(weberlo, areusch): add dev type check + if (utvm_num_active_allocs == 0) { TVMAPISetLastError("free called with no active workspace allocations"); // Reset allocations and workspace (for future task executions). utvm_num_active_allocs = 0; utvm_workspace_curr = utvm_workspace_start; + utvm_last_error = UTVM_ERR_WS_DOUBLE_FREE; return -1; - } else if (utvm_num_active_allocs == 0) { - // No more allocations. Reset workspace. - utvm_workspace_curr = utvm_workspace_start; - return 0; } else { + utvm_num_active_allocs--; + if (ptr == utvm_workspace_start) { + // it's the first allocation + utvm_alloc_ends[0] = NULL; + } else { + for (uint32_t i = utvm_alloc_idx - 1; i >= 0; i--) { + if (utvm_alloc_ends[i] == ptr) { + utvm_alloc_ends[i + 1] = NULL; + break; + } + } + } + while (utvm_alloc_idx > 0 && utvm_alloc_ends[utvm_alloc_idx - 1] == NULL) { + utvm_alloc_idx--; + } + if (utvm_alloc_idx == 0) { + utvm_workspace_curr = utvm_workspace_start; + } else { + // TODO(weberlo, areusch): could you possibly have utvm_alloc_idx pointing to a NULL entry in + // this branch? + utvm_workspace_curr = utvm_alloc_ends[utvm_alloc_idx - 1]; + } return 0; } } -void TVMAPISetLastError(const char* msg) { - utvm_last_error = msg; -} +void TVMAPISetLastError(const char* msg) { } #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/src/runtime/micro/host_driven/utvm_runtime.h b/src/runtime/micro/host_driven/utvm_runtime.h index c364ecf407929..1a4486c12d7c3 100644 --- a/src/runtime/micro/host_driven/utvm_runtime.h +++ b/src/runtime/micro/host_driven/utvm_runtime.h @@ -29,8 +29,24 @@ extern "C" { #endif #include -#include #include +#include + +/*! + * \brief TODO + */ +enum UTVMReturnCode { + UTVM_ERR_OK = 0, + UTVM_ERR_NOT_FINISHED = -1, + UTVM_ERR_TIMER_NOT_IMPLEMENTED = -2, + UTVM_ERR_TIMER_OVERFLOW = -3, + UTVM_ERR_WS_DOUBLE_FREE = -4, + UTVM_ERR_WS_OUT_OF_SPACE = -5, + UTVM_ERR_WS_TOO_MANY_ALLOCS = -6, + UTVM_ERR_WS_ZERO_SIZE_ALLOC = -7, + UTVM_ERR_WS_UNALIGNED_START = -8, + UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE = -9, +}; /*! * \brief Task structure for uTVM @@ -46,20 +62,46 @@ typedef struct { int32_t num_args; } UTVMTask; +/*! + * \brief microTVM processor startup. + * Expected to reset the stack pointer, configure any hardware required to support the CRT + * (i.e. FPU), and then jump to UTVMMain. + */ extern void UTVMInit(); -extern void UTVMTimerReset(); - +/*! + * \brief Start the on-device timer. + * \return UTVMReturnCode indicating the outcome of the operation. + */ extern int32_t UTVMTimerStart(); -extern void UTVMTimerStop(); - -extern uint32_t UTVMTimerRead(); +/*! + * \brief Stop the on-device timer. + * TODO(areusch): Use an SI specification of timer units here. + * \param err Receives a UTVMReturnCode indicating the outcome of the operation. + * \return elapsed time since UTVMTimerStart returned, in device timer ticks. + */ +extern uint32_t UTVMTimerStop(int32_t* err); +/*! + * \brief Main entry point for UTVM runtime. + * Waits for "go" signal, then executes tasks and reports result. Should never return. + */ void UTVMMain(); +/*! + * \brief Function entered when UTVMMain is complete. + * Should never return. The host sets a breakpoint here to detect end of computation. + */ void UTVMDone(); +// GCC -O3 begins to inject memset and memmove calls, so we provide impls in +// the runtime for this case and for general usage. + +void* memset(void* s, int c, size_t n); + +void* memmove(void* to, const void* from, size_t n); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/src/runtime/micro/host_low_level_device.cc b/src/runtime/micro/host_low_level_device.cc index a24994a2a0e5a..7c3e7a2abad8f 100644 --- a/src/runtime/micro/host_low_level_device.cc +++ b/src/runtime/micro/host_low_level_device.cc @@ -23,10 +23,12 @@ */ #include + #include #include -#include "micro_common.h" + #include "low_level_device.h" +#include "micro_common.h" namespace tvm { namespace runtime { @@ -43,38 +45,35 @@ class HostLowLevelDevice final : public LowLevelDevice { * \brief constructor to initialize on-host memory region to act as device * \param num_bytes size of the emulated on-device memory region */ - explicit HostLowLevelDevice(size_t num_bytes, void** base_addr) : size_(num_bytes) { + explicit HostLowLevelDevice(size_t num_bytes, TargetPtr* base_addr) : size_(num_bytes) { size_t size_in_pages = (num_bytes + kPageSize - 1) / kPageSize; // TODO(weberlo): Set permissions per section (e.g., read-write perms for // the heap, execute perms for text, etc.). int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC; int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE; base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0); - *base_addr = base_addr_; + *base_addr = + TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast(base_addr_)); } /*! * \brief destructor to deallocate on-host device region */ - virtual ~HostLowLevelDevice() { - munmap(base_addr_, size_); - } + virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); } - void Read(DevPtr addr, void* buf, size_t num_bytes) { + void Read(TargetPtr addr, void* buf, size_t num_bytes) { std::memcpy(buf, addr.cast_to(), num_bytes); } - void Write(DevPtr addr, const void* buf, size_t num_bytes) { + void Write(TargetPtr addr, const void* buf, size_t num_bytes) { std::memcpy(addr.cast_to(), buf, num_bytes); } - void Execute(DevPtr func_addr, DevPtr breakpoint_addr) { - reinterpret_cast(func_addr.value().val64)(); + void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) { + reinterpret_cast(func_addr.value().uint64())(); } - const char* device_type() const final { - return "host"; - } + const char* device_type() const final { return "host"; } private: /*! \brief base address of the micro device memory region */ @@ -83,9 +82,9 @@ class HostLowLevelDevice final : public LowLevelDevice { size_t size_; }; -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, void** base_addr) { - std::shared_ptr lld = - std::make_shared(num_bytes, base_addr); +const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, + TargetPtr* base_addr) { + std::shared_ptr lld = std::make_shared(num_bytes, base_addr); return lld; } diff --git a/src/runtime/micro/low_level_device.h b/src/runtime/micro/low_level_device.h index 3158e2fe20de5..6cc0e1dc5af06 100644 --- a/src/runtime/micro/low_level_device.h +++ b/src/runtime/micro/low_level_device.h @@ -45,9 +45,7 @@ class LowLevelDevice { * \param buffer on-host buffer to be read into * \param num_bytes number of bytes to read */ - virtual void Read(DevPtr addr, - void* buffer, - size_t num_bytes) = 0; + virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0; /*! * \brief writes num_bytes from buffer to device memory at addr @@ -55,16 +53,14 @@ class LowLevelDevice { * \param buffer host buffer to write from * \param num_bytes number of bytes to write */ - virtual void Write(DevPtr addr, - const void* buffer, - size_t num_bytes) = 0; + virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0; /*! * \brief starts execution of device at func_addr * \param func_addr offset of the init stub function * \param breakpoint_addr address at which to stop function execution */ - virtual void Execute(DevPtr func_addr, DevPtr breakpoint_addr) = 0; + virtual void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) = 0; /*! * \brief getter function for low-level device type @@ -78,7 +74,8 @@ class LowLevelDevice { * \param num_bytes size of the memory region * \param base_addr pointer to write the host device's resulting base address into */ -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, void** base_addr); +const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, + TargetPtr* base_addr); /*! * \brief connect to OpenOCD and create an OpenOCD low-level device diff --git a/src/runtime/micro/micro_common.cc b/src/runtime/micro/micro_common.cc index 632b6048b182a..020df625c7090 100644 --- a/src/runtime/micro/micro_common.cc +++ b/src/runtime/micro/micro_common.cc @@ -22,65 +22,65 @@ * \brief common utilties for uTVM */ +#include "micro_common.h" + #include #include + +#include #include -#include #include -#include -#include "micro_session.h" -#include "micro_common.h" +#include + #include "low_level_device.h" +#include "micro_session.h" namespace tvm { namespace runtime { const char* SectionToString(SectionKind section) { switch (section) { - case SectionKind::kText: return "text"; - case SectionKind::kRodata: return "rodata"; - case SectionKind::kData: return "data"; - case SectionKind::kBss: return "bss"; - case SectionKind::kArgs: return "args"; - case SectionKind::kHeap: return "heap"; - case SectionKind::kWorkspace: return "workspace"; - case SectionKind::kStack: return "stack"; - default: return ""; + case SectionKind::kText: + return "text"; + case SectionKind::kRodata: + return "rodata"; + case SectionKind::kData: + return "data"; + case SectionKind::kBss: + return "bss"; + case SectionKind::kArgs: + return "args"; + case SectionKind::kHeap: + return "heap"; + case SectionKind::kWorkspace: + return "workspace"; + case SectionKind::kStack: + return "stack"; + default: + return ""; } } -std::string RelocateBinarySections( - const std::string& binary_path, - size_t word_size, - DevPtr text_start, - DevPtr rodata_start, - DevPtr data_start, - DevPtr bss_start, - DevPtr stack_end, - const std::string& toolchain_prefix) { +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix) { const auto* f = Registry::Get("tvm_callback_relocate_binary"); - CHECK(f != nullptr) - << "Require tvm_callback_relocate_binary to exist in registry"; - std::string relocated_bin = (*f)(binary_path, - word_size, - text_start.cast_to(), - rodata_start.cast_to(), - data_start.cast_to(), - bss_start.cast_to(), - stack_end.cast_to(), - toolchain_prefix); + CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry"; + std::string relocated_bin = + (*f)(binary_path, word_size.bytes(), text_start.cast_to(), + rodata_start.cast_to(), data_start.cast_to(), + bss_start.cast_to(), stack_end.cast_to(), toolchain_prefix); return relocated_bin; } -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "ReadSection requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_read_binary_section"); - CHECK(f != nullptr) - << "Require tvm_callback_read_binary_section to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry"; TVMByteArray arr; arr.data = &binary[0]; arr.size = binary.length(); @@ -88,18 +88,15 @@ std::string ReadSection(const std::string& binary, return section_contents; } -size_t GetSectionSize(const std::string& binary_path, - SectionKind section, - const std::string& toolchain_prefix, - size_t align) { +size_t GetSectionSize(const std::string& binary_path, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "GetSectionSize requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_get_section_size"); - CHECK(f != nullptr) - << "Require tvm_callback_get_section_size to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry"; int size = (*f)(binary_path, SectionToString(section), toolchain_prefix); - return UpperAlignValue(size, align); + return UpperAlignValue(size, word_size.bytes()); } } // namespace runtime diff --git a/src/runtime/micro/micro_common.h b/src/runtime/micro/micro_common.h index 4a0189b3e89ee..4375791e02103 100644 --- a/src/runtime/micro/micro_common.h +++ b/src/runtime/micro/micro_common.h @@ -24,12 +24,12 @@ #define TVM_RUNTIME_MICRO_MICRO_COMMON_H_ #include - #include #include #include #include +#include namespace tvm { namespace runtime { @@ -52,28 +52,108 @@ enum class SectionKind : size_t { kNumKinds, }; -/*! \brief union for storing values on varying target word sizes */ -union TargetVal { - /*! \brief 32-bit pointer */ - uint32_t val32; - /*! \brief 64-bit pointer */ - uint64_t val64; +/*! \brief data type for word sizes */ +class TargetWordSize { + public: + explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} { + CHECK(word_size_bits == 32 || word_size_bits == 64) + << "only 32-bit and 64-bit are supported now"; + } + + size_t bytes() const { return word_size_bits_ / 8; } + + size_t bits() const { return word_size_bits_; } + + private: + size_t word_size_bits_; }; -/*! \brief absolute device address */ -class DevPtr { +/*! \brief class for storing values on varying target word sizes */ +class TargetVal { + private: + size_t width_bits_; + uint64_t value_; + public: - /*! \brief construct a device address with value `value` */ - explicit DevPtr(std::uintptr_t value) : value_(TargetVal { .val64 = value }) {} + /*! \brief construct a TargetVal matching the size of the given integral argument */ + template ::value, T>::type> + explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {} + + /*! \brief construct an uninitialized value */ + TargetVal() : width_bits_{0}, value_{0} {} + + /*! \brief construct a TargetVal with explicit size and value */ + TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} { + CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0) + << "width_bits must be a power of 2 in [8, 64], got " << width_bits; + value_ = value & Bitmask(); + } + + bool IsInitialized() const { return width_bits_ != 0; } + + size_t width_bits() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + return width_bits_; + } + + uint64_t Bitmask() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; - /*! \brief default constructor */ - DevPtr() : value_(TargetVal { .val64 = 0 }) {} + if (width_bits_ == 64) { + return ~0UL; + } else { + return (1UL << width_bits_) - 1; + } + } + + uint32_t uint32() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + CHECK(width_bits_ <= 32) << "TargetVal: requested 32-bit value, actual width is " + << width_bits_; + return uint32_t(value_ & Bitmask()); + } + + uint64_t uint64() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + return value_; + } + + TargetVal& operator=(const TargetVal& other) { + CHECK(other.IsInitialized()) << "Cannot assign an uninitialized TargetVal"; + + if (!IsInitialized()) { + width_bits_ = other.width_bits_; + } + + CHECK(width_bits_ >= other.width_bits_) + << "Cannot assign TargetVal with width " << other.width_bits_ + << "bits to TargetVal with width " << width_bits_ << "bits"; + + value_ = other.value_ & Bitmask(); + return *this; + } +}; + +// TODO(weberlo, areusch): just get rid of `TargetPtr`. +/*! \brief absolute device address */ +class TargetPtr { + public: + /*! \brief construct a device address with variable-length value `value` */ + TargetPtr(TargetWordSize word_size, std::uint64_t value) + : value_(TargetVal(word_size.bits(), value)) {} /*! \brief construct a null address */ - explicit DevPtr(std::nullptr_t value) : value_(TargetVal { .val64 = 0 }) {} + TargetPtr(TargetWordSize word_size, std::nullptr_t value) + : value_{TargetVal(word_size.bits(), 0)} {} + + /*! \brief construct an uninitialized pointer whose word_size can be changed once */ + TargetPtr() = default; + + /*! \brief construct a device address using the given TargetVal */ + explicit TargetPtr(const TargetVal& value) : value_{value} {} /*! \brief destructor */ - ~DevPtr() {} + ~TargetPtr() {} /*! * \brief get value of pointer @@ -86,33 +166,35 @@ class DevPtr { * \return casted result */ template - T cast_to() const { return reinterpret_cast(value_.val64); } + T cast_to() const { + return reinterpret_cast(value_.uint64()); + } /*! \brief check if location is null */ - bool operator==(std::nullptr_t) const { return value_.val64 == 0; } + bool operator==(std::nullptr_t) const { return value_.uint64() == 0; } /*! \brief check if location is not null */ - bool operator!=(std::nullptr_t) const { return value_.val64 != 0; } + bool operator!=(std::nullptr_t) const { return value_.uint64() != 0; } /*! \brief add an integer to this absolute address to get a larger absolute address */ - DevPtr operator+(size_t n) const { - return DevPtr(value_.val64 + n); + TargetPtr operator+(size_t n) const { + return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() + n); } /*! \brief mutably add an integer to this absolute address */ - DevPtr& operator+=(size_t n) { - value_.val64 += n; + TargetPtr& operator+=(size_t n) { + value_ = TargetVal(value_.width_bits(), value_.uint64() + n); return *this; } /*! \brief subtract an integer from this absolute address to get a smaller absolute address */ - DevPtr operator-(size_t n) const { - return DevPtr(value_.val64 - n); + TargetPtr operator-(size_t n) const { + return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() - n); } /*! \brief mutably subtract an integer from this absolute address */ - DevPtr& operator-=(size_t n) { - value_.val64 -= n; + TargetPtr& operator-=(size_t n) { + value_ = TargetVal(value_.width_bits(), value_.uint64() - n); return *this; } @@ -136,8 +218,8 @@ class SymbolMap { * \param binary contents of binary object file * \param toolchain_prefix prefix of compiler toolchain to use */ - SymbolMap(const std::string& binary, - const std::string& toolchain_prefix) { + SymbolMap(const std::string& binary, const std::string& toolchain_prefix, + TargetWordSize word_size) { const auto* f = Registry::Get("tvm_callback_get_symbol_map"); CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry"; TVMByteArray arr; @@ -152,7 +234,7 @@ class SymbolMap { stream >> name; stream >> std::hex >> addr; while (stream) { - map_[name] = DevPtr(addr); + map_.emplace(std::make_pair(name, TargetPtr(word_size, addr))); stream >> name; stream >> std::hex >> addr; } @@ -163,25 +245,29 @@ class SymbolMap { * \param name name of the symbol * \return on-device offset of the symbol */ - DevPtr operator[](const std::string& name) const { + TargetPtr operator[](const std::string& name) const { auto result = map_.find(name); CHECK(result != map_.end()) << "\"" << name << "\" not in symbol map"; return result->second; } - bool HasSymbol(const std::string& name) const { - return map_.find(name) != map_.end(); + bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); } + + void Dump(std::ostream& stream) const { + for (auto e : map_) { + stream << "Entry:" << e.first << std::endl; + } } private: /*! \brief backing map */ - std::unordered_map map_; + std::unordered_map map_; }; /*! \brief struct containing start and size of a device memory region */ struct DevMemRegion { /*! \brief section start offset */ - DevPtr start; + TargetPtr start; /*! \brief size of section */ size_t size; }; @@ -237,15 +323,10 @@ const char* SectionToString(SectionKind section); * \param toolchain_prefix prefix of compiler toolchain to use * \return relocated binary file contents */ -std::string RelocateBinarySections( - const std::string& binary_path, - size_t word_size, - DevPtr text_start, - DevPtr rodata_start, - DevPtr data_start, - DevPtr bss_start, - DevPtr stack_end, - const std::string& toolchain_prefix); +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix); /*! * \brief reads section from binary @@ -254,8 +335,7 @@ std::string RelocateBinarySections( * \param toolchain_prefix prefix of compiler toolchain to use * \return contents of the section */ -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix); /*! @@ -263,13 +343,11 @@ std::string ReadSection(const std::string& binary, * \param binary input binary contents * \param section section type * \param toolchain_prefix prefix of compiler toolchain to use - * \param align alignment of the returned size (default: 8) + * \param word_size word size of the target, for alignment * \return size of the section if it exists, 0 otherwise */ -size_t GetSectionSize(const std::string& binary_name, - SectionKind section, - const std::string& toolchain_prefix, - size_t align); +size_t GetSectionSize(const std::string& binary_name, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc index 3d0a6889c4f7d..68480786ac872 100644 --- a/src/runtime/micro/micro_device_api.cc +++ b/src/runtime/micro/micro_device_api.cc @@ -21,9 +21,10 @@ * \file micro_device_api.cc */ -#include -#include #include +#include +#include + #include "../workspace_pool.h" #include "micro_session.h" @@ -35,7 +36,7 @@ namespace runtime { class MicroDeviceAPI final : public DeviceAPI { public: /*! \brief constructor */ - MicroDeviceAPI() { } + MicroDeviceAPI() {} void SetDevice(TVMContext ctx) final {} @@ -45,100 +46,93 @@ class MicroDeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { ObjectPtr& session = MicroSession::Current(); - void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to(); + TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes); CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; - MicroDevSpace* dev_space = new MicroDevSpace(); - dev_space->data = data; - dev_space->session = session; - return static_cast(dev_space); + return reinterpret_cast(new MicroDevSpace{data, session}); } void FreeDataSpace(TVMContext ctx, void* ptr) final { MicroDevSpace* dev_space = static_cast(ptr); - dev_space->session->FreeInSection( - SectionKind::kHeap, DevPtr(reinterpret_cast(dev_space->data))); + dev_space->session->FreeInSection(SectionKind::kHeap, dev_space->data); delete dev_space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { std::tuple type_from_to(ctx_from.device_type, ctx_to.device_type); if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) { // Copying from the device to the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); MicroDevSpace* to_space = static_cast(const_cast(to)); CHECK(from_space->session == to_space->session) - << "attempt to copy data between different micro sessions (" - << from_space->session.get() + << "attempt to copy data between different micro sessions (" << from_space->session.get() << " != " << to_space->session.get() << ")"; CHECK(ctx_from.device_id == ctx_to.device_id) - << "can only copy between the same micro device"; + << "can only copy between the same micro device"; ObjectPtr& session = from_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); - DevPtr from_dev_addr = GetDevLoc(from_space, from_offset); - DevPtr to_dev_addr = GetDevLoc(to_space, to_offset); + TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); + TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); std::vector buffer(size); lld->Read(from_dev_addr, static_cast(buffer.data()), size); lld->Write(to_dev_addr, static_cast(buffer.data()), size); + } else if (type_from_to == std::make_tuple(kDLMicroDev, kDLCPU)) { // Reading from the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); ObjectPtr& session = from_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); - DevPtr from_dev_addr = GetDevLoc(from_space, from_offset); + TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); void* to_host_ptr = GetHostLoc(to, to_offset); lld->Read(from_dev_addr, to_host_ptr, size); + } else if (type_from_to == std::make_tuple(kDLCPU, kDLMicroDev)) { // Writing to the device. - MicroDevSpace* to_space = static_cast(const_cast(to)); ObjectPtr& session = to_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); void* from_host_ptr = GetHostLoc(from, from_offset); - DevPtr to_dev_addr = GetDevLoc(to_space, to_offset); + TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); lld->Write(to_dev_addr, from_host_ptr, size); + } else { LOG(FATAL) << "Expect copy from/to micro device or between micro device\n"; } } void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + MicroSession::Current()->FlushTaskQueue(); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { + CHECK(false) << "the on-device workspace allocator isn't aware of this function"; ObjectPtr& session = MicroSession::Current(); - void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to(); - CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace"; - MicroDevSpace* dev_space = new MicroDevSpace(); - dev_space->data = data; - dev_space->session = session; - return static_cast(dev_space); + TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size); + CHECK(data.value().uint64() != 0) + << "unable to allocate " << size << " bytes on device workspace"; + return static_cast(new MicroDevSpace{data, session}); } void FreeWorkspace(TVMContext ctx, void* data) final { + CHECK(false) << "the on-device workspace allocator isn't aware of this function"; MicroDevSpace* dev_space = static_cast(data); ObjectPtr& session = dev_space->session; - session->FreeInSection(SectionKind::kWorkspace, - DevPtr(reinterpret_cast(dev_space->data))); + session->FreeInSection(SectionKind::kWorkspace, dev_space->data); delete dev_space; } @@ -152,9 +146,7 @@ class MicroDeviceAPI final : public DeviceAPI { } private: - DevPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { - return DevPtr(reinterpret_cast(dev_space->data) + offset); - } + TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; } void* GetHostLoc(const void* ptr, size_t offset) { return reinterpret_cast(reinterpret_cast(ptr) + offset); @@ -162,10 +154,9 @@ class MicroDeviceAPI final : public DeviceAPI { }; // register device that can be obtained from Python frontend -TVM_REGISTER_GLOBAL("device_api.micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MicroDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MicroDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_module.cc b/src/runtime/micro/micro_module.cc index 50cee34be4a6f..b4770ec6f9344 100644 --- a/src/runtime/micro/micro_module.cc +++ b/src/runtime/micro/micro_module.cc @@ -21,15 +21,17 @@ * \file micro_module.cc */ -#include #include #include -#include +#include + #include -#include "micro_session.h" +#include + +#include "../pack_args.h" #include "low_level_device.h" #include "micro_common.h" -#include "../pack_args.h" +#include "micro_session.h" namespace tvm { namespace runtime { @@ -42,18 +44,17 @@ class MicroModuleNode final : public ModuleNode { ~MicroModuleNode() {} - const char* type_key() const final { - return "micro"; - } + const char* type_key() const final { return "micro"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief initializes module by establishing device connection and loads binary * \param binary_path path of the binary to be loaded */ void InitMicroModule(const std::string& binary_path) { + // std::cout << "[MicroModuleNode::InitMicroModule]" << std::endl; + // std::cout << " start" << std::endl; session_ = MicroSession::Current(); symbol_map_ = session_->LoadBinary(binary_path, true).symbol_map; } @@ -66,27 +67,25 @@ class MicroModuleNode final : public ModuleNode { class MicroWrappedFunc { public: - MicroWrappedFunc(ObjectPtr session, - DevPtr func_ptr) { + MicroWrappedFunc(ObjectPtr session, TargetPtr func_ptr) { session_ = session; func_ptr_ = func_ptr; } void operator()(TVMArgs args, TVMRetValue* rv) const { - *rv = session_->PushToExecQueue(func_ptr_, args); + session_->PushToTaskQueue(func_ptr_, args); } private: /*! \brief reference to the session for this function (to keep the session alive) */ ObjectPtr session_; /*! \brief offset of the function to be called */ - DevPtr func_ptr_; + TargetPtr func_ptr_; }; -PackedFunc MicroModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { - DevPtr func_ptr; +PackedFunc MicroModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + TargetPtr func_ptr; if (name == tvm::runtime::symbol::tvm_module_main) { if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) { func_ptr = symbol_map_[tvm::runtime::symbol::tvm_module_main]; @@ -102,10 +101,10 @@ PackedFunc MicroModuleNode::GetFunction( // register loadfile function to load module from Python frontend TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->InitMicroModule(args[0]); - *rv = runtime::Module(n); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->InitMicroModule(args[0]); + *rv = runtime::Module(n); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_section_allocator.h b/src/runtime/micro/micro_section_allocator.h index 5c75f92737ab7..5cafb41bbc4b9 100644 --- a/src/runtime/micro/micro_section_allocator.h +++ b/src/runtime/micro/micro_section_allocator.h @@ -23,7 +23,9 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ #define TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ +#include #include + #include "micro_common.h" namespace tvm { @@ -38,16 +40,18 @@ class MicroSectionAllocator { * \brief constructor that specifies section boundaries * \param region location and size of the section on the device */ - explicit MicroSectionAllocator(DevMemRegion region, size_t word_size) - : start_addr_(region.start), - size_(0), - capacity_(region.size), - word_size_(word_size) { - CHECK_EQ(start_addr_.value().val64 % word_size, 0) - << "micro section start not aligned to " << word_size << " bytes"; - CHECK_EQ(capacity_ % word_size, 0) - << "micro section end not aligned to " << word_size << " bytes"; - } + explicit MicroSectionAllocator(std::string section_name, DevMemRegion region, + TargetWordSize word_size) + : section_name_(section_name), + start_addr_(region.start), + size_(0), + capacity_(region.size), + word_size_(word_size) { + CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0) + << "micro section start not aligned to " << word_size.bytes() << " bytes"; + CHECK_EQ(capacity_ % word_size.bytes(), 0) + << "micro section end not aligned to " << word_size.bytes() << " bytes"; + } /*! * \brief destructor @@ -56,17 +60,18 @@ class MicroSectionAllocator { /*! * \brief memory allocator - * \param size size of allocated memory in bytes + * \param alloc_size size of allocated memory in bytes * \return pointer to allocated memory region in section, nullptr if out of space */ - DevPtr Allocate(size_t size) { - size_ = UpperAlignValue(size_, word_size_); + TargetPtr Allocate(size_t size) { + size_ = UpperAlignValue(size_, word_size_.bytes()); CHECK(size_ + size < capacity_) - << "cannot alloc " << size << " bytes in section with start_addr " << - start_addr_.cast_to(); - DevPtr alloc_addr = start_addr_ + size_; + << "cannot alloc " << size << " bytes in section \"" << section_name_ + << "\" (start_addr=" << start_addr_.cast_to() << ", used=" << size_ + << ", capacity=" << capacity_ << ")"; + TargetPtr alloc_addr = start_addr_ + size_; size_ += size; - alloc_map_[alloc_addr.value().val64] = size; + alloc_map_[alloc_addr.value().uint64()] = size; return alloc_addr; } @@ -75,10 +80,10 @@ class MicroSectionAllocator { * \param offs offset to allocated memory * \note simple allocator scheme, more complex versions will be implemented later */ - void Free(DevPtr addr) { - CHECK(alloc_map_.find(addr.value().val64) != alloc_map_.end()) - << "freed pointer was never allocated"; - alloc_map_.erase(addr.value().val64); + void Free(TargetPtr addr) { + CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end()) + << "freed pointer was never allocated"; + alloc_map_.erase(addr.value().uint64()); if (alloc_map_.empty()) { size_ = 0; } @@ -87,17 +92,17 @@ class MicroSectionAllocator { /*! * \brief start offset of the memory region managed by this allocator */ - DevPtr start_addr() const { return start_addr_; } + TargetPtr start_addr() const { return start_addr_; } /*! * \brief current end addr of the space being used in this memory region */ - DevPtr curr_end_addr() const { return start_addr_ + size_; } + TargetPtr curr_end_addr() const { return start_addr_ + size_; } /*! * \brief end addr of the memory region managed by this allocator */ - DevPtr max_addr() const { return start_addr_ + capacity_; } + TargetPtr max_addr() const { return start_addr_ + capacity_; } /*! * \brief size of the section @@ -110,14 +115,16 @@ class MicroSectionAllocator { size_t capacity() const { return capacity_; } private: + /*! \brief name of the section (for debugging) */ + std::string section_name_; /*! \brief start address of the section */ - DevPtr start_addr_; + TargetPtr start_addr_; /*! \brief current size of the section */ size_t size_; /*! \brief total storage capacity of the section */ size_t capacity_; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; /*! \brief allocation map for allocation sizes */ std::unordered_map alloc_map_; }; diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 4bdc8ed697975..a9efa0f567716 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -21,13 +21,19 @@ * \file micro_session.cc */ +#include "micro_session.h" + #include +#include #include + +#include +#include #include #include #include #include -#include "micro_session.h" + #include "low_level_device.h" #include "target_data_layout_encoder.h" @@ -41,164 +47,183 @@ struct TVMMicroSessionThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMMicroSessionThreadLocalStore; ObjectPtr& MicroSession::Current() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK_GT(entry->session_stack.size(), 0) << "No current session"; return entry->session_stack.top(); } void MicroSession::EnterWithScope(ObjectPtr session) { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); entry->session_stack.push(session); } void MicroSession::ExitWithScope() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK(!entry->session_stack.empty()); entry->session_stack.pop(); } -MicroSession::MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - size_t word_size, - bool thumb_mode, - const std::string& server_addr, - int port) - : toolchain_prefix_(toolchain_prefix) - , word_size_(word_size) - , thumb_mode_(thumb_mode) { - CHECK(word_size_ == 4 || word_size_ == 8) << "unsupported word size " << word_size_; +MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, + size_t text_size, uint64_t rodata_start, size_t rodata_size, + uint64_t data_start, size_t data_size, uint64_t bss_start, + size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port) + : toolchain_prefix_(toolchain_prefix), + word_size_(word_size), + thumb_mode_(thumb_mode), + use_device_timer_(use_device_timer), + batch_args_encoder_(args_size, word_size) { if (comms_method == "host") { // TODO(weberlo): move checks to python - CHECK( - text_start == 0 && - rodata_start == 0 && - data_start == 0 && - bss_start == 0 && - args_start == 0 && - heap_start == 0 && - workspace_start == 0 && - stack_start == 0) << "unable to specify section addresses for host device"; - size_t memory_size = - text_size + rodata_size + data_size + bss_size + - args_size + heap_size + workspace_size + stack_size; - void* base_addr; + CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 && + args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0) + << "unable to specify section addresses for host device"; + size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size + + workspace_size + stack_size; + TargetPtr base_addr; low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr); - CHECK_EQ(reinterpret_cast(base_addr) % word_size_, 0) - << "base address not aligned to " << word_size_ << " bytes"; - DevPtr curr_addr = DevPtr(reinterpret_cast(base_addr)); - - section_allocators_[0] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = text_size, - }, word_size_); + CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0) + << "base address not aligned to " << word_size.bytes() << " bytes"; + TargetPtr curr_addr = base_addr; + + section_allocators_[0] = std::make_shared("text", + DevMemRegion{ + .start = curr_addr, + .size = text_size, + }, + word_size_); curr_addr += text_size; - section_allocators_[1] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = rodata_size, - }, word_size_); + section_allocators_[1] = std::make_shared("rodata", + DevMemRegion{ + .start = curr_addr, + .size = rodata_size, + }, + word_size_); curr_addr += rodata_size; - section_allocators_[2] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = data_size, - }, word_size_); + section_allocators_[2] = std::make_shared("data", + DevMemRegion{ + .start = curr_addr, + .size = data_size, + }, + word_size_); curr_addr += data_size; - section_allocators_[3] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = bss_size, - }, word_size_); + section_allocators_[3] = std::make_shared("bss", + DevMemRegion{ + .start = curr_addr, + .size = bss_size, + }, + word_size_); curr_addr += bss_size; - section_allocators_[4] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = args_size, - }, word_size_); + section_allocators_[4] = std::make_shared("args", + DevMemRegion{ + .start = curr_addr, + .size = args_size, + }, + word_size_); curr_addr += args_size; - section_allocators_[5] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = heap_size, - }, word_size_); + section_allocators_[5] = std::make_shared("heap", + DevMemRegion{ + .start = curr_addr, + .size = heap_size, + }, + word_size_); curr_addr += heap_size; - section_allocators_[6] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = workspace_size, - }, word_size_); + section_allocators_[6] = std::make_shared("workspace", + DevMemRegion{ + .start = curr_addr, + .size = workspace_size, + }, + word_size_); curr_addr += workspace_size; - section_allocators_[7] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = stack_size, - }, word_size_); + section_allocators_[7] = std::make_shared("stack", + DevMemRegion{ + .start = curr_addr, + .size = stack_size, + }, + word_size_); curr_addr += stack_size; } else if (comms_method == "openocd") { low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port); - section_allocators_[0] = std::make_shared(DevMemRegion { - .start = DevPtr(text_start), - .size = text_size, - }, word_size_); - section_allocators_[1] = std::make_shared(DevMemRegion { - .start = DevPtr(rodata_start), - .size = rodata_size, - }, word_size_); - section_allocators_[2] = std::make_shared(DevMemRegion { - .start = DevPtr(data_start), - .size = data_size, - }, word_size_); - section_allocators_[3] = std::make_shared(DevMemRegion { - .start = DevPtr(bss_start), - .size = bss_size, - }, word_size_); - section_allocators_[4] = std::make_shared(DevMemRegion { - .start = DevPtr(args_start), - .size = args_size, - }, word_size_); - section_allocators_[5] = std::make_shared(DevMemRegion { - .start = DevPtr(heap_start), - .size = heap_size, - }, word_size_); - section_allocators_[6] = std::make_shared(DevMemRegion { - .start = DevPtr(workspace_start), - .size = workspace_size, - }, word_size_); - section_allocators_[7] = std::make_shared(DevMemRegion { - .start = DevPtr(stack_start), - .size = stack_size, - }, word_size_); + section_allocators_[0] = + std::make_shared("text", + DevMemRegion{ + .start = TargetPtr(word_size_, text_start), + .size = text_size, + }, + word_size_); + section_allocators_[1] = + std::make_shared("rodata", + DevMemRegion{ + .start = TargetPtr(word_size_, rodata_start), + .size = rodata_size, + }, + word_size_); + section_allocators_[2] = + std::make_shared("data", + DevMemRegion{ + .start = TargetPtr(word_size_, data_start), + .size = data_size, + }, + word_size_); + section_allocators_[3] = + std::make_shared("bss", + DevMemRegion{ + .start = TargetPtr(word_size_, bss_start), + .size = bss_size, + }, + word_size_); + section_allocators_[4] = + std::make_shared("args", + DevMemRegion{ + .start = TargetPtr(word_size_, args_start), + .size = args_size, + }, + word_size_); + section_allocators_[5] = + std::make_shared("heap", + DevMemRegion{ + .start = TargetPtr(word_size_, heap_start), + .size = heap_size, + }, + word_size_); + section_allocators_[6] = + std::make_shared("workspace", + DevMemRegion{ + .start = TargetPtr(word_size_, workspace_start), + .size = workspace_size, + }, + word_size_); + section_allocators_[7] = + std::make_shared("stack", + DevMemRegion{ + .start = TargetPtr(word_size_, stack_start), + .size = stack_size, + }, + word_size_); } else { LOG(FATAL) << "unsupported micro low-level device"; } + TargetPtr args_start_addr = GetAllocator(SectionKind::kArgs)->start_addr(); + batch_args_encoder_.set_start_addr(args_start_addr); + runtime_symbol_map_ = LoadBinary(binary_path, false).symbol_map; // Patch pointers to define the bounds of the workspace section and the word // size (for allocation alignment). std::shared_ptr ws_allocator = GetAllocator(SectionKind::kWorkspace); - TargetVal ws_start = ws_allocator->start_addr().value(); - TargetVal ws_end = ws_allocator->max_addr().value(); - TargetVal target_word_size { .val64 = word_size_ }; - if (word_size_ == 4) { - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_start.val32); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_end.val32); - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", target_word_size.val32); - } else if (word_size_ == 8) { - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_start.val64); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_end.val64); - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", target_word_size.val64); + DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_allocator->start_addr()); + DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_allocator->max_addr()); + if (word_size.bytes() == 4) { + DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint32_t(word_size.bytes())); + } else if (word_size.bytes() == 8) { + DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint64_t(word_size.bytes())); + } else { + CHECK(false) << "Unsupported word size unexpectedly here"; } } @@ -209,59 +234,114 @@ MicroSession::~MicroSession() { low_level_device_ = nullptr; } -double MicroSession::PushToExecQueue(DevPtr func_ptr, const TVMArgs& args) { +void MicroSession::PushToTaskQueue(TargetPtr func_ptr, const TVMArgs& args) { if (thumb_mode_) { + // TODO(areusch): should be |= func_ptr += 1; } + TargetVal func_dev_addr = func_ptr.value(); - // Create an allocator stream for the memory region after the most recent - // allocation in the args section. - DevPtr args_addr = GetAllocator(SectionKind::kArgs)->curr_end_addr(); - TargetDataLayoutEncoder encoder(args_addr, word_size_); - - std::tuple arg_field_addrs = EncoderAppend(&encoder, args); - - // Flush `stream` to device memory. - DevPtr stream_dev_addr = - GetAllocator(SectionKind::kArgs)->Allocate(encoder.buf_size()); - low_level_device()->Write(stream_dev_addr, - reinterpret_cast(encoder.data()), - encoder.buf_size()); - - TargetVal arg_values_dev_addr = std::get<0>(arg_field_addrs).value(); - TargetVal arg_type_codes_dev_addr = std::get<1>(arg_field_addrs).value(); - if (word_size_ == 4) { - UTVMTask32 task = { - .func = func_ptr.value().val32, - .arg_values = arg_values_dev_addr.val32, - .arg_type_codes = arg_type_codes_dev_addr.val32, - .num_args = args.num_args, - }; - // Write the task. - DevSymbolWrite(runtime_symbol_map_, "utvm_task", task); - } else if (word_size_ == 8) { - UTVMTask64 task = { - .func = func_ptr.value().val64, - .arg_values = arg_values_dev_addr.val64, - .arg_type_codes = arg_type_codes_dev_addr.val64, - .num_args = args.num_args, - }; - // Write the task. - DevSymbolWrite(runtime_symbol_map_, "utvm_task", task); + std::tuple arg_field_addrs = EncoderAppend(&batch_args_encoder_, args); + TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()}; + TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()}; + + task_queue_.push_back(DevTask{.func = func_dev_addr, + .arg_values = arg_values_dev_addr, + .arg_type_codes = arg_type_codes_dev_addr, + .num_args = args.num_args}); + + if (task_queue_.size() == MicroSession::kTaskQueueCapacity) { + FlushTaskQueue(); } +} - DevPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; - DevPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"]; +void MicroSession::FlushTaskQueue() { + if (task_queue_.size() == 0) { + // nothing to run + return; + } + if (word_size_.bytes() == 4) { + FlushTaskQueuePriv(); + } else if (word_size_.bytes() == 8) { + FlushTaskQueuePriv(); + } +} + +template +void MicroSession::FlushTaskQueuePriv() { + std::vector prepped_tasks; + for (const auto& task : task_queue_) { + prepped_tasks.push_back(T(task)); + } + + // Flush `args` to device memory. + low_level_device()->Write(batch_args_encoder_.start_addr(), + reinterpret_cast(batch_args_encoder_.data()), + batch_args_encoder_.buf_size()); + + // Flush `tasks` to device memory. + TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"]; + low_level_device()->Write(dev_tasks_addr, reinterpret_cast(prepped_tasks.data()), + prepped_tasks.size() * sizeof(T)); + DevSymbolWrite(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size()); + + TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; + TargetPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"]; if (thumb_mode_) { + // TODO(areusch): should be |= utvm_init_addr += 1; } + std::chrono::time_point tbegin, + tend; + tbegin = std::chrono::high_resolution_clock::now(); + // std::string tmp; + // while (tmp[0] != 'd' && tmp[0] != 'e') { + // std::cout << "How to proceed? [Debug / Execute] "; + // getline(std::cin, tmp); + // CHECK(std::cin.good()) << "Stdin closed"; + // tmp[0] = std::tolower(tmp[0]); + // } + // if (tmp[0] == 'd') { + // std::cout << "Launch debugger; [Enter] to resume automated execution"; + // getline(std::cin, tmp); + // } else { low_level_device()->Execute(utvm_init_addr, utvm_done_addr); + // } + tend = std::chrono::high_resolution_clock::now(); + // Check if there was an error during execution. If so, log it. CheckDeviceError(); - uint32_t task_time = DevSymbolRead(runtime_symbol_map_, "utvm_task_time"); - GetAllocator(SectionKind::kArgs)->Free(stream_dev_addr); - return static_cast(task_time); + + if (use_device_timer_) { + uint64_t sum = 0; + std::vector times; + times.resize(task_queue_.size()); + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), + task_queue_.size() * sizeof(uint32_t)); + int i = 0; + for (uint32_t time : times) { + LOG(INFO) << "Time " << i++ << ": " << time; + sum += time; + } + last_batch_time_ += static_cast(sum) / 1e3; + } else { + last_batch_time_ += + std::chrono::duration_cast>(tend - tbegin).count() * 1000; + // TODO(weberlo): Reading internal data structure is hacky. + uint64_t sum = 0; + std::vector times; + times.resize(task_queue_.size()); + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), + task_queue_.size() * sizeof(uint32_t)); + for (uint32_t time : times) { + sum += time; + } + last_batch_cycles_ += static_cast(sum); + } + + batch_args_encoder_.Clear(); + task_queue_.clear(); } BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_dylib_pointers) { @@ -270,32 +350,22 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d DevMemRegion data_section; DevMemRegion bss_section; - text_section.size = GetSectionSize( - binary_path, SectionKind::kText, toolchain_prefix_, word_size_); - rodata_section.size = GetSectionSize( - binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); - data_section.size = GetSectionSize( - binary_path, SectionKind::kData, toolchain_prefix_, word_size_); - bss_section.size = GetSectionSize( - binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); + text_section.size = + GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_); + rodata_section.size = + GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); + data_section.size = + GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_); + bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); text_section.start = AllocateInSection(SectionKind::kText, text_section.size); rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size); data_section.start = AllocateInSection(SectionKind::kData, data_section.size); bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size); - CHECK(text_section.start != nullptr && rodata_section.start != nullptr && - data_section.start != nullptr && bss_section.start != nullptr) - << "not enough space to load module on device"; std::string relocated_bin = RelocateBinarySections( - binary_path, - word_size_, - text_section.start, - rodata_section.start, - data_section.start, - bss_section.start, - GetAllocator(SectionKind::kStack)->max_addr(), - toolchain_prefix_); + binary_path, word_size_, text_section.start, rodata_section.start, data_section.start, + bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_); std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_); std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_); std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_); @@ -305,7 +375,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size); low_level_device_->Write(data_section.start, &data_contents[0], data_section.size); low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size); - SymbolMap symbol_map {relocated_bin, toolchain_prefix_}; + SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_}; if (patch_dylib_pointers) { // Patch device lib pointers. @@ -314,7 +384,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d PatchImplHole(symbol_map, "TVMAPISetLastError"); } - return BinaryInfo { + return BinaryInfo{ .text_section = text_section, .rodata_section = rodata_section, .data_section = data_section, @@ -323,8 +393,8 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d }; } -std::tuple MicroSession::EncoderAppend( - TargetDataLayoutEncoder* encoder, const TVMArgs& args) { +std::tuple MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, + const TVMArgs& args) { const int* type_codes = args.type_codes; int num_args = args.num_args; @@ -341,12 +411,13 @@ std::tuple MicroSession::EncoderAppend( // order to prevent premature session destruction. void* old_data = base_arr_handle->data; // Mutate the array to unwrap the `data` field. - base_arr_handle->data = reinterpret_cast(old_data)->data; + MicroDevSpace* dev_arr_ptr = reinterpret_cast(old_data); + base_arr_handle->data = reinterpret_cast(dev_arr_ptr->data.value().uint64()); // Now, encode the unwrapped version. void* arr_ptr = nullptr; - if (word_size_ == 4) { + if (word_size_.bytes() == 4) { arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); - } else if (word_size_ == 8) { + } else if (word_size_.bytes() == 8) { arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); } // And restore the original wrapped version. @@ -371,7 +442,7 @@ std::tuple MicroSession::EncoderAppend( } template -DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) { +TargetPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) { auto tvm_arr_slot = encoder->Alloc(); auto shape_slot = encoder->Alloc(arr.ndim); @@ -379,24 +450,19 @@ DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTen // the device first. The `data` field is already allocated on the device and // is a device pointer, so we don't need to write it. shape_slot.WriteArray(arr.shape, arr.ndim); - DevPtr shape_dev_addr = shape_slot.start_addr(); - DevPtr strides_dev_addr = DevPtr(nullptr); + TargetPtr shape_dev_addr = shape_slot.start_addr(); + TargetPtr strides_dev_addr = TargetPtr(word_size_, nullptr); if (arr.strides != nullptr) { auto stride_slot = encoder->Alloc(arr.ndim); stride_slot.WriteArray(arr.strides, arr.ndim); strides_dev_addr = stride_slot.start_addr(); } - T dev_arr( - TargetVal { .val64 = reinterpret_cast(arr.data) }, - arr.ctx, - arr.ndim, - arr.dtype, - shape_dev_addr.value(), - strides_dev_addr.value(), - TargetVal { .val64 = arr.byte_offset }); + T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast(arr.data)}, arr.ctx, arr.ndim, + arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(), + TargetVal{word_size_.bits(), arr.byte_offset}); CHECK(dev_arr.ctx.device_type == static_cast(kDLMicroDev)) - << "attempt to write DLTensor with non-micro device type"; + << "attempt to write DLTensor with non-micro device type"; // Update the device type to CPU, because from the microcontroller's // perspective, it is. dev_arr.ctx.device_type = DLDeviceType::kDLCPU; @@ -404,39 +470,69 @@ DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTen return tvm_arr_slot.start_addr(); } +// TODO(weberlo): switch over entirely to error codes that expand to error +// messages on the host side. void MicroSession::CheckDeviceError() { - int32_t return_code = DevSymbolRead(runtime_symbol_map_, "utvm_return_code"); - - if (return_code) { - std::uintptr_t last_error = - DevSymbolRead(runtime_symbol_map_, "utvm_last_error"); - std::string last_error_str; - if (last_error) { - DevPtr last_err_addr = DevPtr(last_error); - last_error_str = ReadString(last_err_addr); + int32_t last_error = DevSymbolRead(runtime_symbol_map_, "utvm_last_error"); + + if (last_error) { + if (!use_device_timer_ && + (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) { + // these errors don't matter if we're not using the on-device timer + return; + } + std::string err_msg; + switch (last_error) { + case UTVM_ERR_NOT_FINISHED: + err_msg = "execution timed out"; + break; + case UTVM_ERR_TIMER_NOT_IMPLEMENTED: + err_msg = "timer is not implemented for the target device"; + break; + case UTVM_ERR_TIMER_OVERFLOW: + // TODO(weberlo): this should be remedied by using interrupts to accumulate the + // timer into a larger datatype (ARM timers are only 24 bits) + err_msg = "timer overflowed during execution"; + break; + case UTVM_ERR_WS_DOUBLE_FREE: + err_msg = "free called with no active workspace allocations"; + break; + case UTVM_ERR_WS_OUT_OF_SPACE: + err_msg = "ran out of space in workspace section"; + break; + case UTVM_ERR_WS_TOO_MANY_ALLOCS: + err_msg = "exceeded number of allocs the runtime can keep track of"; + break; + case UTVM_ERR_WS_ZERO_SIZE_ALLOC: + err_msg = "attempt to allocate scratchpad of size zero"; + break; + case UTVM_ERR_WS_UNALIGNED_START: + err_msg = "start of workspace section is not word-aligned"; + break; + case UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE: + err_msg = "scratchpad allocation size is not a multiple of the word size"; + break; + default: + err_msg = "unknown error code"; + break; } LOG(FATAL) << "error during micro function execution:\n" - << " return code: " << std::dec << return_code << "\n" - << " dev str addr: 0x" << std::hex << last_error << "\n" - << " dev str data: " << last_error_str << std::endl; + << " error ID: " << std::dec << last_error << std::endl + << " error message: " << err_msg; } } void MicroSession::PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name) { - DevPtr runtime_impl_addr = runtime_symbol_map_[func_name]; + TargetPtr runtime_impl_addr = runtime_symbol_map_[func_name]; if (thumb_mode_) { runtime_impl_addr += 1; } std::ostringstream func_name_underscore; func_name_underscore << func_name << "_"; - if (word_size_ == 4) { - DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr.value().val32); - } else if (word_size_ == 8) { - DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr.value().val64); - } + DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr); } -std::string MicroSession::ReadString(DevPtr str_addr) { +std::string MicroSession::ReadString(TargetPtr str_addr) { std::ostringstream result; const size_t buf_size = 256; std::vector buf(buf_size, 0); @@ -454,98 +550,127 @@ std::string MicroSession::ReadString(DevPtr str_addr) { return result.str(); } -DevPtr MicroSession::AllocateInSection(SectionKind type, size_t size) { +TargetPtr MicroSession::AllocateInSection(SectionKind type, size_t size) { return GetAllocator(type)->Allocate(size); } -void MicroSession::FreeInSection(SectionKind type, DevPtr addr) { +void MicroSession::FreeInSection(SectionKind type, TargetPtr addr) { return GetAllocator(type)->Free(addr); } template T MicroSession::DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol) { - DevPtr sym_addr = symbol_map[symbol]; + TargetPtr sym_addr = symbol_map[symbol]; T result; low_level_device()->Read(sym_addr, &result, sizeof(T)); return result; } +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, + const TargetPtr& ptr) { + if (word_size_.bytes() == 4) { + DevSymbolWrite(symbol_map, symbol, ptr.value().uint32()); + } else if (word_size_.bytes() == 8) { + DevSymbolWrite(symbol_map, symbol, ptr.value().uint64()); + } else { + CHECK(false) << "Unsupported word size unexpectedly here"; + } +} + template -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, - const std::string& symbol, +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value) { - DevPtr sym_addr = symbol_map[symbol]; + TargetPtr sym_addr = symbol_map[symbol]; low_level_device()->Write(sym_addr, &value, sizeof(T)); } -PackedFunc MicroSession::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MicroSession::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "enter") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { MicroSession::EnterWithScope(GetObjectPtr(this)); }); } else if (name == "exit") { - return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { - MicroSession::ExitWithScope(); - }); + return PackedFunc( + [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); }); + // TODO(weberlo): add a `clear_batch_timer` func + } else if (name == "get_last_batch_time") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); }); + // TODO(weberlo): remove this func + } else if (name == "get_last_batch_cycles") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); }); } else { return PackedFunc(); } } +TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) { + PackedFunc pf = args[0]; + TVMContext ctx = args[1]; + uint64_t number = args[2]; + uint64_t repeat = args[3]; + + auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable { + TVMRetValue temp; + std::ostringstream os; + + for (unsigned int i = 0; i < repeat; ++i) { + // start timing + CHECK(number < MicroSession::kTaskQueueCapacity) + << "`number` must be less than uTVM task queue capacity"; + for (unsigned int j = 0; j < number; ++j) { + pf.CallPacked(args, &temp); + } + ObjectPtr session = MicroSession::Current(); + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + double time_per_batch = session->GetLastBatchTime() / number; + os.write(reinterpret_cast(&time_per_batch), sizeof(time_per_batch)); + } + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + *rv = PackedFunc(ftimer); +}); + // create micro session and low-level device from Python frontend -TVM_REGISTER_GLOBAL("micro._CreateSession") -.set_body([](TVMArgs args, TVMRetValue* rv) { - const std::string& comms_method = args[0]; - const std::string& binary_path = args[1]; - const std::string& toolchain_prefix = args[2]; - uint64_t text_start = args[3]; - size_t text_size = args[4]; - uint64_t rodata_start = args[5]; - size_t rodata_size = args[6]; - uint64_t data_start = args[7]; - size_t data_size = args[8]; - uint64_t bss_start = args[9]; - size_t bss_size = args[10]; - uint64_t args_start = args[11]; - size_t args_size = args[12]; - uint64_t heap_start = args[13]; - size_t heap_size = args[14]; - uint64_t workspace_start = args[15]; - size_t workspace_size = args[16]; - uint64_t stack_start = args[17]; - size_t stack_size = args[18]; - size_t word_size = args[19]; - bool thumb_mode = args[20]; - const std::string& server_addr = args[21]; - int port = args[22]; - ObjectPtr session = make_object( - comms_method, - binary_path, - toolchain_prefix, - text_start, - text_size, - rodata_start, - rodata_size, - data_start, - data_size, - bss_start, - bss_size, - args_start, - args_size, - heap_start, - heap_size, - workspace_start, - workspace_size, - stack_start, - stack_size, - word_size, - thumb_mode, - server_addr, - port); - *rv = Module(session); - }); +TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) { + const std::string& comms_method = args[0]; + const std::string& binary_path = args[1]; + const std::string& toolchain_prefix = args[2]; + uint64_t text_start = args[3]; + size_t text_size = uint64_t(args[4]); + uint64_t rodata_start = args[5]; + size_t rodata_size = uint64_t(args[6]); + uint64_t data_start = args[7]; + size_t data_size = uint64_t(args[8]); + uint64_t bss_start = args[9]; + size_t bss_size = uint64_t(args[10]); + uint64_t args_start = args[11]; + size_t args_size = uint64_t(args[12]); + uint64_t heap_start = args[13]; + size_t heap_size = uint64_t(args[14]); + uint64_t workspace_start = args[15]; + size_t workspace_size = uint64_t(args[16]); + uint64_t stack_start = args[17]; + size_t stack_size = uint64_t(args[18]); + TargetWordSize word_size{uint64_t(args[19])}; + bool thumb_mode = args[20]; + bool use_device_timer = args[21]; + const std::string& server_addr = args[22]; + int port = args[23]; + ObjectPtr session = make_object( + comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size, + data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size, + workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode, + use_device_timer, server_addr, port); + *rv = Module(session); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_session.h b/src/runtime/micro/micro_session.h index 9e844e8b21408..ab3afcc5bce89 100644 --- a/src/runtime/micro/micro_session.h +++ b/src/runtime/micro/micro_session.h @@ -34,24 +34,25 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_ #define TVM_RUNTIME_MICRO_MICRO_SESSION_H_ -#include "micro_common.h" -#include "micro_section_allocator.h" - -#include #include +#include #include #include +#include #include #include -#include #include "low_level_device.h" +#include "micro_common.h" +#include "micro_section_allocator.h" #include "target_data_layout_encoder.h" namespace tvm { namespace runtime { +struct DevTask; + /*! * \brief session for facilitating micro device interaction */ @@ -63,15 +64,15 @@ class MicroSession : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + // todo having this decoupled from the value in utvm_runtime.c gives me stress dreams + static const size_t kTaskQueueCapacity = 20; /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "MicroSession"; - } + const char* type_key() const final { return "MicroSession"; } /*! * \brief creates session by setting up a low-level device and initting allocators for it @@ -94,35 +95,19 @@ class MicroSession : public ModuleNode { * \param workspace_size workspace section size * \param stack_start stack section start address * \param stack_size stack section size - * \param word_size number of bytes in a word on the target device + * \param word_size_bytes number of bytes in a word on the target device * \param thumb_mode whether the target device requires a thumb-mode bit on function addresses * \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`) * \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`) */ - MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - size_t word_size, - bool thumb_mode, - const std::string& server_addr, - int port); + MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, size_t text_size, + uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size, + uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port); /*! * \brief destructor @@ -137,7 +122,19 @@ class MicroSession : public ModuleNode { * \param args args to the packed function * \return elapsed time during function execution on the device */ - double PushToExecQueue(DevPtr func, const TVMArgs& args); + void PushToTaskQueue(TargetPtr func, const TVMArgs& args); + + /*! + * \brief serialize runtime metadata to the device for enqueued tasks and execute + * \return elapsed time during function execution on the device + */ + void FlushTaskQueue(); + + /*! + * \brief TODO + */ + template + void FlushTaskQueuePriv(); /*! * \brief loads binary onto device @@ -153,36 +150,44 @@ class MicroSession : public ModuleNode { * \param size size of allocated memory in bytes * \return pointer to allocated memory region in section, nullptr if out of space */ - DevPtr AllocateInSection(SectionKind type, size_t size); + TargetPtr AllocateInSection(SectionKind type, size_t size); /*! * \brief free prior allocation from section * \param type type of section to allocate in * \param addr device address of allocated memory */ - void FreeInSection(SectionKind type, DevPtr addr); + void FreeInSection(SectionKind type, TargetPtr addr); /*! * \brief read string from device to host * \param str_addr device address of first character of string * \return host copy of device string that was read */ - std::string ReadString(DevPtr str_addr); + std::string ReadString(TargetPtr str_addr); /*! - * \brief read value of symbol from device memory - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being read from - * \return value at symbol in memory - */ + * \brief read value of symbol from device memory + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being read from + * \return value at symbol in memory + */ template T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol); /*! - * \brief write value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param value value being written into symbol + * \brief write pointer value into device memory corresponding to symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param ptr pointer value to write into symbol + */ + void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr); + + /*! + * \brief write value into device memory corresponding to symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param value value being written into symbol */ template void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value); @@ -196,6 +201,18 @@ class MicroSession : public ModuleNode { return low_level_device_; } + const double GetLastBatchTime() { + double result = last_batch_time_; + last_batch_time_ = 0.0; + return result; + } + + const double GetLastBatchCycles() { + double result = last_batch_cycles_; + last_batch_cycles_ = 0.0; + return result; + } + private: /*! \brief low-level device pointer */ std::shared_ptr low_level_device_; @@ -205,7 +222,7 @@ class MicroSession : public ModuleNode { std::shared_ptr section_allocators_[static_cast(SectionKind::kNumKinds)]; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; /*! \brief whether the target device requires a thumb-mode bit on function addresses * * ARM and other manufacturers use the lowest bit of a function address to determine @@ -213,8 +230,20 @@ class MicroSession : public ModuleNode { * results in more compact binaries. */ bool thumb_mode_; + /*! \brief TODO */ + bool use_device_timer_; /*! \brief symbol map for the device runtime */ SymbolMap runtime_symbol_map_; + /*! \brief TODO */ + std::vector task_queue_; + // TODO(weberlo): we don't even need an allocator mechanism for the args + // section. there's only ever one allocation. + /*! \brief TODO hack */ + TargetDataLayoutEncoder batch_args_encoder_; + /*! \brief TODO hack */ + double last_batch_time_; + /*! \brief TODO hack */ + double last_batch_cycles_; /*! * \brief patches a function pointer in this module to an implementation @@ -228,7 +257,8 @@ class MicroSession : public ModuleNode { * \param args args to be appended * \return device address of the allocated args */ - std::tuple EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArgs& args); + std::tuple EncoderAppend(TargetDataLayoutEncoder* encoder, + const TVMArgs& args); /*! * \brief appends a `DLTensor` to the host-side buffer of `encoder` @@ -237,7 +267,7 @@ class MicroSession : public ModuleNode { * \return device address of the allocated `DLTensor` */ template - DevPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr); + TargetPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr); /*! * \brief checks and logs if there was an error during the device's most recent execution @@ -254,15 +284,15 @@ class MicroSession : public ModuleNode { } /*! - * \brief Push a new session context onto the thread-local stack. - * The session on top of the stack is used as the current global session. - */ + * \brief Push a new session context onto the thread-local stack. + * The session on top of the stack is used as the current global session. + */ static void EnterWithScope(ObjectPtr session); /*! - * \brief Pop a session off the thread-local context stack, - * restoring the previous session as the current context. - */ + * \brief Pop a session off the thread-local context stack, + * restoring the previous session as the current context. + */ static void ExitWithScope(); }; @@ -274,7 +304,7 @@ class MicroSession : public ModuleNode { */ struct MicroDevSpace { /*! \brief data being wrapped */ - void* data; + TargetPtr data; /*! \brief shared ptr to session where this data is valid */ ObjectPtr session; }; @@ -283,26 +313,24 @@ struct MicroDevSpace { /*! \brief TVM array for serialization to 32-bit devices */ struct TVMArray32 { - TVMArray32( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.val32), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.val32), - strides(strides.val32), - pad1(0), - byte_offset(byte_offset.val32), - pad2(0) { } - - /*! \brief opaque pointer to the allocated data */ + TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data(data.uint32()), + ctx(ctx), + ndim(ndim), + pad0(0), + dtype(dtype), + shape(shape.uint32()), + strides(strides.uint32()), + pad1(0), + byte_offset(byte_offset.uint32()), + pad2(0) {} + + /*! + * \brief The opaque data pointer points to the allocated data. + * This will be CUDA device pointer or cl_mem handle in OpenCL. + * This pointer is always aligns to 256 bytes as in CUDA. + */ uint32_t data; /*! \brief The device context of the tensor */ DLContext ctx; @@ -329,24 +357,21 @@ struct TVMArray32 { /*! \brief TVM array for serialization to 64-bit devices */ struct TVMArray64 { - TVMArray64( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.val64), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.val64), - strides(strides.val64), - byte_offset(byte_offset.val64) { } - - /*! \brief opaque pointer to the allocated data */ + TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data(data.uint64()), + ctx(ctx), + ndim(ndim), + pad0(0), + dtype(dtype), + shape(shape.uint64()), + strides(strides.uint64()), + byte_offset(byte_offset.uint64()) {} + /*! + * \brief The opaque data pointer points to the allocated data. + * This will be CUDA device pointer or cl_mem handle in OpenCL. + * This pointer is always aligns to 256 bytes as in CUDA. + */ uint64_t data; /*! \brief The device context of the tensor */ DLContext ctx; @@ -367,8 +392,26 @@ struct TVMArray64 { uint64_t byte_offset; }; +/*! \brief MicroTVM task to store in task queue before specializing to word size */ +struct DevTask { + /*! \brief Pointer to function to call for this task */ + TargetVal func; + /*! \brief Array of argument values */ + TargetVal arg_values; + /*! \brief Array of type codes for each argument value */ + TargetVal arg_type_codes; + /*! \brief Number of arguments */ + int32_t num_args; +}; + /*! \brief MicroTVM task for serialization to 32-bit devices */ typedef struct StructUTVMTask32 { + StructUTVMTask32(DevTask task) + : func(task.func.uint32()), + arg_values(task.arg_values.uint32()), + arg_type_codes(task.arg_type_codes.uint32()), + num_args(task.num_args) {} + /*! \brief Pointer to function to call for this task */ uint32_t func; /*! \brief Array of argument values */ @@ -377,10 +420,16 @@ typedef struct StructUTVMTask32 { uint32_t arg_type_codes; /*! \brief Number of arguments */ int32_t num_args; -} UTVMTask32; +} StructUTVMTask32; /*! \brief MicroTVM task for serialization to 64-bit devices */ typedef struct StructUTVMTask64 { + StructUTVMTask64(DevTask task) + : func(task.func.uint64()), + arg_values(task.arg_values.uint64()), + arg_type_codes(task.arg_type_codes.uint64()), + num_args(task.num_args) {} + /*! \brief Pointer to function to call for this task */ uint64_t func; /*! \brief Array of argument values */ @@ -389,7 +438,7 @@ typedef struct StructUTVMTask64 { uint64_t arg_type_codes; /*! \brief Number of arguments */ int32_t num_args; -} UTVMTask64; +} StructUTVMTask64; } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/openocd_low_level_device.cc b/src/runtime/micro/openocd_low_level_device.cc index e5c83e590c36e..610ca8590dd12 100644 --- a/src/runtime/micro/openocd_low_level_device.cc +++ b/src/runtime/micro/openocd_low_level_device.cc @@ -20,11 +20,11 @@ /*! * \file openocd_low_level_device.cc */ -#include #include +#include -#include "micro_common.h" #include "low_level_device.h" +#include "micro_common.h" #include "tcl_socket.h" namespace tvm { @@ -40,17 +40,19 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { * \param server_addr address of the OpenOCD server to connect to * \param port port of the OpenOCD server to connect to */ - explicit OpenOCDLowLevelDevice(const std::string& server_addr, - int port) : socket_() { + explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() { server_addr_ = server_addr; port_ = port; socket_.Connect(tvm::support::SockAddr(server_addr_.c_str(), port_)); - socket_.cmd_builder() << "halt 0"; + socket_.cmd_builder() << "reset run"; + socket_.SendCommand(); + + socket_.cmd_builder() << "halt 500"; socket_.SendCommand(); } - void Read(DevPtr addr, void* buf, size_t num_bytes) { + void Read(TargetPtr addr, void* buf, size_t num_bytes) override { if (num_bytes == 0) { return; } @@ -77,18 +79,17 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.cmd_builder() << "array unset output"; socket_.SendCommand(); - socket_.cmd_builder() - << "mem2array output" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - // Round up any request sizes under a byte, since OpenOCD doesn't support - // sub-byte-sized transfers. - << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); + socket_.cmd_builder() << "mem2array output" + << " " << std::dec << kWordSize << " " + << addr.cast_to() + // Round up any request sizes under a byte, since OpenOCD doesn't + // support sub-byte-sized transfers. + << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); socket_.SendCommand(); } { - socket_.cmd_builder() << "ocd_echo $output"; + socket_.cmd_builder() << "return $output"; socket_.SendCommand(); const std::string& reply = socket_.last_reply(); @@ -101,9 +102,8 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { // The response from this command pairs indices with the contents of the // memory at that index. values >> index; - CHECK(index < num_bytes) - << "index " << index << - " out of bounds (length " << num_bytes << ")"; + CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes + << ")"; // Read the value into `curr_val`, instead of reading directly into // `buf_iter`, because otherwise it's interpreted as the ASCII value and // not the integral value. @@ -119,7 +119,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { } } - void Write(DevPtr addr, const void* buf, size_t num_bytes) { + void Write(TargetPtr addr, const void* buf, size_t num_bytes) override { if (num_bytes == 0) { return; } @@ -162,16 +162,14 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } { - socket_.cmd_builder() - << "array2mem input" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - << " " << std::dec << num_bytes; + socket_.cmd_builder() << "array2mem input" + << " " << std::dec << kWordSize << " " << addr.cast_to() << " " + << std::dec << num_bytes; socket_.SendCommand(); } } - void Execute(DevPtr func_addr, DevPtr breakpoint_addr) { + void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) override { socket_.cmd_builder() << "halt 0"; socket_.SendCommand(); @@ -193,9 +191,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } - const char* device_type() const final { - return "openocd"; - } + const char* device_type() const final { return "openocd"; } private: /*! \brief socket used to communicate with the device through Tcl */ @@ -207,18 +203,17 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { /*! \brief number of bytes in a word on the target device (64-bit) */ static const constexpr ssize_t kWordSize = 8; - // NOTE: OpenOCD will call any request larger than this constant an "absurd - // request". + // NOTE: The OS pipe buffer must be able to handle a line long enough to + // print this transfer request. /*! \brief maximum number of bytes allowed in a single memory transfer */ - static const constexpr ssize_t kMemTransferLimit = 64000; + static const constexpr ssize_t kMemTransferLimit = 8000; /*! \brief number of milliseconds to wait for function execution to halt */ - static const constexpr int kWaitTime = 10000; + static const constexpr int kWaitTime = 30000; }; const std::shared_ptr OpenOCDLowLevelDeviceCreate(const std::string& server_addr, int port) { - std::shared_ptr lld = - std::make_shared(server_addr, port); + std::shared_ptr lld = std::make_shared(server_addr, port); return lld; } diff --git a/src/runtime/micro/standalone/minimal_vector.h b/src/runtime/micro/standalone/minimal_vector.h index 4d04e526329fe..74bea06ebcfd4 100644 --- a/src/runtime/micro/standalone/minimal_vector.h +++ b/src/runtime/micro/standalone/minimal_vector.h @@ -27,7 +27,6 @@ namespace tvm { namespace micro { - // A minimal wrapper, derived from https://github.com/Robbepop/dynarray/, that // supports a minimal subset of the std::vector API with a minimized code size. template diff --git a/src/runtime/micro/standalone/utvm_graph_runtime.cc b/src/runtime/micro/standalone/utvm_graph_runtime.cc index 546ed7d4988b9..db55634de66af 100644 --- a/src/runtime/micro/standalone/utvm_graph_runtime.cc +++ b/src/runtime/micro/standalone/utvm_graph_runtime.cc @@ -20,8 +20,10 @@ #include "utvm_graph_runtime.h" #include + #include #include + #include "picojson.h" namespace tvm { diff --git a/src/runtime/micro/standalone/utvm_runtime.cc b/src/runtime/micro/standalone/utvm_runtime.cc index 418443818bf10..73d616b6d482c 100644 --- a/src/runtime/micro/standalone/utvm_runtime.cc +++ b/src/runtime/micro/standalone/utvm_runtime.cc @@ -16,15 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include "tvm/runtime/micro/standalone/utvm_runtime.h" + #include -#include "tvm/runtime/micro/standalone/utvm_runtime.h" #include "utvm_graph_runtime.h" void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) { - return new tvm::micro::MicroGraphRuntime( - std::string(json, json + json_len), - reinterpret_cast(module)); + return new tvm::micro::MicroGraphRuntime(std::string(json, json + json_len), + reinterpret_cast(module)); } void UTVMRuntimeDestroy(void* handle) { diff --git a/src/runtime/micro/standalone/utvm_runtime_api.cc b/src/runtime/micro/standalone/utvm_runtime_api.cc index 896ff578da9e9..a6ac420feec20 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.cc +++ b/src/runtime/micro/standalone/utvm_runtime_api.cc @@ -20,6 +20,7 @@ #include "utvm_runtime_api.h" #include + #include #include diff --git a/src/runtime/micro/standalone/utvm_runtime_api.h b/src/runtime/micro/standalone/utvm_runtime_api.h index 1b87052840d4d..b38aa0a47a8c7 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.h +++ b/src/runtime/micro/standalone/utvm_runtime_api.h @@ -21,6 +21,7 @@ #include #include + #include // The subset of the TVM runtime API that is implemented by the minimal runtime API. diff --git a/src/runtime/micro/target_data_layout_encoder.h b/src/runtime/micro/target_data_layout_encoder.h index e0275165e7746..97781773eeccd 100644 --- a/src/runtime/micro/target_data_layout_encoder.h +++ b/src/runtime/micro/target_data_layout_encoder.h @@ -25,12 +25,13 @@ #define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ #include + #include "host_driven/utvm_runtime.h" namespace tvm { namespace runtime { -// TODO(weberlo): Handle endianness. +// TODO(weberlo, areusch): Handle endianness. /*! * \brief data encoder for uTVM that builds a host-side buffer @@ -50,7 +51,7 @@ class TargetDataLayoutEncoder { * \param size size (in bytes) of the memory region allocated for this slot * \param start_addr start address of the slot in the device's memory */ - Slot(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, DevPtr start_addr); + Slot(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, TargetPtr start_addr); ~Slot(); @@ -71,7 +72,7 @@ class TargetDataLayoutEncoder { * \brief returns start address of the slot in device memory * \return device start address */ - DevPtr start_addr(); + TargetPtr start_addr(); /*! * \brief returns number of bytes allocated for this slot @@ -89,17 +90,19 @@ class TargetDataLayoutEncoder { /*! \brief size (in bytes) of the memory region allocated for this slot */ size_t size_; /*! \brief start address of the slot in the device's memory */ - DevPtr start_addr_; + TargetPtr start_addr_; }; /*! * \brief constructor * \param start_addr start address of the encoder in device memory */ - explicit TargetDataLayoutEncoder(DevPtr start_addr, size_t word_size) - : buf_(std::vector()), curr_offset_(0), word_size_(word_size) { - start_addr_ = DevPtr(UpperAlignValue(start_addr.value().val64, word_size_)); - } + explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size) + : buf_(std::vector()), + curr_offset_(0), + start_addr_(word_size, nullptr), + capacity_(capacity), + word_size_(word_size) {} /*! * \brief allocates a slot for `sizeof(T) * num_elems` bytes of data @@ -108,30 +111,43 @@ class TargetDataLayoutEncoder { */ template Slot Alloc(size_t num_elems = 1) { - curr_offset_ = UpperAlignValue(curr_offset_, word_size_); + curr_offset_ = UpperAlignValue(curr_offset_, word_size_.bytes()); size_t size = sizeof(T) * num_elems; if (curr_offset_ + size > buf_.size()) { buf_.resize(curr_offset_ + size); } + CHECK(buf_.size() < capacity_) << "out of space in data encoder"; size_t slot_start_offset = curr_offset_; curr_offset_ += size; - return Slot(this, slot_start_offset, size, start_addr_ + slot_start_offset); + return Slot(this, slot_start_offset, size, start_addr() + slot_start_offset); + } + + void Clear() { + buf_.clear(); + curr_offset_ = 0; } /*! * \brief returns the array backing the encoder's buffer * \return array backing the encoder's buffer */ - uint8_t* data() { - return buf_.data(); - } + uint8_t* data() { return buf_.data(); } /*! * \brief returns current size of the encoder's buffer * \return buffer size */ - size_t buf_size() { - return buf_.size(); + size_t buf_size() const { return buf_.size(); } + + TargetPtr start_addr() const { + CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized"; + return start_addr_; + } + + void set_start_addr(TargetPtr start_addr) { + CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty"; + start_addr_ = + TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes())); } private: @@ -140,16 +156,16 @@ class TargetDataLayoutEncoder { /*! \brief current offset */ size_t curr_offset_; /*! \brief start address of the encoder in device memory */ - DevPtr start_addr_; + TargetPtr start_addr_; + /*! \brief number of bytes available in device memory */ + size_t capacity_; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; }; template -TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, - size_t start_offset, - size_t size, - DevPtr start_addr) +TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, size_t start_offset, + size_t size, TargetPtr start_addr) : parent_(parent), start_offset_(start_offset), curr_offset_(0), @@ -158,7 +174,10 @@ TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, template TargetDataLayoutEncoder::Slot::~Slot() { - CHECK(curr_offset_ == size_) << "unwritten space in slot"; + // TODO(weberlo, areusch): this can mask the exception thrown by slot allocation... even though + // that doesn't make sense. + CHECK(curr_offset_ == size_) << "unwritten space in slot; curr_offset=" << curr_offset_ + << ", size=" << size_; } template @@ -177,7 +196,7 @@ void TargetDataLayoutEncoder::Slot::WriteValue(const T& val) { } template -DevPtr TargetDataLayoutEncoder::Slot::start_addr() { +TargetPtr TargetDataLayoutEncoder::Slot::start_addr() { return start_addr_; } diff --git a/src/runtime/micro/tcl_socket.cc b/src/runtime/micro/tcl_socket.cc index 64dfbf2183884..8f482b874260c 100644 --- a/src/runtime/micro/tcl_socket.cc +++ b/src/runtime/micro/tcl_socket.cc @@ -20,10 +20,10 @@ /*! * \file tcl_socket.cc */ -#include - #include "tcl_socket.h" +#include + namespace tvm { namespace runtime { @@ -33,9 +33,7 @@ TclSocket::TclSocket() { reply_buf_.reserve(kReplyBufSize); } -TclSocket::~TclSocket() { - tcp_socket_.Close(); -} +TclSocket::~TclSocket() { tcp_socket_.Close(); } void TclSocket::Connect(tvm::support::SockAddr addr) { CHECK(tcp_socket_.Connect(addr)) << "failed to connect"; @@ -45,8 +43,8 @@ void TclSocket::SendCommand() { const char terminate_token = kCommandTerminateToken; cmd_builder_ << terminate_token; std::string full_cmd = cmd_builder_.str(); - CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) - << "failed to send command"; + + CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command"; cmd_builder_.str(std::string()); reply_builder_.str(std::string()); @@ -66,8 +64,7 @@ void TclSocket::SendCommand() { CHECK(bytes_read != -1) << "failed to read command reply"; } while (last_read != terminate_token); last_reply_ = reply_builder_.str(); - CHECK_EQ(last_reply_[last_reply_.length()-1], terminate_token) - << "missing command terminator"; + CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator"; } } // namespace runtime diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h index 0b23e7f1b07f9..4aef2aef36e2e 100644 --- a/src/runtime/micro/tcl_socket.h +++ b/src/runtime/micro/tcl_socket.h @@ -66,12 +66,12 @@ class TclSocket { /* * \return string stream for current command being built - */ + */ std::ostringstream& cmd_builder() { return cmd_builder_; } /* * \return reply from most recently sent command - */ + */ const std::string& last_reply() { return last_reply_; } private: diff --git a/src/runtime/module.cc b/src/runtime/module.cc index d2ed7ff9e2b7b..19f1f39062274 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -22,10 +22,12 @@ * \brief TVM module system */ #include -#include #include -#include +#include + #include +#include + #include "file_util.h" namespace tvm { @@ -36,7 +38,7 @@ void ModuleNode::Import(Module other) { if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { - fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); + fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule"); CHECK(fimport_ != nullptr); } (*fimport_)(GetRef(this), other); @@ -55,8 +57,7 @@ void ModuleNode::Import(Module other) { stack.push_back(next); } } - CHECK(!visited.count(this)) - << "Cyclic dependency detected during import"; + CHECK(!visited.count(this)) << "Cyclic dependency detected during import"; this->imports_.emplace_back(std::move(other)); } @@ -73,25 +74,20 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) return pf; } -Module Module::LoadFromFile(const std::string& file_name, - const std::string& format) { +Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK(fmt.length() != 0) - << "Cannot deduce format of file " << file_name; + CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { fmt = "so"; } std::string load_f_name = "runtime.module.loadfile_" + fmt; const PackedFunc* f = Registry::Get(load_f_name); - CHECK(f != nullptr) - << "Loader of " << format << "(" - << load_f_name << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented."; Module m = (*f)(file_name, format); return m; } -void ModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; } @@ -114,9 +110,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { } if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); - CHECK(f != nullptr) - << "Cannot find function " << name - << " in the imported modules or global registry"; + CHECK(f != nullptr) << "Cannot find function " << name + << " in the imported modules or global registry"; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); @@ -158,36 +153,30 @@ bool RuntimeEnabled(const std::string& target) { return runtime::Registry::Get(f_name) != nullptr; } -TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled") -.set_body_typed(RuntimeEnabled); +TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); -TVM_REGISTER_GLOBAL("runtime.ModuleGetSource") -.set_body_typed([](Module mod, std::string fmt) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { return mod->GetSource(fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { return static_cast(mod->imports().size()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetImport") -.set_body_typed([](Module mod, int index) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { return mod->imports().at(index); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") -.set_body_typed(Module::LoadFromFile); +TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") -.set_body_typed([](Module mod, std::string name, std::string fmt) { - mod->SaveToFile(name, fmt); -}); + .set_body_typed([](Module mod, std::string name, std::string fmt) { + mod->SaveToFile(name, fmt); + }); TVM_REGISTER_OBJECT_TYPE(ModuleNode); } // namespace runtime diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index ac12472a903ea..d97d01b0feab9 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -22,9 +22,10 @@ * \brief NDArray container infratructure. */ #include -#include #include #include +#include + #include "runtime_base.h" extern "C" { @@ -45,9 +46,12 @@ inline void VerifyDataType(DLDataType dtype) { // allow uint1 as a special flag for bool. if (dtype.bits == 1 && dtype.code == kDLUInt) return; // allow int1/uint4/int4 - else if (dtype.bits == 1 && dtype.code == kDLInt) return; - else if (dtype.bits == 4 && dtype.code == kDLUInt) return; - else if (dtype.bits == 4 && dtype.code == kDLInt) return; + else if (dtype.bits == 1 && dtype.code == kDLInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLUInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLInt) + return; else CHECK_EQ(dtype.bits % 8, 0); } @@ -65,12 +69,10 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyFromBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - data, 0, - handle->data, static_cast(handle->byte_offset), - nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(data, 0, handle->data, static_cast(handle->byte_offset), nbytes, + cpu_ctx, handle->ctx, handle->dtype, nullptr); } void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { @@ -78,12 +80,10 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyToBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - handle->data, static_cast(handle->byte_offset), - data, 0, - nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(handle->data, static_cast(handle->byte_offset), data, 0, nbytes, + handle->ctx, cpu_ctx, handle->dtype, nullptr); } struct NDArray::Internal { @@ -93,8 +93,8 @@ struct NDArray::Internal { if (ptr->manager_ctx != nullptr) { static_cast(ptr->manager_ctx)->DecRef(); } else if (ptr->dl_tensor.data != nullptr) { - tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( - ptr->dl_tensor.ctx, ptr->dl_tensor.data); + tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx) + ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data); } delete ptr; } @@ -113,9 +113,7 @@ struct NDArray::Internal { } // Local create function which allocates tensor metadata // but does not allocate space for the data. - static NDArray Create(std::vector shape, - DLDataType dtype, - DLContext ctx) { + static NDArray Create(std::vector shape, DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); // critical zone: construct header @@ -140,13 +138,11 @@ struct NDArray::Internal { ObjectRef::FFIClearAfterMove(&arr); return handle; } - static void FFIDecRef(TVMArrayHandle tensor) { - NDArray::FFIDecRef(tensor); - } + static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); } // Container to DLManagedTensor static DLManagedTensor* ToDLPack(TVMArrayHandle handle) { - auto* from = static_cast( - reinterpret_cast(handle)); + auto* from = + static_cast(reinterpret_cast(handle)); return ToDLPack(from); } @@ -168,11 +164,9 @@ struct NDArray::Internal { NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { CHECK(data_ != nullptr); - CHECK(get_mutable()->dl_tensor.strides == nullptr) - << "Can only create view for compact tensor"; + CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx); - ret.get_mutable()->dl_tensor.byte_offset = - this->get_mutable()->dl_tensor.byte_offset; + ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset; size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor); CHECK_LE(view_size, curr_size) @@ -184,20 +178,15 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { return ret; } -DLManagedTensor* NDArray::ToDLPack() const { - return Internal::ToDLPack(get_mutable()); -} +DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, - DLDataType dtype, - DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); ret.get_mutable()->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); + DeviceAPI::Get(ret->ctx)->AllocDataSpace(ret->ctx, size, alignment, ret->dtype); return ret; } @@ -227,34 +216,26 @@ void NDArray::CopyFromBytes(const void* data, size_t nbytes) { ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes); } -void NDArray::CopyFromTo(const DLTensor* from, - DLTensor* to, - TVMStreamHandle stream) { +void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { size_t from_size = GetDataSize(*from); size_t to_size = GetDataSize(*to); - CHECK_EQ(from_size, to_size) - << "TVMArrayCopyFromTo: The size must exactly match"; + CHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match"; - CHECK(from->ctx.device_type == to->ctx.device_type - || from->ctx.device_type == kDLCPU - || to->ctx.device_type == kDLCPU - || from->ctx.device_type == kDLCPUPinned - || to->ctx.device_type == kDLCPUPinned) - << "Can not copy across different ctx types directly"; + CHECK(from->ctx.device_type == to->ctx.device_type || from->ctx.device_type == kDLCPU || + to->ctx.device_type == kDLCPU || from->ctx.device_type == kDLCPUPinned || + to->ctx.device_type == kDLCPUPinned) + << "Can not copy across different ctx types directly"; // Use the context that is *not* a cpu context to get the correct device // api manager. TVMContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; - DeviceAPI::Get(ctx)->CopyDataFromTo( - from->data, static_cast(from->byte_offset), - to->data, static_cast(to->byte_offset), - from_size, from->ctx, to->ctx, from->dtype, stream); + DeviceAPI::Get(ctx)->CopyDataFromTo(from->data, static_cast(from->byte_offset), to->data, + static_cast(to->byte_offset), from_size, from->ctx, + to->ctx, from->dtype, stream); } -std::vector NDArray::Shape() const { - return get_mutable()->shape_; -} +std::vector NDArray::Shape() const { return get_mutable()->shape_; } TVM_REGISTER_OBJECT_TYPE(NDArray::Container); @@ -273,14 +254,8 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -300,43 +275,33 @@ int TVMArrayFree(TVMArrayHandle handle) { API_END(); } -int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream) { +int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) { API_BEGIN(); NDArray::CopyFromTo(from, to, stream); API_END(); } -int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out) { +int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { API_BEGIN(); *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); API_END(); } -int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out) { +int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { API_BEGIN(); *out = NDArray::Internal::ToDLPack(from); API_END(); } -void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { - (*(dltensor->deleter))(dltensor); -} +void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { (*(dltensor->deleter))(dltensor); } -int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyFromBytes(handle, data, nbytes); API_END(); } -int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyToBytes(handle, data, nbytes); API_END(); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 03012006bd797..c8e6671d5ee66 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -21,13 +21,15 @@ * \brief Object type management system. */ #include -#include #include +#include + #include #include -#include -#include #include +#include +#include + #include "object_internal.h" #include "runtime_base.h" @@ -75,10 +77,8 @@ class TypeContext { return child_tindex == parent_tindex; } - uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, + uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); auto it = type_key2index_.find(skey); @@ -105,10 +105,8 @@ class TypeContext { allocated_tindex = static_tindex; CHECK_LT(static_tindex, type_table_.size()); CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) - << "Conflicting static index " << static_tindex - << " between " << type_table_[allocated_tindex].name - << " and " - << skey; + << "Conflicting static index " << static_tindex << " between " + << type_table_[allocated_tindex].name << " and " << skey; } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) { // allocate the slot from parent's reserved pool allocated_tindex = parent_tindex + pinfo.allocated_slots; @@ -129,8 +127,7 @@ class TypeContext { type_table_[allocated_tindex].parent_index = parent_tindex; type_table_[allocated_tindex].num_slots = num_slots; type_table_[allocated_tindex].allocated_slots = 1; - type_table_[allocated_tindex].child_slots_can_overflow = - child_slots_can_overflow; + type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; type_table_[allocated_tindex].name = skey; type_table_[allocated_tindex].name_hash = std::hash()(skey); // update the key2index mapping. @@ -140,16 +137,14 @@ class TypeContext { std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name; } size_t TypeIndex2KeyHash(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name_hash; } @@ -173,7 +168,7 @@ class TypeContext { for (const auto& info : type_table_) { if (info.index != 0 && num_children[info.index] >= min_children_count) { - std::cerr <<'[' << info.index << "] "<< info.name + std::cerr << '[' << info.index << "] " << info.name << "\tparent=" << type_table_[info.parent_index].name << "\tnum_child_slots=" << info.num_slots - 1 << "\tnum_children=" << num_children[info.index] << std::endl; @@ -198,18 +193,15 @@ class TypeContext { std::unordered_map type_key2index_; }; -uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, +uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); } bool Object::DerivedFrom(uint32_t parent_tindex) const { - return TypeContext::Global()->DerivedFrom( - this->type_index_, parent_tindex); + return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex); } std::string Object::TypeIndex2Key(uint32_t tindex) { @@ -224,14 +216,11 @@ uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } - -TVM_REGISTER_GLOBAL("runtime.ObjectHash") -.set_body_typed([](ObjectRef obj) { +TVM_REGISTER_GLOBAL("runtime.ObjectHash").set_body_typed([](ObjectRef obj) { return static_cast(ObjectHash()(obj)); }); -TVM_REGISTER_GLOBAL("runtime.DumpTypeTable") -.set_body_typed([](int min_child_count) { +TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) { TypeContext::Global()->Dump(min_child_count); }); } // namespace runtime @@ -252,7 +241,6 @@ int TVMObjectFree(TVMObjectHandle obj) { int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); - out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( - type_key); + out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 79551309d67c0..d56046cfde3ce 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -24,8 +24,9 @@ #ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ #define TVM_RUNTIME_OBJECT_INTERNAL_H_ -#include #include +#include + #include namespace tvm { @@ -68,4 +69,4 @@ class ObjectInternal { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ +#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/aocl/aocl_common.h b/src/runtime/opencl/aocl/aocl_common.h index d9251f8aaf535..1b98d4b2d221c 100644 --- a/src/runtime/opencl/aocl/aocl_common.h +++ b/src/runtime/opencl/aocl/aocl_common.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class AOCLWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for AOCL */ class AOCLThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/aocl/aocl_device_api.cc b/src/runtime/opencl/aocl/aocl_device_api.cc index 84c29eea33ec3..07057ff297161 100644 --- a/src/runtime/opencl/aocl/aocl_device_api.cc +++ b/src/runtime/opencl/aocl/aocl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file aocl_device_api.cc */ -#include #include +#include + #include "aocl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { - return AOCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); } const std::shared_ptr& AOCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -47,15 +46,12 @@ bool AOCLWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore AOCLThreadStore; -AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { - return AOCLThreadStore::Get(); -} +AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = AOCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = AOCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/aocl/aocl_module.cc b/src/runtime/opencl/aocl/aocl_module.cc index abda5b179a6a3..747188cf7b2dd 100644 --- a/src/runtime/opencl/aocl/aocl_module.cc +++ b/src/runtime/opencl/aocl/aocl_module.cc @@ -20,23 +20,24 @@ /*! * \file aocl_module.cc */ +#include "aocl_module.h" + #include #include -#include + #include #include +#include + #include "aocl_common.h" -#include "aocl_module.h" namespace tvm { namespace runtime { class AOCLModuleNode : public OpenCLModuleNode { public: - explicit AOCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit AOCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& AOCLModuleNode::GetGlobalWorkspace() return cl::AOCLWorkspace::Global(); } -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module AOCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module AOCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -66,8 +63,7 @@ Module AOCLModuleLoadFile(const std::string& file_name, return AOCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx") -.set_body_typed(AOCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx").set_body_typed(AOCLModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/aocl/aocl_module.h b/src/runtime/opencl/aocl/aocl_module.h index 70955cc655284..199a94decdd83 100644 --- a/src/runtime/opencl/aocl/aocl_module.h +++ b/src/runtime/opencl/aocl/aocl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "aocx" * \param fmap The map function information map of each function. */ -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 8f9d5d6352bac..a892bff75342a 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,10 +24,10 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ +#include #include -#include #include -#include +#include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, * hence we use 1.2 APIs, some of which are now deprecated. In order @@ -45,73 +45,120 @@ #include #endif +#include #include #include -#include -#include #include -#include "../workspace_pool.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "../workspace_pool.h" namespace tvm { namespace runtime { namespace cl { -static_assert(sizeof(cl_mem) ==sizeof(void*), - "Required to store cl_mem inside void*"); +static_assert(sizeof(cl_mem) == sizeof(void*), "Required to store cl_mem inside void*"); inline const char* CLGetErrorString(cl_int error) { switch (error) { - case CL_SUCCESS: return "CL_SUCCESS"; - case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND"; - case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE"; - case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE"; - case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; - case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES"; - case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY"; - case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE"; - case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP"; - case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH"; - case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; - case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE"; - case CL_MAP_FAILURE: return "CL_MAP_FAILURE"; - case CL_INVALID_VALUE: return "CL_INVALID_VALUE"; - case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE"; - case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM"; - case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE"; - case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT"; - case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES"; - case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE"; - case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR"; - case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT"; - case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; - case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE"; - case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER"; - case CL_INVALID_BINARY: return "CL_INVALID_BINARY"; - case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS"; - case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM"; - case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE"; - case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME"; - case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION"; - case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL"; - case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX"; - case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE"; - case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE"; - case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS"; - case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION"; - case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE"; - case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE"; - case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET"; - case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST"; - case CL_INVALID_EVENT: return "CL_INVALID_EVENT"; - case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION"; - case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT"; - case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE"; - case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; - default: return "Unknown OpenCL error code"; + case CL_SUCCESS: + return "CL_SUCCESS"; + case CL_DEVICE_NOT_FOUND: + return "CL_DEVICE_NOT_FOUND"; + case CL_DEVICE_NOT_AVAILABLE: + return "CL_DEVICE_NOT_AVAILABLE"; + case CL_COMPILER_NOT_AVAILABLE: + return "CL_COMPILER_NOT_AVAILABLE"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; + case CL_OUT_OF_RESOURCES: + return "CL_OUT_OF_RESOURCES"; + case CL_OUT_OF_HOST_MEMORY: + return "CL_OUT_OF_HOST_MEMORY"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "CL_PROFILING_INFO_NOT_AVAILABLE"; + case CL_MEM_COPY_OVERLAP: + return "CL_MEM_COPY_OVERLAP"; + case CL_IMAGE_FORMAT_MISMATCH: + return "CL_IMAGE_FORMAT_MISMATCH"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; + case CL_BUILD_PROGRAM_FAILURE: + return "CL_BUILD_PROGRAM_FAILURE"; + case CL_MAP_FAILURE: + return "CL_MAP_FAILURE"; + case CL_INVALID_VALUE: + return "CL_INVALID_VALUE"; + case CL_INVALID_DEVICE_TYPE: + return "CL_INVALID_DEVICE_TYPE"; + case CL_INVALID_PLATFORM: + return "CL_INVALID_PLATFORM"; + case CL_INVALID_DEVICE: + return "CL_INVALID_DEVICE"; + case CL_INVALID_CONTEXT: + return "CL_INVALID_CONTEXT"; + case CL_INVALID_QUEUE_PROPERTIES: + return "CL_INVALID_QUEUE_PROPERTIES"; + case CL_INVALID_COMMAND_QUEUE: + return "CL_INVALID_COMMAND_QUEUE"; + case CL_INVALID_HOST_PTR: + return "CL_INVALID_HOST_PTR"; + case CL_INVALID_MEM_OBJECT: + return "CL_INVALID_MEM_OBJECT"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; + case CL_INVALID_IMAGE_SIZE: + return "CL_INVALID_IMAGE_SIZE"; + case CL_INVALID_SAMPLER: + return "CL_INVALID_SAMPLER"; + case CL_INVALID_BINARY: + return "CL_INVALID_BINARY"; + case CL_INVALID_BUILD_OPTIONS: + return "CL_INVALID_BUILD_OPTIONS"; + case CL_INVALID_PROGRAM: + return "CL_INVALID_PROGRAM"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "CL_INVALID_PROGRAM_EXECUTABLE"; + case CL_INVALID_KERNEL_NAME: + return "CL_INVALID_KERNEL_NAME"; + case CL_INVALID_KERNEL_DEFINITION: + return "CL_INVALID_KERNEL_DEFINITION"; + case CL_INVALID_KERNEL: + return "CL_INVALID_KERNEL"; + case CL_INVALID_ARG_INDEX: + return "CL_INVALID_ARG_INDEX"; + case CL_INVALID_ARG_VALUE: + return "CL_INVALID_ARG_VALUE"; + case CL_INVALID_ARG_SIZE: + return "CL_INVALID_ARG_SIZE"; + case CL_INVALID_KERNEL_ARGS: + return "CL_INVALID_KERNEL_ARGS"; + case CL_INVALID_WORK_DIMENSION: + return "CL_INVALID_WORK_DIMENSION"; + case CL_INVALID_WORK_GROUP_SIZE: + return "CL_INVALID_WORK_GROUP_SIZE"; + case CL_INVALID_WORK_ITEM_SIZE: + return "CL_INVALID_WORK_ITEM_SIZE"; + case CL_INVALID_GLOBAL_OFFSET: + return "CL_INVALID_GLOBAL_OFFSET"; + case CL_INVALID_EVENT_WAIT_LIST: + return "CL_INVALID_EVENT_WAIT_LIST"; + case CL_INVALID_EVENT: + return "CL_INVALID_EVENT"; + case CL_INVALID_OPERATION: + return "CL_INVALID_OPERATION"; + case CL_INVALID_GL_OBJECT: + return "CL_INVALID_GL_OBJECT"; + case CL_INVALID_BUFFER_SIZE: + return "CL_INVALID_BUFFER_SIZE"; + case CL_INVALID_MIP_LEVEL: + return "CL_INVALID_MIP_LEVEL"; + default: + return "Unknown OpenCL error code"; } } @@ -119,16 +166,13 @@ inline const char* CLGetErrorString(cl_int error) { * \brief Protected OpenCL call * \param func Expression to call. */ -#define OPENCL_CHECK_ERROR(e) \ - { \ - CHECK(e == CL_SUCCESS) \ - << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ - } +#define OPENCL_CHECK_ERROR(e) \ + { CHECK(e == CL_SUCCESS) << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); } -#define OPENCL_CALL(func) \ - { \ - cl_int e = (func); \ - OPENCL_CHECK_ERROR(e); \ +#define OPENCL_CALL(func) \ + { \ + cl_int e = (func); \ + OPENCL_CHECK_ERROR(e); \ } class OpenCLThreadEntry; @@ -172,37 +216,24 @@ class OpenCLWorkspace : public DeviceAPI { // Initialzie the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { - Init("opencl", "gpu"); - } + virtual void Init() { Init("opencl", "gpu"); } // Check whether the context is OpenCL or not. - virtual bool IsOpenCLDevice(TVMContext ctx) { - return ctx.device_type == kDLOpenCL; - } + virtual bool IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == kDLOpenCL; } // get the queue of the context cl_command_queue GetQueue(TVMContext ctx) { CHECK(IsOpenCLDevice(ctx)); this->Init(); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid OpenCL device_id=" << ctx.device_id; return queues[ctx.device_id]; } // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -217,7 +248,6 @@ class OpenCLWorkspace : public DeviceAPI { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace */ class OpenCLThreadEntry { public: @@ -240,8 +270,7 @@ class OpenCLThreadEntry { context.device_id = 0; context.device_type = device_type; } - OpenCLThreadEntry() - : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} + OpenCLThreadEntry() : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} // get the global workspace static OpenCLThreadEntry* ThreadLocal(); @@ -260,10 +289,8 @@ class OpenCLModuleNode : public ModuleNode { size_t kernel_id; size_t version; }; - explicit OpenCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit OpenCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} // destructor ~OpenCLModuleNode(); @@ -275,20 +302,15 @@ class OpenCLModuleNode : public ModuleNode { const char* type_key() const final { return workspace_->type_key.c_str(); } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; std::string GetSource(const std::string& format) final; // Initialize the programs void Init(); // install a new kernel to thread local entry - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e); + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e); private: // The workspace, need to keep reference to use it in destructor. diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 99d2b0cb24e6c..6d9835e6231cc 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file opencl_device_api.cc */ -#include #include +#include + #include "opencl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { - return OpenCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); } const std::shared_ptr& OpenCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -41,23 +40,21 @@ void OpenCLWorkspace::SetDevice(TVMContext ctx) { GetThreadEntry()->context.device_id = ctx.device_id; } -void OpenCLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = static_cast(index< devices.size()); + *rv = static_cast(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { - case kExist: break; + case kExist: + break; case kMaxThreadsPerBlock: { size_t value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, - sizeof(size_t), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), + &value, nullptr)); *rv = static_cast(value); break; } @@ -72,58 +69,55 @@ void OpenCLWorkspace::GetAttr( } case kMaxSharedMemoryPerBlock: { cl_ulong value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_LOCAL_MEM_SIZE, - sizeof(cl_ulong), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), + &value, nullptr)); *rv = static_cast(value); break; } - case kComputeVersion: return; + case kComputeVersion: + return; case kDeviceName: { char value[128] = {0}; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_NAME, - sizeof(value) - 1, value, nullptr)); + OPENCL_CALL( + clGetDeviceInfo(devices[index], CL_DEVICE_NAME, sizeof(value) - 1, value, nullptr)); *rv = std::string(value); break; } case kMaxClockRate: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMultiProcessorCount: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMaxThreadDimensions: { size_t dims[3]; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, + nullptr)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); break; } - case kGcnArch: return; + case kGcnArch: + return; } } -void* OpenCLWorkspace::AllocDataSpace( - TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) { +void* OpenCLWorkspace::AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, + DLDataType type_hint) { this->Init(); CHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; - cl_mem mptr = clCreateBuffer( - this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); + cl_mem mptr = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); return mptr; } @@ -137,38 +131,27 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { OPENCL_CALL(clReleaseMemObject(mptr)); } -void OpenCLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void OpenCLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueCopyBuffer( - this->GetQueue(ctx_to), - static_cast((void*)from), // NOLINT(*) - static_cast(to), - from_offset, to_offset, size, 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(ctx_to), + static_cast((void*)from), // NOLINT(*) + static_cast(to), from_offset, to_offset, size, 0, + nullptr, nullptr)); } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) { - OPENCL_CALL(clEnqueueReadBuffer( - this->GetQueue(ctx_from), - static_cast((void*)from), // NOLINT(*) - CL_FALSE, from_offset, size, - static_cast(to) + to_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueReadBuffer(this->GetQueue(ctx_from), + static_cast((void*)from), // NOLINT(*) + CL_FALSE, from_offset, size, static_cast(to) + to_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_from))); } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueWriteBuffer( - this->GetQueue(ctx_to), - static_cast(to), - CL_FALSE, to_offset, size, - static_cast(from) + from_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueWriteBuffer(this->GetQueue(ctx_to), static_cast(to), CL_FALSE, + to_offset, size, static_cast(from) + from_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_to))); } else { LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL"; @@ -180,9 +163,7 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { OPENCL_CALL(clFinish(this->GetQueue(ctx))); } -void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return GetThreadEntry()->pool.AllocWorkspace(ctx, size); } @@ -192,12 +173,9 @@ void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) { typedef dmlc::ThreadLocalStore OpenCLThreadStore; -OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { - return OpenCLThreadStore::Get(); -} +OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { return OpenCLThreadStore::Get(); } -std::string GetPlatformInfo( - cl_platform_id pid, cl_platform_info param_name) { +std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name) { size_t ret_size; OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -206,8 +184,7 @@ std::string GetPlatformInfo( return ret; } -std::string GetDeviceInfo( - cl_device_id pid, cl_device_info param_name) { +std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) { size_t ret_size; OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -226,8 +203,7 @@ std::vector GetPlatformIDs() { return ret; } -std::vector GetDeviceIDs( - cl_platform_id pid, std::string device_type) { +std::vector GetDeviceIDs(cl_platform_id pid, std::string device_type) { cl_device_type dtype = CL_DEVICE_TYPE_ALL; if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU; if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU; @@ -241,10 +217,7 @@ std::vector GetDeviceIDs( return ret; } -bool MatchPlatformInfo( - cl_platform_id pid, - cl_platform_info param_name, - std::string value) { +bool MatchPlatformInfo(cl_platform_id pid, cl_platform_info param_name, std::string value) { if (value.length() == 0) return true; std::string param_value = GetPlatformInfo(pid, param_name); return param_value.find(value) != std::string::npos; @@ -286,25 +259,22 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic return; } cl_int err_code; - this->context = clCreateContext( - nullptr, this->devices.size(), &(this->devices[0]), - nullptr, nullptr, &err_code); + this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr, + nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); CHECK_EQ(this->queues.size(), 0U); for (size_t i = 0; i < this->devices.size(); ++i) { cl_device_id did = this->devices[i]; - this->queues.push_back( - clCreateCommandQueue(this->context, did, 0, &err_code)); + this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code)); OPENCL_CHECK_ERROR(err_code); } initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index fefde72b9508b..95d0481c31d55 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -20,13 +20,16 @@ /*! * \file opencl_module.cc */ +#include "opencl_module.h" + #include #include -#include + #include #include +#include + #include "opencl_common.h" -#include "opencl_module.h" namespace tvm { namespace runtime { @@ -34,12 +37,9 @@ namespace runtime { class OpenCLWrappedFunc { public: // initialize the OpenCL function. - void Init(OpenCLModuleNode* m, - ObjectPtr sptr, - OpenCLModuleNode::KTRefEntry entry, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) { + void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, + std::string func_name, std::vector arg_size, + const std::vector& thread_axis_tags) { w_ = m->GetGlobalWorkspace().get(); m_ = m; sptr_ = sptr; @@ -49,9 +49,7 @@ class OpenCLWrappedFunc { thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { CHECK(w_->context != nullptr) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. @@ -74,11 +72,8 @@ class OpenCLWrappedFunc { wl.work_size[i] *= wl.work_size[i + 3]; } // launch kernel - OPENCL_CALL(clEnqueueNDRangeKernel( - queue, kernel, work_dim, nullptr, - wl.work_size, - wl.work_size + 3, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueNDRangeKernel(queue, kernel, work_dim, nullptr, wl.work_size, + wl.work_size + 3, 0, nullptr, nullptr)); } private: @@ -119,12 +114,10 @@ const std::shared_ptr& OpenCLModuleNode::GetGlobalWorkspace return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -143,16 +136,13 @@ PackedFunc OpenCLModuleNode::GetFunction( } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), - name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void OpenCLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -193,10 +183,8 @@ void OpenCLModuleNode::Init() { } } -cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e) { +cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { std::lock_guard lock(build_lock_); int device_id = t->context.device_id; if (!device_built_flag_[device_id]) { @@ -210,7 +198,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, OPENCL_CHECK_ERROR(err); } } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { - const unsigned char* s = (const unsigned char *)data_.c_str(); + const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; cl_device_id dev = w->devices[device_id]; @@ -226,11 +214,9 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, if (err != CL_SUCCESS) { size_t len; std::string log; - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); log.resize(len); - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); LOG(FATAL) << "OpenCL build error for device=" << dev << log; } device_built_flag_[device_id] = true; @@ -245,19 +231,15 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, return kernel; } -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } // Load module from module. -Module OpenCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -278,13 +260,10 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl") -.set_body_typed(OpenCLModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 3b7ebb9c16590..77f4b80107794 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/sdaccel/sdaccel_common.h b/src/runtime/opencl/sdaccel/sdaccel_common.h index 2100b50678b34..803cbe67b9a74 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_common.h +++ b/src/runtime/opencl/sdaccel/sdaccel_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class SDAccelWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for SDAccel*/ class SDAccelThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc index 59e8a25c834ee..6bac0c916aade 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,26 +20,23 @@ /*! * \file sdaccel_device_api.cc */ -#include #include +#include + #include "sdaccel_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { - return SDAccelThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); } const std::shared_ptr& SDAccelWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); return inst; } -void SDAccelWorkspace::Init() { - OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); -} +void SDAccelWorkspace::Init() { OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); } bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == static_cast(kDLSDAccel); @@ -47,15 +44,12 @@ bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore SDAccelThreadStore; -SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { - return SDAccelThreadStore::Get(); -} +SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.sdaccel") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = SDAccelWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = SDAccelWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.cc b/src/runtime/opencl/sdaccel/sdaccel_module.cc index 4569ec3946dfa..b4edca32a9982 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_module.cc @@ -20,23 +20,24 @@ /*! * \file sdaccel_module.cc */ +#include "sdaccel_module.h" + #include #include -#include + #include #include +#include + #include "sdaccel_common.h" -#include "sdaccel_module.h" namespace tvm { namespace runtime { class SDAccelModuleNode : public OpenCLModuleNode { public: - explicit SDAccelModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit SDAccelModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& SDAccelModuleNode::GetGlobalWorkspac return cl::SDAccelWorkspace::Global(); } -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module SDAccelModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module SDAccelModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -77,10 +74,8 @@ Module SDAccelModuleLoadBinary(void* strm) { return SDAccelModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin").set_body_typed(SDAccelModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin").set_body_typed(SDAccelModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.h b/src/runtime/opencl/sdaccel/sdaccel_module.h index e126291f3f03a..322decc4460c3 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.h +++ b/src/runtime/opencl/sdaccel/sdaccel_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "xclbin", "awsxclbin" * \param fmap The map function information map of each function. */ -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ diff --git a/src/runtime/opengl/opengl_common.h b/src/runtime/opengl/opengl_common.h index 009ea6c9111d2..eca45d78e17a0 100644 --- a/src/runtime/opengl/opengl_common.h +++ b/src/runtime/opengl/opengl_common.h @@ -24,19 +24,20 @@ #ifndef TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ #define TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ +#include #include -#include #include -#include +#include #if defined(__APPLE__) #define GLFW_INCLUDE_GLCOREARB #endif #include + +#include #include #include #include #include -#include namespace tvm { namespace runtime { @@ -54,8 +55,7 @@ inline GLFWglproc GetProcAddress(const char* procname) { return proc; } -#define SetGLFunctionPointer(NAME) \ - NAME(decltype(NAME)(GetProcAddress("gl" #NAME))) +#define SetGLFunctionPointer(NAME) NAME(decltype(NAME)(GetProcAddress("gl" #NAME))) /*! * \brief The function pointers of all OpenGL APIs that are used. @@ -117,8 +117,7 @@ class GLFunctionPointers { void (*BindFramebuffer)(GLenum target, GLuint framebuffer); void (*BindTexture)(GLenum target, GLuint texture); void (*BindVertexArray)(GLuint array); - void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, - GLenum usage); + void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, GLenum usage); GLenum (*CheckFramebufferStatus)(GLenum target); void (*Clear)(GLbitfield mask); void (*CompileShader)(GLuint shader); @@ -133,8 +132,8 @@ class GLFunctionPointers { void (*DrawBuffers)(GLsizei n, const GLenum* bufs); void (*EnableVertexAttribArray)(GLuint index); void (*Finish)(); - void (*FramebufferTexture2D)(GLenum target, GLenum attachment, - GLenum textarget, GLuint texture, GLint level); + void (*FramebufferTexture2D)(GLenum target, GLenum attachment, GLenum textarget, GLuint texture, + GLint level); void (*GenBuffers)(GLsizei n, GLuint* buffers); void (*GenFramebuffers)(GLsizei n, GLuint* ids); void (*GenTextures)(GLsizei n, GLuint* textures); @@ -142,32 +141,26 @@ class GLFunctionPointers { GLint (*GetAttribLocation)(GLuint program, const GLchar* name); GLenum (*GetError)(); void (*GetIntegerv)(GLenum pname, GLint* data); - void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, - GLchar* info_log); + void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, GLchar* info_log); void (*GetProgramiv)(GLuint program, GLenum pname, GLint* params); - void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, - GLchar* info_log); + void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, GLchar* info_log); void (*GetShaderiv)(GLuint shader, GLenum pname, GLint* params); - const GLubyte *(*GetString)(GLenum name); + const GLubyte* (*GetString)(GLenum name); GLint (*GetUniformLocation)(GLuint program, const GLchar* name); void (*LinkProgram)(GLuint program); - void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, - GLenum format, GLenum type, GLvoid* data); - void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, - const GLint* length); - void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, - GLsizei width, GLsizei height, GLint border, GLenum format, - GLenum type, const GLvoid* data); + void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, GLenum format, GLenum type, + GLvoid* data); + void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, const GLint* length); + void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, GLsizei width, + GLsizei height, GLint border, GLenum format, GLenum type, const GLvoid* data); void (*TexParameteri)(GLenum target, GLenum pname, GLint param); - void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, - GLint yoffset, GLsizei width, GLsizei height, - GLenum format, GLenum type, const GLvoid* data); + void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, GLint yoffset, GLsizei width, + GLsizei height, GLenum format, GLenum type, const GLvoid* data); void (*Uniform1f)(GLint location, GLfloat v0); void (*Uniform1i)(GLint location, GLint v0); void (*UseProgram)(GLuint program); - void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, - GLboolean normalized, GLsizei stride, - const GLvoid* pointer); + void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, GLboolean normalized, + GLsizei stride, const GLvoid* pointer); void (*Viewport)(GLint x, GLint y, GLsizei width, GLsizei height); }; @@ -181,19 +174,10 @@ class OpenGLWorkspace final : public DeviceAPI { // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; @@ -225,10 +209,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param nelems The number of elements to be written to. * \param data The user data. */ - void PutTextureData(Texture* texture, - GLint begin, - GLsizei nelems, - const GLvoid* data); + void PutTextureData(Texture* texture, GLint begin, GLsizei nelems, const GLvoid* data); /*! * \brief Download a sub-region of an OpenGL texture. * \param texture The texture to download from. @@ -236,10 +217,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param nelems The number of elements to download from. * \param data The user buffer. */ - void GetTextureData(const Texture* texture, - GLint begin, - GLsizei nelems, - GLvoid* data); + void GetTextureData(const Texture* texture, GLint begin, GLsizei nelems, GLvoid* data); /*! * \brief Set currently used OpenGL program. @@ -254,10 +232,7 @@ class OpenGLWorkspace final : public DeviceAPI { * \param type The type of the uniform. * \param value The value to pass in. */ - void SetUniform(const Program& program, - const std::string& name, - DLDataType type, - void* value); + void SetUniform(const Program& program, const std::string& name, DLDataType type, void* value); /*! * \brief Set input texture for an OpenGL program. @@ -268,9 +243,7 @@ class OpenGLWorkspace final : public DeviceAPI { * different unit. * \param texture The OpenGL texture to pass in. */ - void SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, + void SetInputTexture(const Program& program, const std::string& name, GLuint unit, Texture* texture); /*! @@ -354,8 +327,7 @@ class OpenGLWorkspace final : public DeviceAPI { class Program { public: // Move constructor. - Program(Program&& other) noexcept - : workspace_(other.workspace_), program_(other.program_) { + Program(Program&& other) noexcept : workspace_(other.workspace_), program_(other.program_) { other.program_ = kInvalidProgram; } @@ -406,11 +378,14 @@ struct TextureFormat { GLsizei elemsz() const { switch (type) { - case GL_BYTE: case GL_UNSIGNED_BYTE: + case GL_BYTE: + case GL_UNSIGNED_BYTE: return 1; - case GL_SHORT: case GL_UNSIGNED_SHORT: + case GL_SHORT: + case GL_UNSIGNED_SHORT: return 2; - case GL_INT: case GL_UNSIGNED_INT: + case GL_INT: + case GL_UNSIGNED_INT: return 4; case GL_FLOAT: return 4; @@ -422,7 +397,7 @@ struct TextureFormat { bool operator==(const TextureFormat& other) const { return std::make_tuple(internal_format, format, type) == - std::make_tuple(other.internal_format, other.format, other.type); + std::make_tuple(other.internal_format, other.format, other.type); } GLint internal_format; // OpenGL says this is GLint, not GLenum. @@ -439,8 +414,11 @@ class Texture { public: // Move constructor. Texture(Texture&& other) noexcept - : workspace_(other.workspace_), texture_(other.texture_), - format_(other.format_), width_(other.width_), height_(other.height_) { + : workspace_(other.workspace_), + texture_(other.texture_), + format_(other.format_), + width_(other.width_), + height_(other.height_) { other.texture_ = kInvalidTexture; } @@ -489,11 +467,9 @@ class Texture { // We enforce this to make sure OpenGL is initialized. // Always only use the first dimension of a 2D texture. // The reason is that texelFetch only supports 2D textures. - explicit Texture(OpenGLWorkspace* workspace, GLuint texture, - TextureFormat format, - GLsizei width, GLsizei height) - : workspace_(workspace), texture_(texture), format_(format), - width_(width), height_(height) {} + explicit Texture(OpenGLWorkspace* workspace, GLuint texture, TextureFormat format, GLsizei width, + GLsizei height) + : workspace_(workspace), texture_(texture), format_(format), width_(width), height_(height) {} // The internal texture ID. GLuint texture() const { return texture_; } diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc index 0be921cb4ae58..b3e4f5924c3ac 100644 --- a/src/runtime/opengl/opengl_device_api.cc +++ b/src/runtime/opengl/opengl_device_api.cc @@ -21,7 +21,9 @@ * \file opengl_device_api.cc */ #include + #include + #include "opengl_common.h" #include "opengl_module.h" @@ -60,26 +62,23 @@ static const char* GLGetErrorString(GLenum error) { */ void OpenGLWorkspace::CheckOpenGLError() { GLenum err = gl->GetError(); - CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " - << gl::GLGetErrorString(err); + CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " << gl::GLGetErrorString(err); } /*! * \brief Protected OpenGL call. * \param func Expression to call. */ -#define OPENGL_CALL(func) \ - { \ - (func); \ - CheckOpenGLError(); \ +#define OPENGL_CALL(func) \ + { \ + (func); \ + CheckOpenGLError(); \ } /*! * \brief The error handling callback passed to GLFW. */ -void GlfwErrorCallback(int err, const char* str) { - LOG(FATAL) << "Error: [" << err << "] " << str; -} +void GlfwErrorCallback(int err, const char* str) { LOG(FATAL) << "Error: [" << err << "] " << str; } const std::shared_ptr& OpenGLWorkspace::Global() { static std::shared_ptr inst(new OpenGLWorkspace); @@ -87,13 +86,11 @@ const std::shared_ptr& OpenGLWorkspace::Global() { } void OpenGLWorkspace::SetDevice(TVMContext ctx) { - CHECK_EQ(ctx.device_type, static_cast(kOpenGL)) - << "Device type must be OpenGL."; + CHECK_EQ(ctx.device_type, static_cast(kOpenGL)) << "Device type must be OpenGL."; CHECK_EQ(ctx.device_id, 0) << "Only support 1 OpenGL \"device\"."; } -void OpenGLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenGLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { switch (kind) { case kExist: { *rv = static_cast(ctx.device_id == 0); @@ -108,20 +105,26 @@ void OpenGLWorkspace::GetAttr( *rv = 1; break; } - case kMaxSharedMemoryPerBlock: return; + case kMaxSharedMemoryPerBlock: + return; case kComputeVersion: { break; } - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kGcnArch: return; + case kDeviceName: + return; + case kMaxClockRate: + return; + case kMultiProcessorCount: + return; + case kMaxThreadDimensions: + return; + case kGcnArch: + return; } } -void* OpenGLWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { +void* OpenGLWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) { return reinterpret_cast(new Texture(CreateTexture(type_hint, nbytes))); } @@ -129,14 +132,9 @@ void OpenGLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { delete reinterpret_cast(ptr); } -void OpenGLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void OpenGLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { CHECK(stream == nullptr); @@ -159,7 +157,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from, } else if (type_from_to == std::make_tuple(gl_devtype, kDLCPU)) { auto texture = static_cast(from); - void *data = static_cast(to) + to_offset; + void* data = static_cast(to) + to_offset; auto elemsz = texture->elemsz(); auto begin = static_cast(from_offset / elemsz); auto nelems = static_cast(size / elemsz); @@ -213,8 +211,7 @@ OpenGLWorkspace::OpenGLWorkspace() { GLuint vertex_buffer; OPENGL_CALL(gl->GenBuffers(1, &vertex_buffer)); OPENGL_CALL(gl->BindBuffer(GL_ARRAY_BUFFER, vertex_buffer)); - OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, - GL_STATIC_DRAW)); + OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, GL_STATIC_DRAW)); GLuint vertex_array; OPENGL_CALL(gl->GenVertexArrays(1, &vertex_array)); @@ -244,9 +241,7 @@ void OpenGLWorkspace::OnDeleteTexture(GLuint texture) { OPENGL_CALL(gl->DeleteTextures(1, &texture)); } -void OpenGLWorkspace::OnDeleteProgram(GLuint program) { - OPENGL_CALL(gl->DeleteProgram(program)); -} +void OpenGLWorkspace::OnDeleteProgram(GLuint program) { OPENGL_CALL(gl->DeleteProgram(program)); } GLuint OpenGLWorkspace::NumTextureUnits() { GLint num_units; @@ -255,28 +250,22 @@ GLuint OpenGLWorkspace::NumTextureUnits() { } const OpenGLWorkspace::Vertex OpenGLWorkspace::vertices[OpenGLWorkspace::kNumVertices] = { - {-1.f, -1.f}, - {1.0f, -1.f}, - {1.0f, 1.0f}, - {-1.f, -1.f}, - {-1.f, 1.0f}, - {1.0f, 1.0f}, + {-1.f, -1.f}, {1.0f, -1.f}, {1.0f, 1.0f}, {-1.f, -1.f}, {-1.f, 1.0f}, {1.0f, 1.0f}, }; // Don't need to change this. // The vertex shader only needs to take in the triangle points. // No need for point transformations. -const char* OpenGLWorkspace::vertex_shader_text_ = "#version 300 es\n" +const char* OpenGLWorkspace::vertex_shader_text_ = + "#version 300 es\n" "in vec2 point; // input to vertex shader\n" "void main() {\n" " gl_Position = vec4(point, 0.0, 1.0);\n" "}\n"; -Program OpenGLWorkspace::CreateProgram( - const char* fragment_shader_src) { +Program OpenGLWorkspace::CreateProgram(const char* fragment_shader_src) { // Create and compile the shaders. - GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, - fragment_shader_src); + GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, fragment_shader_src); // Link the shaders and create the program. Program program = CreateProgram(fragment_shader); @@ -286,8 +275,7 @@ Program OpenGLWorkspace::CreateProgram( return program; } -GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, - const char* shader_src) { +GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, const char* shader_src) { // Create the shader. GLuint shader = gl->CreateShader(shader_kind); gl->ShaderSource(shader, 1, &shader_src, nullptr); @@ -367,20 +355,14 @@ Texture OpenGLWorkspace::CreateTexture(DLDataType type, size_t nbytes) { auto nelems = static_cast(nbytes / (type.bits / 8)); auto height = (nelems + kTextureRowSize - 1) / kTextureRowSize; auto width = (height == 1) ? nelems : kTextureRowSize; - OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, - texture_format.internal_format, - width, height, /*border=*/0, - texture_format.format, texture_format.type, + OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, texture_format.internal_format, width, + height, /*border=*/0, texture_format.format, texture_format.type, /*data=*/nullptr)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + OPENGL_CALL(gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); return Texture(this, texture, texture_format, width, height); } @@ -414,8 +396,8 @@ Program OpenGLWorkspace::CreateProgram(GLuint fragment_shader) { auto point_attrib = GLuint(gl->GetAttribLocation(program, "point")); OPENGL_CALL(gl->EnableVertexAttribArray(point_attrib)); - OPENGL_CALL(gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, - sizeof(Vertex), nullptr)); + OPENGL_CALL( + gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, sizeof(Vertex), nullptr)); return Program(this, program); } @@ -465,29 +447,22 @@ static void Visit1DRange(GLint beg, GLint end, F&& on_2d_block) { on_2d_block(0, ylast, xlast + 1, 1); } -void OpenGLWorkspace::PutTextureData(Texture *texture, - GLint begin, - GLsizei nelems, +void OpenGLWorkspace::PutTextureData(Texture* texture, GLint begin, GLsizei nelems, const GLvoid* data) { // Bind to temporary unit. BindTextureUnit(NumTextureUnits() - 1, texture->texture()); - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); const GLvoid* ptr = static_cast(data) + offset; // Similar to cudaMemcpy. - OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, - xbeg, ybeg, width, height, - texture->format_.format, - texture->format_.type, ptr)); + OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, xbeg, ybeg, width, height, + texture->format_.format, texture->format_.type, ptr)); }); } -void OpenGLWorkspace::GetTextureData(const Texture *texture, - GLint begin, - GLsizei nelems, +void OpenGLWorkspace::GetTextureData(const Texture* texture, GLint begin, GLsizei nelems, GLvoid* data) { BindTextureUnit(NumTextureUnits() - 1, texture->texture()); @@ -497,8 +472,8 @@ void OpenGLWorkspace::GetTextureData(const Texture *texture, OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); // Bind texture to framebuffer's attachment 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, texture->texture(), 0)); + OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + texture->texture(), 0)); // Always check that our framebuffer is okay. if (gl->CheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { @@ -521,28 +496,24 @@ void OpenGLWorkspace::GetTextureData(const Texture *texture, auto nchannels = 4; auto padded_data_size = nchannels * nelems * elemsz; auto padded_data = std::unique_ptr(new char[padded_data_size]); - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto data_offset = (ybeg * kTextureRowSize + xbeg - begin) * elemsz; auto padded_data_offset = data_offset * nchannels; - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - GL_RGBA, GL_FLOAT, + OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, GL_RGBA, GL_FLOAT, padded_data.get() + padded_data_offset)); }); for (GLsizei i = 0; i != nelems; ++i) { - auto dst = reinterpret_cast(data) + i * elemsz; + auto dst = reinterpret_cast(data) + i * elemsz; auto src = padded_data.get() + nchannels * i * elemsz; std::memcpy(dst, src, elemsz); } #else - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { + Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, GLsizei width, GLsizei height) { auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); GLvoid* ptr = static_cast(data) + offset; - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - texture->format_.format, texture->format_.type, - ptr)); + OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, texture->format_.format, + texture->format_.type, ptr)); }); #endif @@ -553,9 +524,7 @@ void OpenGLWorkspace::SetCurrentProgram(const Program& program) { OPENGL_CALL(gl->UseProgram(program.program())); } -void OpenGLWorkspace::SetUniform(const Program& program, - const std::string& name, - DLDataType type, +void OpenGLWorkspace::SetUniform(const Program& program, const std::string& name, DLDataType type, void* value) { GLint location = gl->GetUniformLocation(program.program(), name.c_str()); switch (type.code) { @@ -582,9 +551,7 @@ void OpenGLWorkspace::SetUniform(const Program& program, } } -void OpenGLWorkspace::SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, +void OpenGLWorkspace::SetInputTexture(const Program& program, const std::string& name, GLuint unit, Texture* texture) { // We always use the last texture unit as temporary. // Therefore, we can have "NumTextureUnits() - 1" input textures. @@ -602,8 +569,8 @@ void OpenGLWorkspace::Render(Texture* output) { OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); // Set "renderedTexture" as our colour attachement 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, output->texture(), 0)); + OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + output->texture(), 0)); // Specify that we will render to color attachment 0. GLenum DrawBuffers[1] = {GL_COLOR_ATTACHMENT0}; @@ -622,8 +589,7 @@ void OpenGLWorkspace::Render(Texture* output) { OPENGL_CALL(gl->DeleteFramebuffers(1, &frame_buffer)); } -TVM_REGISTER_GLOBAL("device_api.opengl") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("device_api.opengl").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = OpenGLWorkspace::Global().get(); *rv = static_cast(ptr); }); diff --git a/src/runtime/opengl/opengl_module.cc b/src/runtime/opengl/opengl_module.cc index 6435aca1bfdd6..ee490f2e72071 100644 --- a/src/runtime/opengl/opengl_module.cc +++ b/src/runtime/opengl/opengl_module.cc @@ -20,35 +20,35 @@ /*! * \file opengl_module.cc */ +#include "opengl_module.h" + #include -#include + #include -#include "opengl_common.h" -#include "opengl_module.h" +#include + +#include "../file_util.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../file_util.h" +#include "opengl_common.h" namespace tvm { namespace runtime { class OpenGLModuleNode final : public ModuleNode { public: - OpenGLModuleNode(std::unordered_map shaders, - std::string fmt, + OpenGLModuleNode(std::unordered_map shaders, std::string fmt, std::unordered_map fmap); ~OpenGLModuleNode() override = default; const char* type_key() const final { return "opengl"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; std::string GetSource(const std::string& format) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; @@ -72,11 +72,8 @@ class OpenGLModuleNode final : public ModuleNode { class OpenGLWrappedFunc { public: - OpenGLWrappedFunc(OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags); + OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr sptr, std::string func_name, + std::vector arg_size, const std::vector& thread_axis_tags); void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const; @@ -93,30 +90,32 @@ class OpenGLWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; -OpenGLModuleNode::OpenGLModuleNode( - std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap) - : workspace_(gl::OpenGLWorkspace::Global()), shaders_(std::move(shaders)), - fmt_(std::move(fmt)), fmap_(std::move(fmap)), programs_() { +OpenGLModuleNode::OpenGLModuleNode(std::unordered_map shaders, + std::string fmt, + std::unordered_map fmap) + : workspace_(gl::OpenGLWorkspace::Global()), + shaders_(std::move(shaders)), + fmt_(std::move(fmt)), + fmap_(std::move(fmap)), + programs_() { CHECK_EQ(fmt_, "gl") << "Unknown OpenGL format " << fmt_; - for (auto &pair : shaders_) { - auto &func_name = pair.first; - auto &shader = pair.second; - programs_.emplace(func_name, - workspace_->CreateProgram(shader.source.c_str())); + for (auto& pair : shaders_) { + auto& func_name = pair.first; + auto& shader = pair.second; + programs_.emplace(func_name, workspace_->CreateProgram(shader.source.c_str())); } } -PackedFunc OpenGLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenGLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto func_info_it = fmap_.find(name); - if (func_info_it == fmap_.end()) { return PackedFunc(); } - auto &func_info = func_info_it->second; + if (func_info_it == fmap_.end()) { + return PackedFunc(); + } + auto& func_info = func_info_it->second; std::vector arg_size(func_info.arg_types.size()); for (size_t i = 0; i < func_info.arg_types.size(); ++i) { @@ -128,26 +127,27 @@ PackedFunc OpenGLModuleNode::GetFunction( } // Initialize the wrapped func. - OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, - func_info.thread_axis_tags); + OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, func_info.thread_axis_tags); return PackFuncVoidAddr(f, func_info.arg_types); } std::string OpenGLModuleNode::GetSource(const std::string& format) { - if (format != fmt_ && fmt_ != "gl") { return ""; } + if (format != fmt_ && fmt_ != "gl") { + return ""; + } std::ostringstream os; - for (auto &pair : shaders_) { - auto &name = pair.first; - auto &shader = pair.second; - os << "[" << name << "]" << "\n"; - os << shader.source <<"\n"; + for (auto& pair : shaders_) { + auto& name = pair.first; + auto& shader = pair.second; + os << "[" << name << "]" + << "\n"; + os << shader.source << "\n"; } return os.str(); } -void OpenGLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void OpenGLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); @@ -161,8 +161,7 @@ void OpenGLModuleNode::SaveToBinary(dmlc::Stream* stream) { stream->Write(ToJSON(shaders_)); } -const gl::Program& OpenGLModuleNode::GetProgram( - const std::string& func_name) const { +const gl::Program& OpenGLModuleNode::GetProgram(const std::string& func_name) const { auto it = programs_.find(func_name); if (it == programs_.end()) { LOG(FATAL) << "Cannot find program"; @@ -170,8 +169,7 @@ const gl::Program& OpenGLModuleNode::GetProgram( return it->second; } -const OpenGLShader& OpenGLModuleNode::GetShader( - const std::string& func_name) const { +const OpenGLShader& OpenGLModuleNode::GetShader(const std::string& func_name) const { auto it = shaders_.find(func_name); if (it == shaders_.end()) { LOG(FATAL) << "Cannot find shader"; @@ -179,8 +177,7 @@ const OpenGLShader& OpenGLModuleNode::GetShader( return it->second; } -const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( - const std::string& func_name) const { +const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(const std::string& func_name) const { auto it = fmap_.find(func_name); if (it == fmap_.end()) { LOG(FATAL) << "Cannot find shader"; @@ -188,22 +185,20 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( return it->second; } -OpenGLWrappedFunc::OpenGLWrappedFunc( - OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) - : m_(m), sptr_(std::move(sptr)), func_name_(std::move(func_name)), +OpenGLWrappedFunc::OpenGLWrappedFunc(OpenGLModuleNode* m, ObjectPtr sptr, + std::string func_name, std::vector arg_size, + const std::vector& thread_axis_tags) + : m_(m), + sptr_(std::move(sptr)), + func_name_(std::move(func_name)), arg_size_(std::move(arg_size)) { thread_axis_cfg_.Init(arg_size_.size(), thread_axis_tags); } -void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - void** void_args) const { - auto &shader = m_->GetShader(func_name_); - auto &program = m_->GetProgram(func_name_); - auto &func_info = m_->GetFunctionInfo(func_name_); +void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { + auto& shader = m_->GetShader(func_name_); + auto& program = m_->GetProgram(func_name_); + auto& func_info = m_->GetFunctionInfo(func_name_); size_t nargs = shader.arg_kinds.size(); // Must call this function before setting uniforms & input textures. @@ -213,7 +208,7 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, GLuint texture_unit = 0; gl::Texture* output = nullptr; for (size_t i = 0; i != nargs; ++i) { - auto &name = shader.arg_names.at(i); + auto& name = shader.arg_names.at(i); auto kind = shader.arg_kinds.at(i); auto type = func_info.arg_types.at(i); switch (kind) { @@ -240,24 +235,19 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Set "thread_extent" uniform. ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); std::unique_ptr thread_extent(new GLint(wl.block_dim(0))); - m_->workspace().SetUniform(program, shader.thread_extent_var, - DLDataType{kDLInt, 32, 1}, + m_->workspace().SetUniform(program, shader.thread_extent_var, DLDataType{kDLInt, 32, 1}, static_cast(thread_extent.get())); m_->workspace().Render(output); } -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap) { - auto n = make_object(std::move(shaders), - std::move(fmt), - std::move(fmap)); + auto n = make_object(std::move(shaders), std::move(fmt), std::move(fmap)); return Module(n); } -Module OpenGLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module OpenGLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -278,20 +268,17 @@ Module OpenGLModuleLoadBinary(void* strm) { return OpenGLModuleCreate(FromJSON(data), fmt, fmap); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadFile(args[0], args[1]); +}); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadFile(args[0], args[1]); +}); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadBinary(args[0]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenGLModuleLoadBinary(args[0]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opengl/opengl_module.h b/src/runtime/opengl/opengl_module.h index 4d2d1c859253f..27841a8c60519 100644 --- a/src/runtime/opengl/opengl_module.h +++ b/src/runtime/opengl/opengl_module.h @@ -25,12 +25,14 @@ #define TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_ #include + #include #include #include -#include -#include #include +#include +#include + #include "../meta_data.h" namespace tvm { @@ -67,11 +69,10 @@ OpenGLArgKind String2OpenGLArgKind(const std::string& str); */ struct OpenGLShader { OpenGLShader() = default; - OpenGLShader(std::string source, - std::vector arg_names, - std::vector arg_kinds, - std::string thread_extent_var) - : source(std::move(source)), arg_names(std::move(arg_names)), + OpenGLShader(std::string source, std::vector arg_names, + std::vector arg_kinds, std::string thread_extent_var) + : source(std::move(source)), + arg_names(std::move(arg_names)), arg_kinds(std::move(arg_kinds)), thread_extent_var(std::move(thread_extent_var)) { CHECK_EQ(this->arg_names.size(), this->arg_kinds.size()) << "Invalid input"; @@ -96,8 +97,7 @@ std::unordered_map FromJSON(const std::string& str); * \param fmt The format of the data, * \param fmap The map function information map of each function. */ -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap); inline std::string OpenGLArgKind2String(OpenGLArgKind kind) { @@ -156,8 +156,7 @@ inline void OpenGLShader::Load(dmlc::JSONReader* reader) { } } -inline std::string ToJSON( - const std::unordered_map& shaders) { +inline std::string ToJSON(const std::unordered_map& shaders) { std::ostringstream os; dmlc::JSONWriter writer(&os); writer.BeginObject(); @@ -166,8 +165,7 @@ inline std::string ToJSON( return os.str(); } -inline std::unordered_map FromJSON( - const std::string& str) { +inline std::unordered_map FromJSON(const std::string& str) { std::unordered_map shaders; std::istringstream is(str); dmlc::JSONReader reader(&is); diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 9d24ca9072b4d..ae9771641b232 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -32,8 +32,9 @@ #define TVM_RUNTIME_PACK_ARGS_H_ #include -#include + #include +#include namespace tvm { namespace runtime { @@ -55,7 +56,7 @@ union ArgUnion { * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function only packs buffer arguments. @@ -66,7 +67,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function that takes a packed arguments. @@ -77,7 +78,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types); /*! * \brief Extract number of buffer argument from the argument types. @@ -88,23 +89,21 @@ inline size_t NumBufferArgs(const std::vector& arg_types); // implementations details namespace detail { -template +template class TempArray { public: explicit TempArray(int size) {} - T* data() { - return data_; - } + T* data() { return data_; } + private: T data_[kSize]; }; -template +template class TempArray { public: explicit TempArray(int size) : data_(size) {} - T* data() { - return data_.data(); - } + T* data() { return data_.data(); } + private: std::vector data_; }; @@ -120,8 +119,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - CHECK_EQ(t.lanes, 1U) - << "Cannot pass vector type argument to devic function for now"; + CHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; @@ -137,7 +135,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { return HANDLE_TO_HANDLE; } -template +template inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { @@ -158,7 +156,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code addr[i] = &(holder[i]); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[i].v_int64); addr[i] = &(holder[i]); break; @@ -175,9 +173,8 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code return PackedFunc(ret); } -template -inline PackedFunc PackFuncNonBufferArg_( - F f, int base, const std::vector& codes) { +template +inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { TempArray holder_(num_args); @@ -186,13 +183,14 @@ inline PackedFunc PackFuncNonBufferArg_( switch (codes[i]) { case INT64_TO_INT64: case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; break; + LOG(FATAL) << "Do not support 64bit argument to device function"; + break; } case INT64_TO_INT32: { holder[i].v_int32 = static_cast(args.values[base + i].v_int64); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); break; } @@ -201,7 +199,8 @@ inline PackedFunc PackFuncNonBufferArg_( break; } case HANDLE_TO_HANDLE: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -210,9 +209,8 @@ inline PackedFunc PackFuncNonBufferArg_( return PackedFunc(ret); } -template -inline PackedFunc PackFuncPackedArg_( - F f, const std::vector& codes) { +template +inline PackedFunc PackFuncPackedArg_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { TempArray pack_(num_args); @@ -238,20 +236,19 @@ inline PackedFunc PackFuncPackedArg_( ++ptr; break; } - case INT64_TO_UINT32 : { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_int64); + case INT64_TO_UINT32: { + *reinterpret_cast(ptr) = static_cast(args.values[i].v_int64); ++ptr; break; } case FLOAT64_TO_FLOAT32: { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_float64); + *reinterpret_cast(ptr) = static_cast(args.values[i].v_float64); ++ptr; break; } default: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -261,7 +258,7 @@ inline PackedFunc PackFuncPackedArg_( } } // namespace detail -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types) { std::vector codes(arg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -282,17 +279,17 @@ inline size_t NumBufferArgs(const std::vector& arg_types) { size_t base = arg_types.size(); for (size_t i = 0; i < arg_types.size(); ++i) { if (arg_types[i].code != kTVMOpaqueHandle) { - base = i; break; + base = i; + break; } } for (size_t i = base; i < arg_types.size(); ++i) { - CHECK(arg_types[i].code != kTVMOpaqueHandle) - << "Device function need to be organized"; + CHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized"; } return base; } -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types) { size_t num_buffer = NumBufferArgs(arg_types); std::vector codes; @@ -309,7 +306,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t } } -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types) { std::vector codes; for (size_t i = 0; i < arg_types.size(); ++i) { diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 4717d89e33c15..641532a839270 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -24,10 +24,12 @@ #include #include #include -#include -#include -#include + #include +#include +#include +#include + #include "runtime_base.h" namespace tvm { @@ -37,14 +39,13 @@ struct Registry::Manager { // map storing the functions. // We delibrately used raw pointer // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction. + // and the resource can become invalid because of indeterminstic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { // We deliberately leak the Manager instance, to avoid leak sanitizers @@ -60,20 +61,17 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) return *this; } -Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) +Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); - auto it = m->fmap.find(name); - if (it == m->fmap.end()) { - Registry* r = new Registry(); - r->name_ = name; - m->fmap[name] = r; - return *r; - } else { - CHECK(override) - << "Global PackedFunc " << name << " is already registered"; - return *it->second; + if (m->fmap.count(name)) { + CHECK(can_override) << "Global PackedFunc " << name << " is already registered"; } + + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; } bool Registry::Remove(const std::string& name) { @@ -98,7 +96,7 @@ std::vector Registry::ListNames() { std::lock_guard lock(m->mutex); std::vector keys; keys.reserve(m->fmap.size()); - for (const auto &kv : m->fmap) { + for (const auto& kv : m->fmap) { keys.push_back(kv.first); } return keys; @@ -112,14 +110,13 @@ struct TVMFuncThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; }; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; -int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override) { +int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); tvm::runtime::Registry::Register(name, override != 0) .set_body(*static_cast(f)); @@ -128,8 +125,7 @@ int TVMFuncRegisterGlobal( int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_BEGIN(); - const tvm::runtime::PackedFunc* fp = - tvm::runtime::Registry::Get(name); + const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); if (fp != nullptr) { *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) } else { @@ -138,10 +134,9 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_END(); } -int TVMFuncListGlobalNames(int *out_size, - const char*** out_array) { +int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { API_BEGIN(); - TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get(); + TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); ret->ret_vec_str = tvm::runtime::Registry::ListNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index 5d0d5c972c4b7..2e637f5496bbb 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,28 +24,28 @@ #ifndef TVM_RUNTIME_ROCM_ROCM_COMMON_H_ #define TVM_RUNTIME_ROCM_ROCM_COMMON_H_ -#include #include +#include + #include + #include "../workspace_pool.h" namespace tvm { namespace runtime { -#define ROCM_DRIVER_CALL(x) \ - { \ - hipError_t result = x; \ - if (result != hipSuccess && result != hipErrorDeinitialized) { \ - LOG(FATAL) \ - << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ - } \ +#define ROCM_DRIVER_CALL(x) \ + { \ + hipError_t result = x; \ + if (result != hipSuccess && result != hipErrorDeinitialized) { \ + LOG(FATAL) << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ + } \ } -#define ROCM_CALL(func) \ - { \ - hipError_t e = (func); \ - CHECK(e == hipSuccess) \ - << "ROCM HIP: " << hipGetErrorString(e); \ +#define ROCM_CALL(func) \ + { \ + hipError_t e = (func); \ + CHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 25e1ac70c2416..475c4fbffadc5 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -35,9 +35,7 @@ namespace runtime { class ROCMDeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { @@ -53,27 +51,26 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock, + ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); os << value << "."; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -86,23 +83,19 @@ class ROCMDeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - ROCM_CALL(hipDeviceGetAttribute( - &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); std::stringstream ss; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; @@ -132,9 +125,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipFree(ptr)); } - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -144,15 +136,12 @@ class ROCMDeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, - hip_stream); + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream); } - } else if (ctx_from.device_type == kDLROCM && - ctx_to.device_type == kDLCPU) { + } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { ROCM_CALL(hipSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); - } else if (ctx_from.device_type == kDLCPU && - ctx_to.device_type == kDLROCM) { + } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { @@ -178,14 +167,13 @@ class ROCMDeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, void* to, size_t size, - hipMemcpyKind kind, hipStream_t stream) { + static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, + hipStream_t stream) { if (stream != 0) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { @@ -198,14 +186,11 @@ typedef dmlc::ThreadLocalStore ROCMThreadStore; ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} -ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { - return ROCMThreadStore::Get(); -} +ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 1f4b830ce4340..79958d20aa1f9 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -20,19 +20,22 @@ /*! * \file rocm_module.cc */ -#include +#include "rocm_module.h" + #include -#include +#include + #include -#include #include +#include #include -#include "rocm_module.h" -#include "rocm_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "rocm_common.h" namespace tvm { namespace runtime { @@ -43,12 +46,10 @@ namespace runtime { // The modules will be lazily loaded class ROCMModuleNode : public runtime::ModuleNode { public: - explicit ROCMModuleNode(std::string data, - std::string fmt, + explicit ROCMModuleNode(std::string data, std::string fmt, std::unordered_map fmap, - std::string hip_source, - std::string assembly) - : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { + std::string hip_source, std::string assembly) + : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { std::fill(module_.begin(), module_.end(), nullptr); } // destructor @@ -61,17 +62,11 @@ class ROCMModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "hip"; - } - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + const char* type_key() const final { return "hip"; } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -87,9 +82,15 @@ class ROCMModuleNode : public runtime::ModuleNode { } std::string GetSource(const std::string& format) final { - if (format == fmt_) { return data_; } - if (format == "llvm" || format == "") { return hip_source_; } - if (format == "asm") { return assembly_; } + if (format == fmt_) { + return data_; + } + if (format == "llvm" || format == "") { + return hip_source_; + } + if (format == "asm") { + return assembly_; + } return ""; } @@ -104,16 +105,13 @@ class ROCMModuleNode : public runtime::ModuleNode { hipFunction_t func; hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != hipSuccess) { - LOG(FATAL) - << "ROCMError: hipModuleGetFunction " << func_name - << " failed with error: " << hipGetErrorString(result); + LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name + << " failed with error: " << hipGetErrorString(result); } return func; } // get a global var from primary context in device_id - hipDeviceptr_t GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + hipDeviceptr_t GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -122,8 +120,7 @@ class ROCMModuleNode : public runtime::ModuleNode { hipDeviceptr_t global = nullptr; size_t nbytes = 0; - ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str())); + ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str())); CHECK_EQ(nbytes, expect_nbytes); return global; } @@ -149,11 +146,8 @@ class ROCMModuleNode : public runtime::ModuleNode { class ROCMWrappedFunc { public: // initialize the ROCM function. - void Init(ROCMModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -161,10 +155,7 @@ class ROCMWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void* packed_args, - size_t packed_nbytes) const { + void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { int device_id; ROCM_CALL(hipGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -174,22 +165,12 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - void* config[] = { - HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, - HIP_LAUNCH_PARAM_END - }; + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, + &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, nullptr, - reinterpret_cast(&config))); + fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), + wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); } private: @@ -206,13 +187,10 @@ class ROCMWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; - -PackedFunc ROCMModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc ROCMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -221,18 +199,14 @@ PackedFunc ROCMModuleNode::GetFunction( return PackFuncPackedArg(f, info.arg_types); } -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string hip_source, - std::string assembly) { +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string hip_source, + std::string assembly) { auto n = make_object(data, fmt, fmap, hip_source, assembly); return Module(n); } -Module ROCMModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -253,19 +227,12 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco") -.set_body_typed(ROCMModuleLoadBinary); - - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip") -.set_body_typed(ROCMModuleLoadBinary); - +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index 7f2a0ce319bf3..c17e123c1a12b 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,12 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file */ -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly); +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h new file mode 100644 index 0000000000000..91a900afd900d --- /dev/null +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file minrpc_server.h + * \brief Minimum RPC server implementation, + * redirects all the calls to C runtime API. + * + * \note This file do not depend on c++ std or c std, + * and only depends on TVM's C runtime API. + */ +#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ + +#include +#include + +#include "../../../support/arena.h" +#include "../rpc_protocol.h" + +/*! \brief Whether or not to enable glog style DLOG */ +#ifndef TVM_MINRPC_ENABLE_LOGGING +#define TVM_MINRPC_ENABLE_LOGGING 0 +#endif + +#ifndef MINRPC_CHECK +#define MINRPC_CHECK(cond) \ + if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); +#endif + +#if TVM_MINRPC_ENABLE_LOGGING +#include +#endif + +namespace tvm { +namespace runtime { + +/*! + * \brief A minimum RPC server that only depends on the tvm C runtime.. + * + * All the dependencies are provided by the io arguments. + * + * \tparam TIOHandler IO provider to provide io handling. + * An IOHandler needs to provide the following functions: + * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - Exit: exit with status code. + */ +template +class MinRPCServer { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {} + + /*! \brief Run the server loop until shutdown signal is received. */ + void ServerLoop() { + RPCCode code; + uint64_t packet_len; + + while (true) { + arena_.RecycleAll(); + allow_clean_shutdown_ = true; + + this->Read(&packet_len); + if (packet_len == 0) continue; + this->Read(&code); + + allow_clean_shutdown_ = false; + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + this->Shutdown(); + return; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; + } + } + } + } + } + + void Shutdown() { + arena_.FreeAll(); + io_.Close(); + } + + void HandleNormalCallFunc() { + uint64_t call_handle; + TVMValue* values; + int* tcodes; + int num_args; + TVMValue ret_value[3]; + int ret_tcode[3]; + + this->Read(&call_handle); + RecvPackedSeq(&values, &tcodes, &num_args); + + int call_ecode = TVMFuncCall(reinterpret_cast(call_handle), values, tcodes, num_args, + &(ret_value[1]), &(ret_tcode[1])); + + if (call_ecode == 0) { + // Return value encoding as in LocalSession + int rv_tcode = ret_tcode[1]; + ret_tcode[0] = kDLInt; + ret_value[0].v_int64 = rv_tcode; + if (rv_tcode == kTVMNDArrayHandle) { + ret_tcode[1] = kTVMDLTensorHandle; + ret_value[2].v_handle = ret_value[1].v_handle; + ret_tcode[2] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 3); + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + ret_tcode[1] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } else { + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + uint8_t* data_ptr; + int call_ecode = 0; + if (ctx.device_type == kDLCPU) { + data_ptr = reinterpret_cast(handle) + offset; + } else { + data_ptr = this->ArenaAlloc(num_bytes); + call_ecode = + TVMDeviceCopyDataFromTo(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, DLContext{kDLCPU, 0}, type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + int call_ecode = 0; + + if (ctx.device_type == kDLCPU) { + uint8_t* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + } else { + uint8_t* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + call_ecode = + TVMDeviceCopyDataFromTo(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, + DLContext{kDLCPU, 0}, ctx, type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleSyscallFunc(RPCCode code) { + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + switch (code) { + case RPCCode::kFreeHandle: { + this->SyscallFreeHandle(values, tcodes, num_args); + break; + } + case RPCCode::kGetGlobalFunc: { + this->SyscallGetGlobalFunc(values, tcodes, num_args); + break; + } + case RPCCode::kDevSetDevice: { + this->ReturnException("SetDevice not supported"); + break; + } + case RPCCode::kDevGetAttr: { + this->ReturnException("GetAttr not supported"); + break; + } + case RPCCode::kDevAllocData: { + this->SyscallDevAllocData(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeData: { + this->SyscallDevFreeData(values, tcodes, num_args); + break; + } + case RPCCode::kDevStreamSync: { + this->SyscallDevStreamSync(values, tcodes, num_args); + break; + } + case RPCCode::kCopyAmongRemote: { + this->SyscallCopyAmongRemote(values, tcodes, num_args); + break; + } + default: { + this->ReturnException("Syscall not recognized"); + break; + } + } + } + + void HandleInitServer() { + uint64_t len; + this->Read(&len); + char* proto_ver = this->ArenaAlloc(len + 1); + this->ReadArray(proto_ver, len); + + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + MINRPC_CHECK(num_args == 0); + this->ReturnVoid(); + } + + void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + + void* handle = values[0].v_handle; + int64_t type_code = values[1].v_int64; + int call_ecode; + + if (type_code == kTVMNDArrayHandle) { + call_ecode = TVMArrayFree(static_cast(handle)); + } else if (type_code == kTVMPackedFuncHandle) { + call_ecode = TVMFuncFree(handle); + } else { + MINRPC_CHECK(type_code == kTVMModuleHandle); + call_ecode = TVMModFree(handle); + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kTVMStr); + + void* handle; + int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 9); + // from, from_offset + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + // to, to_offset + MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[3] == kDLInt); + // size + MINRPC_CHECK(tcodes[4] == kDLInt); + // ctx_from, ctx_to + MINRPC_CHECK(tcodes[5] == kTVMContext); + MINRPC_CHECK(tcodes[6] == kTVMContext); + // type_hint, stream + MINRPC_CHECK(tcodes[7] == kTVMDataType); + MINRPC_CHECK(tcodes[8] == kTVMOpaqueHandle); + + void* from = values[0].v_handle; + int64_t from_offset = values[1].v_int64; + void* to = values[2].v_handle; + int64_t to_offset = values[3].v_int64; + int64_t size = values[4].v_int64; + TVMContext ctx_from = values[5].v_ctx; + TVMContext ctx_to = values[6].v_ctx; + DLDataType type_hint = values[7].v_type; + TVMStreamHandle stream = values[8].v_handle; + + int call_ecode = TVMDeviceCopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 4); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kDLInt); + MINRPC_CHECK(tcodes[2] == kDLInt); + MINRPC_CHECK(tcodes[3] == kTVMDataType); + + TVMContext ctx = values[0].v_ctx; + int64_t nbytes = values[1].v_int64; + int64_t alignment = values[2].v_int64; + DLDataType type_hint = values[3].v_type; + + void* handle; + int call_ecode = TVMDeviceAllocDataSpace(ctx, nbytes, alignment, type_hint, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMDeviceFreeDataSpace(ctx, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_.Exit(static_cast(code)); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + template + void Read(T* data) { + static_assert(std::is_pod::value, "need to be trival"); + this->ReadRawBytes(data, sizeof(T)); + } + + template + void ReadArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->ReadRawBytes(data, sizeof(T) * count); + } + + template + void Write(const T& data) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(&data, sizeof(T)); + } + + template + void WriteArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(data, sizeof(T) * count); + } + + private: + // Internal allocator that redirects alloc to TVM's C API. + class PageAllocator { + public: + using ArenaPageHeader = tvm::support::ArenaPageHeader; + + explicit PageAllocator(TIOHandler io) : io_(io) {} + + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + void* data; + + if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, + DLDataType{kDLInt, 1, 1}, &data) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + + ArenaPageHeader* header = static_cast(data); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + + void deallocate(ArenaPageHeader* page) { + if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + } + + static const constexpr int kPageSize = 2 << 10; + static const constexpr int kPageAlign = 8; + + private: + TIOHandler io_; + }; + + void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { + RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); + } + + void ReturnVoid() { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + } + + void ReturnHandle(void* handle) { + int32_t num_args = 1; + int32_t tcode = kTVMOpaqueHandle; + RPCCode code = RPCCode::kReturn; + uint64_t encode_handle = reinterpret_cast(handle); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + this->Write(encode_handle); + } + + void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } + + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { + RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + } + + void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); } + + void ReadRawBytes(void* data, size_t size) { + uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixRead(buf, size - ndone); + if (ret == 0) { + if (allow_clean_shutdown_) { + this->Shutdown(); + io_.Exit(0); + } else { + this->ThrowError(RPCServerStatus::kReadError); + } + } + if (ret == -1) { + this->ThrowError(RPCServerStatus::kReadError); + } + ndone += ret; + buf += ret; + } + } + + void WriteRawBytes(const void* data, size_t size) { + const uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixWrite(buf, size - ndone); + if (ret == 0 || ret == -1) { + this->ThrowError(RPCServerStatus::kWriteError); + } + buf += ret; + ndone += ret; + } + } + + /*! \brief IO handler. */ + TIOHandler io_; + /*! \brief internal arena. */ + support::GenericArena arena_; + /*! \brief Whether we are in a state that allows clean shutdown. */ + bool allow_clean_shutdown_{true}; + static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian."); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc new file mode 100644 index 0000000000000..9784780fea183 --- /dev/null +++ b/src/runtime/rpc/minrpc/posix_popen_server.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Disable constructor to bring minimum dep on c++ABI. +#define TVM_ARENA_HAS_DESTRUCTOR 0 + +#include + +#include + +#include "minrpc_server.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief IOHandler based on posix API. + */ +class PosixIOHandler { + public: + explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) + : read_fd_(read_fd), write_fd_(write_fd) {} + + ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); } + + ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); } + + void Exit(int code) { exit(code); } + + void Close() { + if (read_fd_ != 0) close(read_fd_); + if (write_fd_ != 0) close(write_fd_); + } + + private: + int read_fd_{0}; + int write_fd_{1}; +}; + +/*! \brief Type for the posix version of min rpc server. */ +using PosixMinRPCServer = MinRPCServer; + +} // namespace runtime +} // namespace tvm + +int main(int argc, char* argv[]) { + if (argc != 3) return -1; + // pass the descriptor via arguments. + tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2])); + tvm::runtime::PosixMinRPCServer server(handler); + server.ServerLoop(); + return 0; +} diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc new file mode 100644 index 0000000000000..eaa64e3372c67 --- /dev/null +++ b/src/runtime/rpc/rpc_channel.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.cc + */ +#include "rpc_channel.h" + +#include + +namespace tvm { +namespace runtime { + +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + LOG(FATAL) << "CallbackChannel::Send"; + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + + if (ret.type_code() != kTVMBytes) { + LOG(FATAL) << "CallbackChannel::Recv"; + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h new file mode 100644 index 0000000000000..114bc0a2e7bde --- /dev/null +++ b/src/runtime/rpc/rpc_channel.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ + +#include + +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Abstract channel interface used to create RPCEndpoint. + */ +class RPCChannel { + public: + /*! \brief virtual destructor */ + virtual ~RPCChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + virtual size_t Send(const void* data, size_t size) = 0; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + virtual size_t Recv(void* data, size_t size) = 0; +}; + +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + /*! + * \brief Constructor. + * + * \param fsend The send function, takes in a TVMByteArray and returns the + * number of bytes sent in that array. Returns -1 if error happens. + * \param frecv The recv function, takes an expected maximum size, and return + * a byte array with the actual amount of data received. + */ + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 9fd45acd14bf4..196a97ecbd665 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -21,8 +21,11 @@ * \file rpc_device_api.cc */ #include -#include #include +#include + +#include + #include "rpc_session.h" namespace tvm { @@ -31,20 +34,22 @@ namespace runtime { class RPCDeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevSetDevice, ctx); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx); } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - *rv = GetSess(ctx)->CallRemote( - RPCCode::kDevGetAttr, ctx, static_cast(kind)); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); - void *data = sess->CallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + auto remote_ctx = RemoveSessMask(ctx); + void* data = + sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(remote_ctx, nbytes, alignment, type_hint); + RemoteSpace* space = new RemoteSpace(); space->data = data; space->sess = std::move(sess); @@ -52,68 +57,68 @@ class RPCDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { RemoteSpace* space = static_cast(ptr); + auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->CallRemote( - RPCCode::kDevFreeData, ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } delete space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int from_dev_type = ctx_from.device_type; int to_dev_type = ctx_to.device_type; - if (from_dev_type > kRPCSessMask && - to_dev_type > kRPCSessMask) { + if (from_dev_type > kRPCSessMask && to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; - GetSess(ctx_from)->CallRemote( - RPCCode::kCopyAmongRemote, - static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, ctx_from, ctx_to, type_hint, stream); - } else if (from_dev_type > kRPCSessMask && - to_dev_type == kDLCPU) { - GetSess(ctx_from)->CopyFromRemote( - static_cast(from)->data, from_offset, - to, to_offset, size, ctx_from, type_hint); - } else if (from_dev_type == kDLCPU && - to_dev_type > kRPCSessMask) { - GetSess(ctx_to)->CopyToRemote( - (void*)from, from_offset, // NOLINT(*) - static_cast(to)->data, to_offset, - size, ctx_to, type_hint); + auto remote_ctx_from = RemoveSessMask(ctx_from); + auto remote_ctx_to = RemoveSessMask(ctx_to); + auto remote_ctx = remote_ctx_from; + if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; + GetSess(ctx_from) + ->GetDeviceAPI(remote_ctx) + ->CopyDataFromTo(static_cast(from)->data, from_offset, + static_cast(to)->data, to_offset, size, + remote_ctx_from, remote_ctx_to, type_hint, stream); + } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { + auto remote_ctx_from = RemoveSessMask(ctx_from); + GetSess(ctx_from)->CopyFromRemote(static_cast(from)->data, from_offset, + to, to_offset, size, remote_ctx_from, type_hint); + } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { + auto remote_ctx_to = RemoveSessMask(ctx_to); + GetSess(ctx_to)->CopyToRemote(const_cast(from), from_offset, + static_cast(to)->data, to_offset, size, + remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevStreamSync, ctx, stream); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(remote_ctx, stream); } private: std::shared_ptr GetSess(TVMContext ctx) { int dev_type = ctx.device_type; CHECK_GE(dev_type, kRPCSessMask); - int tbl_index = dev_type / kRPCSessMask - 1; + int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } + + static TVMContext RemoveSessMask(TVMContext ctx) { + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } }; -TVM_REGISTER_GLOBAL("device_api.rpc") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCDeviceAPI inst; - DeviceAPI* ptr = &inst; - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCDeviceAPI inst; + DeviceAPI* ptr = &inst; + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc new file mode 100644 index 0000000000000..bf85dc56dac9f --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -0,0 +1,1034 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_session.cc + * \brief RPC session for remote function call. + */ +#include "rpc_endpoint.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "../../support/ring_buffer.h" +#include "../object_internal.h" +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +/*! + * Event-driven state-machine based handlers for RPCEndpoint. + * + * Key functions: + * + * - SendPackedSeq: send the arguments over to the peer + * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). + */ +class RPCEndpoint::EventHandler : public dmlc::Stream { + public: + EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name, + std::string* remote_key, std::function flush_writer) + : reader_(reader), + writer_(writer), + name_(name), + remote_key_(remote_key), + flush_writer_(flush_writer) { + this->Clear(); + + if (*remote_key == "%toinit") { + state_ = kInitHeader; + remote_key_->resize(0); + pending_request_bytes_ = sizeof(int32_t); + } + } + + /*! + * \brief Bytes needed to fulfill current request + */ + size_t BytesNeeded() const { + if (reader_->bytes_available() < pending_request_bytes_) { + return pending_request_bytes_ - reader_->bytes_available(); + } else { + return 0; + } + } + + /*! + * \brief Request number of bytes from the reader. + * \param nbytes The number of bytes + */ + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + + /*! \return Whether we are ready to handle next request. */ + bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; } + + /*! \return Whether we can perform a clean shutdown */ + bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; } + + /*! \brief Finish the copy ack stage. */ + void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); } + + /*! + * \brief Enter the io loop until the next event. + * \param client_mode Whether we are in the client. + * \param async_server_mode Whether we are in the async server mode. + * \param setreturn The function to set the return value encoding. + * \return The function to set return values when there is a return event. + */ + RPCCode HandleNextEvent(bool client_mode, bool async_server_mode, + RPCSession::FEncodeReturn setreturn) { + std::swap(client_mode_, client_mode); + std::swap(async_server_mode_, async_server_mode); + + RPCCode status = RPCCode::kNone; + + while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) { + switch (state_) { + case kInitHeader: + HandleInitHeader(); + break; + case kRecvPacketNumBytes: { + uint64_t packet_nbytes; + CHECK(this->Read(&packet_nbytes)); + if (packet_nbytes != 0) { + this->SwitchToState(kProcessPacket); + this->RequestBytes(packet_nbytes); + } else { + this->SwitchToState(kRecvPacketNumBytes); + } + break; + } + case kProcessPacket: { + this->HandleProcessPacket(setreturn); + break; + } + case kWaitForAsyncCallback: { + break; + } + case kReturnReceived: { + this->SwitchToState(kRecvPacketNumBytes); + status = RPCCode::kReturn; + break; + } + case kCopyAckReceived: { + status = RPCCode::kCopyAck; + break; + } + case kShutdownReceived: { + status = RPCCode::kShutdown; + } + } + } + + std::swap(async_server_mode_, async_server_mode); + std::swap(client_mode_, client_mode); + return status; + } + + /*! \brief Clear all the states in the Handler.*/ + void Clear() { + state_ = kRecvPacketNumBytes; + pending_request_bytes_ = sizeof(uint64_t); + } + + /*! + * \brief Validate that the arguments can be sent through RPC. + * \param arg_values The argument values. + * \param type_codes The type codes. + */ + void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) { + TVMArgs args(arg_values, type_codes, num_args); + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " + << args[i].AsObjectRef()->GetTypeKey() << " is not supported by RPC"; + } else if (tcode == kTVMContext) { + DLContext ctx = args[i]; + CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) + << "InternalError: cannot pass RPC context in the channel"; + } + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); + } + + uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode) { + return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this); + } + + void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode) { + RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this); + } + + // Endian aware IO handling + using Stream::Read; + using Stream::ReadArray; + using Stream::Write; + using Stream::WriteArray; + + bool Read(RPCCode* code) { + int32_t cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast(cdata); + return true; + } + void Write(RPCCode code) { + int32_t cdata = static_cast(code); + this->Write(cdata); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + protected: + enum State { + kInitHeader, + kRecvPacketNumBytes, + kProcessPacket, + kWaitForAsyncCallback, + kReturnReceived, + kCopyAckReceived, + kShutdownReceived + }; + // Current state; + State state_; + // Initialize remote header + bool init_header_step_{0}; + // Whether current handler is client or server mode. + bool client_mode_{false}; + // Whether current handler is in the async server mode. + bool async_server_mode_{false}; + // Internal arena + support::Arena arena_; + + // State switcher + void SwitchToState(State state) { + // invariant + if (state != kCopyAckReceived) { + CHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; + } + // need to actively flush the writer + // so the data get pushed out. + if (state_ == kWaitForAsyncCallback) { + flush_writer_(); + } + state_ = state; + CHECK(state != kInitHeader) << "cannot switch to init header"; + if (state == kRecvPacketNumBytes) { + this->RequestBytes(sizeof(uint64_t)); + // recycle arena for the next session. + arena_.RecycleAll(); + } + } + + // handler for initial header read + void HandleInitHeader() { + if (init_header_step_ == 0) { + int32_t len; + this->Read(&len); + remote_key_->resize(len); + init_header_step_ = 1; + this->RequestBytes(len); + return; + } else { + CHECK_EQ(init_header_step_, 1); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for read code. + void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kNone; + this->Read(&code); + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscall(code); + } else { + switch (code) { + case RPCCode::kInitServer: { + this->HandleInitServer(); + break; + } + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + this->HandleReturn(code, setreturn); + break; + } + case RPCCode::kCopyAck: { + this->SwitchToState(kCopyAckReceived); + break; + } + case RPCCode::kShutdown: { + this->SwitchToState(kShutdownReceived); + break; + } + default: + LOG(FATAL) << "Unknown event " << static_cast(code); + } + } + } + + /*! + * \brief Recive incoming packed seq from the stream. + * \return The received argments. + * \note The TVMArgs is available until we switchstate. + */ + TVMArgs RecvPackedSeq() { + TVMValue* values; + int* tcodes; + int num_args; + RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); + return TVMArgs(values, tcodes, num_args); + } + + /*! + * \brief Return exception to the remote. + * \param err_msg The error message. + */ + void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); } + + /*! + * \brief Return nullptr to the remote. + * \param err_msg The error message. + */ + void ReturnVoid() { RPCReference::ReturnVoid(this); } + + /*! + * \brief Return a packed sequence to the remote. + * \param args The arguments. + */ + void ReturnPackedSeq(TVMArgs args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this); + } + + /*! + * \brief Handle the case when return/exception value is received. + * \param code The RPC code. + * \param setreturn The function to encode return. + */ + void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { + TVMArgs args = RecvPackedSeq(); + + if (code == RPCCode::kException) { + // switch to the state before sending exception. + this->SwitchToState(kRecvPacketNumBytes); + std::string msg = args[0]; + LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; + } + + CHECK(setreturn != nullptr) << "fsetreturn not available"; + setreturn(args); + + this->SwitchToState(kReturnReceived); + } + + void HandleSyscall(RPCCode code); + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + auto* sess = GetServingSession(); + + // Return Copy Ack with the given data + auto fcopyack = [this](char* data_ptr, size_t num_bytes) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + this->SwitchToState(kRecvPacketNumBytes); + }; + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) { + char* data_ptr = reinterpret_cast(handle) + offset; + fcopyack(data_ptr, num_bytes); + } else { + char* data_ptr = this->ArenaAlloc(num_bytes); + + auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](RPCCode status, + TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + } + fcopyack(data_ptr, num_bytes); + } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyFromRemote(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, type_hint, on_copy_complete); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + auto* sess = GetServingSession(); + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession()) { + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } else { + char* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes); + } + + auto on_copy_complete = [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyToRemote(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, ctx, + type_hint, on_copy_complete); + } + } + + // Handle for packed call. + void HandleNormalCallFunc() { + uint64_t call_handle; + + this->Read(&call_handle); + TVMArgs args = RecvPackedSeq(); + + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncCallFunc(reinterpret_cast(call_handle), args.values, + args.type_codes, args.size(), + [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnPackedSeq(args); + } + this->SwitchToState(kRecvPacketNumBytes); + }); + } + + void HandleInitServer() { + std::string client_protocol_ver; + + uint64_t len; + this->Read(&len); + client_protocol_ver.resize(len); + this->Read(dmlc::BeginPtr(client_protocol_ver), len); + + TVMArgs args = RecvPackedSeq(); + + try { + CHECK(serving_session_ == nullptr) << "Server has already been initialized"; + + std::string server_protocol_ver = kRPCProtocolVer; + CHECK_EQ(client_protocol_ver, server_protocol_ver) + << "Server[" << name_ << "]: Client protocol version mismatch with the server " + << " server protocol=" << server_protocol_ver + << ", client protocol=" << client_protocol_ver; + + std::string constructor_name; + TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0); + + if (args.size() == 0) { + constructor_name = "rpc.LocalSession"; + serving_session_ = std::make_shared(); + } else { + constructor_name = args[0].operator std::string(); + constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1); + } + + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; + + try { + fconstructor->CallPacked(constructor_args, &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name << ":\n" + << e.what(); + } + + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); + this->ReturnVoid(); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleSyscallStreamSync() { + TVMArgs args = RecvPackedSeq(); + try { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncStreamWait(ctx, handle, [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnVoid(); + } + this->SwitchToState(kRecvPacketNumBytes); + }); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for special syscalls that have a specific RPCCode. + template + void SysCallHandler(F f) { + TVMArgs args = RecvPackedSeq(); + try { + TVMRetValue rv; + f(GetServingSession(), args, &rv); + TVMValue ret_value; + int ret_tcode; + TVMArgsSetter setter(&ret_value, &ret_tcode); + setter(0, rv); + + this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + this->SwitchToState(kRecvPacketNumBytes); + } + + private: + RPCSession* GetServingSession() const { + CHECK(serving_session_ != nullptr) + << "Need to call InitRemoteSession first before any further actions"; + CHECK(!serving_session_->IsAsync() || async_server_mode_) + << "Cannot host an async session in a non-Event driven server"; + + return serving_session_.get(); + } + // Utility functions + // Internal read function, update pending_request_bytes_ + size_t Read(void* data, size_t size) final { + CHECK_LE(size, pending_request_bytes_); + reader_->Read(data, size); + pending_request_bytes_ -= size; + return size; + } + // wriite the data to the channel. + void Write(const void* data, size_t size) final { writer_->Write(data, size); } + // Number of pending bytes requests + size_t pending_request_bytes_{0}; + // The ring buffer to read data from. + support::RingBuffer* reader_; + // The ringr buffer to write reply to. + support::RingBuffer* writer_; + // The session used to serve the RPC requests. + std::shared_ptr serving_session_; + // Name of endpoint. + std::string name_; + // remote key + std::string* remote_key_; + // function to flush the writer. + std::function flush_writer_; +}; + +RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kCallFunc; + while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) { + while (writer_.bytes_available() != 0) { + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + } + size_t bytes_needed = handler_->BytesNeeded(); + if (bytes_needed != 0) { + size_t n = reader_.WriteWithCallback( + [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed); + if (n == 0) { + if (handler_->CanCleanShutdown()) { + return RPCCode::kShutdown; + } else { + LOG(FATAL) << "Channel closes before we get neded bytes"; + } + } + } + code = handler_->HandleNextEvent(client_mode, false, setreturn); + } + return code; +} + +void RPCEndpoint::Init() { + // callback to flush the writer. + auto flush_writer = [this]() { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + if (n == 0) break; + } + }; + + // Event handler + handler_ = std::make_shared(&reader_, &writer_, name_, &remote_key_, flush_writer); + + // Quick function to for syscall remote. + syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { + std::lock_guard lock(mutex_); + RPCCode code = static_cast(all_args[0].operator int()); + TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1); + + uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + CHECK_EQ(args.size(), 1); + *rv = args[0]; + }); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + }); +} + +std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, + std::string name, std::string remote_key) { + std::shared_ptr endpt = std::make_shared(); + endpt->channel_ = std::move(channel); + endpt->name_ = std::move(name); + endpt->remote_key_ = std::move(remote_key); + endpt->Init(); + return endpt; +} + +RPCEndpoint::~RPCEndpoint() { this->Shutdown(); } + +void RPCEndpoint::Shutdown() { + if (channel_ != nullptr) { + RPCCode code = RPCCode::kShutdown; + uint64_t packet_nbytes = sizeof(code); + + handler_->Write(packet_nbytes); + handler_->Write(code); + + // flush all writing buffer to output channel. + try { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + if (n == 0) break; + } + } catch (const dmlc::Error& e) { + } + channel_.reset(nullptr); + } +} + +void RPCEndpoint::ServerLoop() { + if (const auto* f = Registry::Get("tvm.rpc.server.start")) { + (*f)(); + } + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { + (*f)(); + } + channel_.reset(nullptr); +} + +int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { + RPCCode code = RPCCode::kNone; + if (in_bytes.length() != 0) { + reader_.Write(in_bytes.c_str(), in_bytes.length()); + code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); + } + if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + } + CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + if (code == RPCCode::kShutdown) return 0; + if (writer_.bytes_available() != 0) return 2; + return 1; +} + +void RPCEndpoint::InitRemoteSession(TVMArgs args) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kInitServer; + std::string protocol_ver = kRPCProtocolVer; + uint64_t length = protocol_ver.length(); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(length) + length + + handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(length); + handler_->WriteArray(protocol_ver.data(), length); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +// Get remote function with name +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + RPCSession::FEncodeReturn encode_return) { + std::lock_guard lock(mutex_); + + handler_->ValidateArguments(arg_values, arg_type_codes, num_args); + RPCCode code = RPCCode::kCallFunc; + uint64_t handle = reinterpret_cast(h); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(handle) + + handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true); + + code = HandleUntilReturnEvent(true, encode_return); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +void RPCEndpoint::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_to, DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyToRemote; + uint64_t handle = reinterpret_cast(to); + uint64_t offset = static_cast(to_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_to) + sizeof(type_hint) + data_size; + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->Write(type_hint); + handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); + + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn); +} + +void RPCEndpoint::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_from, DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyFromRemote; + uint64_t handle = reinterpret_cast(from); + uint64_t offset = static_cast(from_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_from) + sizeof(type_hint); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_from); + handler_->Write(type_hint); + + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck); + handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); + handler_->FinishCopyAck(); +} + +// SysCallEventHandler functions +void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + *rv = handler->GetFunction(name); +} + +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + void* handle = args[0]; + int type_code = args[1]; + handler->FreeHandle(handle, type_code); +} + +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + handler->GetDeviceAPI(ctx)->SetDevice(ctx); +} + +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + DeviceAttrKind kind = static_cast(args[1].operator int()); + if (kind == kExist) { + DeviceAPI* api = handler->GetDeviceAPI(ctx, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, rv); + } else { + *rv = 0; + } + } else { + handler->GetDeviceAPI(ctx)->GetAttr(ctx, static_cast(kind), rv); + } +} + +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + uint64_t nbytes = args[1]; + uint64_t alignment = args[2]; + DLDataType type_hint = args[3]; + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); + *rv = data; +} + +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + void* ptr = args[1]; + handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); +} + +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + void* from = args[0]; + uint64_t from_offset = args[1]; + void* to = args[2]; + uint64_t to_offset = args[3]; + uint64_t size = args[4]; + TVMContext ctx_from = args[5]; + TVMContext ctx_to = args[6]; + DLDataType type_hint = args[7]; + TVMStreamHandle stream = args[8]; + TVMContext ctx = ctx_from; + + if (ctx.device_type == kDLCPU) { + ctx = ctx_to; + } else { + CHECK(ctx_to.device_type == kDLCPU || ctx_to.device_type == ctx_from.device_type) + << "Can not copy across different ctx types directly"; + } + handler->GetDeviceAPI(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); +} + +void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { + // Event handler sit at clean state at this point. + switch (code) { + // system functions + case RPCCode::kFreeHandle: + SysCallHandler(RPCFreeHandle); + break; + case RPCCode::kGetGlobalFunc: + SysCallHandler(RPCGetGlobalFunc); + break; + case RPCCode::kDevSetDevice: + SysCallHandler(RPCDevSetDevice); + break; + case RPCCode::kDevGetAttr: + SysCallHandler(RPCDevGetAttr); + break; + case RPCCode::kDevAllocData: + SysCallHandler(RPCDevAllocData); + break; + case RPCCode::kDevFreeData: + SysCallHandler(RPCDevFreeData); + break; + case RPCCode::kDevStreamSync: + this->HandleSyscallStreamSync(); + break; + case RPCCode::kCopyAmongRemote: + SysCallHandler(RPCCopyAmongRemote); + break; + default: + LOG(FATAL) << "Unknown event " << static_cast(code); + } + + if (state_ != kWaitForAsyncCallback) { + CHECK_EQ(state_, kRecvPacketNumBytes); + } +} + +/*! + * \brief RPC client session that proxies all calls to an endpoint. + */ +class RPCClientSession : public RPCSession, public DeviceAPI { + public: + /*! + * \brief param endpoint The client endpoint of the session. + */ + explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) {} + + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final { + return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); + } + + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return); + } + + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) final { + endpoint_->CopyToRemote(from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + } + + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) final { + endpoint_->CopyFromRemote(from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + } + + void FreeHandle(void* handle, int type_code) final { + endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); + } + + void SetDevice(TVMContext ctx) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); } + + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (ctx.device_type == kDLCPU && kind == kExist) { + // cpu always exists. + *rv = 1; + } else { + *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast(kind)); + } + } + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) final { + return endpoint_->SysCallRemote(RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr); + } + + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, const_cast(from), from_offset, to, + to_offset, size, ctx_from, ctx_to, type_hint, stream); + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); + } + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { return this; } + + bool IsLocalSession() const final { return false; } + + private: + std::shared_ptr endpoint_; +}; + +std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { + return std::make_shared(endpoint); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h new file mode 100644 index 0000000000000..2b88cee15c014 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.h @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_endpoint.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ +#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ + +#include + +#include +#include +#include +#include + +#include "../../support/ring_buffer.h" +#include "rpc_channel.h" +#include "rpc_protocol.h" +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +// Magic header for RPC data plane +const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; + +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; + +/*! + * \brief Communication endpoints to connect local and remote RPC sessions. + * An endpoint can either be a client or a server. + */ +class RPCEndpoint { + public: + /*! \brief virtual destructor */ + ~RPCEndpoint(); + /*! + * \brief The server loop that server runs to handle RPC calls. + */ + void ServerLoop(); + /*! + * \brief Message handling function for an async IO event driven server. + * + * Called when the server receives a message or an IO event update. + * Event driven handler will never call recv on the channel + * and always relies on the ServerIOEventHandler to receive the data. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ + int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); + + /*! + * \brief Initalize the session on the remote that will be used to back all the RPC requests. + * + * If no session constructor arguments is passed, LocalSession will be used in the remote. + * Otherwise the remote serving session will be constructed using the arguments + * specified in the session_constructor_args. + * + * The construction rule can be summarized as follows: + * + * \code + * + * auto args = session_constructor_args; + * int n = args.size(); + * if (n != 0) { + * std::string constructor = args[0]; + * server.serving_session_ = GetGlobalFunc(constructor)( + * args[1], args[2] ... args[n - 1]) + * } else { + * server.serving_session_ = LocalSession(); + * } + * \endcode + * + * \param session_constructor_args Optional sequence of the remote sesssion constructor. + */ + void InitRemoteSession(TVMArgs session_constructor_args); + + /*! + * \brief Call into remote function + * \param handle The function handle + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to receive return value encodings. + */ + void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return); + /*! + * \brief Copy bytes into remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_to The target context. + * \param type_hint Hint of content data type. + */ + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint); + /*! + * \brief Copy bytes from remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_from The source context. + * \param type_hint Hint of content data type. + */ + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint); + + /*! + * \brief Call a remote defined system function with arguments. + * \param fcode The function code. + * \param args The arguments + * \return The returned remote value. + */ + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args); + /*! + * \brief Create a RPC session with given channel. + * \param channel The communication channel. + * \param name The local name of the session, used for debug + * \param remote_key The remote key of the session + * if remote_key equals "%toinit", we need to re-intialize + * it by event handler. + */ + static std::shared_ptr Create(std::unique_ptr channel, std::string name, + std::string remote_key); + + private: + class EventHandler; + // Handle events until receives a return + // Also flushes channels so that the function advances. + RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); + // Initalization + void Init(); + // Shutdown + void Shutdown(); + // Internal channel. + std::unique_ptr channel_; + // Internal mutex + std::mutex mutex_; + // Internal ring buffer. + support::RingBuffer reader_, writer_; + // Event handler. + std::shared_ptr handler_; + // syscall remote with specified function code. + PackedFunc syscall_remote_; + // The name of the session. + std::string name_; + // The remote key + std::string remote_key_; +}; + +/*! + * \brief Create an RPC client session from an RPC client endpoint. + * \param endpoint The endpoint. + * \return The created session. + */ +std::shared_ptr CreateClientSession(std::shared_ptr endpoint); + +// implementation of inline functions +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { + return syscall_remote_(static_cast(code), std::forward(args)...); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 29adb0fed108d..f5b933fcf79fe 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,32 +19,32 @@ /*! * \file rpc_event_impl.cc - * \brief Event based RPC server implementation. + * \brief Event driven RPC server implementation. */ #include + #include -#include "rpc_session.h" + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { -PackedFunc CreateEventDrivenServer(PackedFunc fsend, - std::string name, - std::string remote_key) { +PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) { static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) { LOG(FATAL) << "Do not allow explicit receive"; return 0; }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerEventHandler(args[0], args[1]); - *rv = ret; - }); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); + *rv = ret; + }); } -TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") -.set_body_typed(CreateEventDrivenServer); +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc new file mode 100644 index 0000000000000..b35c62d255fc7 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file local_session.cc + * \brief Local session that directs requests to local API. + */ +#include "rpc_local_session.h" + +#include +#include + +#include + +namespace tvm { +namespace runtime { + +RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { + if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else { + return nullptr; + } +} + +void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { + int rv_tcode = rv.type_code(); + + // return value encoding. + TVMValue ret_value_pack[3]; + int ret_tcode_pack[3]; + TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); + // first location always encode type code. + set_arg(0, rv_tcode); + + if (rv_tcode == kTVMNDArrayHandle) { + // We follow a special protocol to return NDArray to client side + // The first pack value is the NDArray handle as DLTensor + // The second pack value is a customized deleter that deletes the NDArray. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMDLTensorHandle; + ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; + ret_tcode_pack[2] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + // MoveToCHost means rv no longer manages the object. + // return handle instead. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else if (rv_tcode == kTVMBytes) { + TVMByteArray byte_arr; + auto* sptr = rv.ptr(); + byte_arr.data = sptr->data(); + byte_arr.size = sptr->length(); + set_arg(1, byte_arr); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else { + set_arg(1, rv); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } +} + +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + this->EncodeReturn(std::move(rv), encode_return); +} + +void LocalSession::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_to, DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + this->GetDeviceAPI(ctx_to)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, cpu_ctx, + ctx_to, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_to)->StreamSync(ctx_to, nullptr); +} + +void LocalSession::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_from, DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + this->GetDeviceAPI(ctx_from)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, ctx_from, + cpu_ctx, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_from)->StreamSync(ctx_from, nullptr); +} + +void LocalSession::FreeHandle(void* handle, int type_code) { + TVMValue value; + value.v_handle = handle; + // will trigger deleter once the rv goes out of the scope. + TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); +} + +DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { + return DeviceAPI::Get(ctx, allow_missing); +} + +TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h new file mode 100644 index 0000000000000..7a67ce86bf80e --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_local_session.h + * \brief Local session that directs all request to the local runtime API. + */ +#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ + +#include +#include + +#include +#include +#include + +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A local session that directly use the handle repr of the + * local tvm runtime objects on the same process. + */ +class LocalSession : public RPCSession { + public: + // function overrides + PackedFuncHandle GetFunction(const std::string& name) override; + + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) override; + + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) override; + + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) override; + + void FreeHandle(void* handle, int type_code) override; + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override; + + bool IsLocalSession() const override { return true; } + + protected: + /*! + * \brief internal encode return fucntion. + * \param rv The return value. + * \param encode_return The encoding function. + */ + void EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 0e48e6fb27089..8c462698f6483 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -18,69 +18,116 @@ */ /*! - * \file rpc_device_api.cc - * \brief RPC module. + * \file rpc_module.cc + * \brief RPC runtime module. */ +#include #include -#include + #include +#include + +#include "rpc_endpoint.h" #include "rpc_session.h" namespace tvm { namespace runtime { -// Wrapped remote function to packed func. -class RPCWrappedFunc { +/*! + * \brief A wrapped remote function as a PackedFunc. + */ +class RPCWrappedFunc : public Object { public: - RPCWrappedFunc(void* handle, - std::shared_ptr sess) - : handle_(handle), sess_(sess) { - fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - WrapRemote(sess, args, rv); - }); - } + RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) {} + + void operator()(TVMArgs args, TVMRetValue* rv) const { + std::vector values(args.values, args.values + args.size()); + std::vector type_codes(args.type_codes, args.type_codes + args.size()); + std::vector> temp_dltensors; - void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); + // scan and check whether we need rewrite these arguments + // to their remote variant. + for (int i = 0; i < args.size(); ++i) { + int tcode = type_codes[i]; + + switch (tcode) { + case kTVMDLTensorHandle: + case kTVMNDArrayHandle: { + // Pass NDArray as DLTensor, NDArray and DLTensor + // are compatible to each other, just need to change the index. + type_codes[i] = kTVMDLTensorHandle; + // translate to a remote view of DLTensor + auto dptr = std::make_unique(*static_cast(values[i].v_handle)); + dptr->ctx = RemoveSessMask(dptr->ctx); + dptr->data = static_cast(dptr->data)->data; + values[i].v_handle = dptr.get(); + temp_dltensors.emplace_back(std::move(dptr)); + break; + } + case kTVMContext: { + values[i].v_ctx = RemoveSessMask(values[i].v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode)); + break; + } + } + } + auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return); } + ~RPCWrappedFunc() { try { - sess_->CallRemote(RPCCode::kFreeFunc, handle_); + sess_->FreeHandle(handle_, kTVMPackedFuncHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } - static void WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue* rv); + private: + // remote function handle + void* handle_{nullptr}; + // pointer to the session. + std::shared_ptr sess_; - static void* UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg); + // unwrap a remote value to the underlying handle. + void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + // wrap a remote return via Set + void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; + + // remove a remote session mask + TVMContext RemoveSessMask(TVMContext ctx) const { + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "Can not pass in local context or context with a different remote session"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } // deleter of RPC remote array static void RemoteNDArrayDeleter(Object* obj) { auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); delete space; delete ptr; } + // wrap return value as remote NDArray. - static NDArray WrapRemoteNDArray(std::shared_ptr sess, - DLTensor* tensor, - void* nd_handle) { + NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); - space->sess = sess; + space->sess = sess_; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); // RAII now in effect - data->shape_ = std::vector( - tensor->shape, tensor->shape + tensor->ndim); + data->shape_ = std::vector(tensor->shape, tensor->shape + tensor->ndim); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.ndim = static_cast(data->shape_.size()); // setup dtype @@ -88,31 +135,25 @@ class RPCWrappedFunc { // setup ctx, encode as remote session data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( - static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess->table_index() + 1)); + static_cast(tensor->ctx.device_type) + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } - - private: - PackedFunc fwrap_; - void* handle_{nullptr}; - std::shared_ptr sess_; }; // RPC that represents a remote module session. class RPCModuleNode final : public ModuleNode { public: RPCModuleNode(void* module_handle, std::shared_ptr sess) - : module_handle_(module_handle), sess_(sess) { - } + : module_handle_(module_handle), sess_(sess) {} + ~RPCModuleNode() { if (module_handle_ != nullptr) { try { - sess_->CallRemote(RPCCode::kModuleFree, module_handle_); + sess_->FreeHandle(module_handle_, kTVMModuleHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } @@ -120,177 +161,247 @@ class RPCModuleNode final : public ModuleNode { } } - const char* type_key() const final { - return "rpc"; - } + const char* type_key() const final { return "rpc"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { - RPCFuncHandle handle = GetFuncHandle(name); - return WrapRemote(handle); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (module_handle_ == nullptr) { + return WrapRemoteFunc(sess_->GetFunction(name)); + } else { + InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); + return remote_mod_get_function_(GetRef(this), name, false); + } } std::string GetSource(const std::string& format) final { - if (module_handle_ != nullptr) { - std::string ret = sess_->CallRemote( - RPCCode::kModuleGetSource, module_handle_, format); - } + LOG(FATAL) << "GetSource for rpc Module is not supported"; return ""; } - std::shared_ptr& sess() { - return sess_; + PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); + // Remove session mask because we pass ctx by parts. + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + + if (module_handle_ != nullptr) { + return remote_get_time_evaluator_(GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); + } else { + return remote_get_time_evaluator_(Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); + } } - PackedFunc GetTimeEvaluator(const std::string& name, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - RPCFuncHandle handle = GetFuncHandle(name); - if (handle == nullptr) return PackedFunc(); - handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); - return WrapRemote(handle); + Module LoadModule(std::string name) { + InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); + return remote_load_module_(name); } - void* module_handle() const { - return module_handle_; + void ImportModule(Module other) { + InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); + remote_import_module_(GetRef(this), other); } + const std::shared_ptr& sess() { return sess_; } + + void* module_handle() const { return module_handle_; } + private: - PackedFunc WrapRemote(RPCFuncHandle handle) { + template + void InitRemoteFunc(FType* func, const std::string& name) { + if (*func != nullptr) return; + RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); + CHECK(handle != nullptr) << "Cannot found remote function " << name; + *func = WrapRemoteFunc(handle); + } + + PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); - return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } - RPCFuncHandle GetFuncHandle(const std::string& name) { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - return handle; - } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr sess_; - // Wrap function to wrap remote module/function. - PackedFunc fwrap_; + // remote function to get time evaluator + TypedPackedFunc, std::string, int, int, int, int, int)> + remote_get_time_evaluator_; + // remote function getter for modules. + TypedPackedFunc remote_mod_get_function_; + // remote function getter for load module + TypedPackedFunc remote_load_module_; + // remote function getter for load module + TypedPackedFunc remote_import_module_; }; -void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg) { +void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "ValueError: Cannot pass a non-RPC module to remote"; + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " - << runtime::TypeCode2Str(arg.type_code()) + LOG(FATAL) << "ValueError: Cannot pass type " << runtime::TypeCode2Str(arg.type_code()) << " as an argument to the remote"; return nullptr; } } -void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue *rv) { - void* handle = args.values[0].v_handle; - int tcode = args.type_codes[0]; +void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const { + int tcode = args[0]; - if (handle == nullptr) return; + if (tcode == kTVMNullptr) return; if (tcode == kTVMPackedFuncHandle) { - auto wf = std::make_shared(handle, sess); - *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto wf = std::make_shared(handle, sess_); + *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } else if (tcode == kTVMModuleHandle) { - auto n = make_object(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); *rv = Module(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { - CHECK_EQ(args.size(), 2); - DLTensor* tensor = args[0]; - void* nd_handle = args[1]; - *rv = WrapRemoteNDArray(sess, tensor, nd_handle); + CHECK_EQ(args.size(), 3); + DLTensor* tensor = args[1]; + void* nd_handle = args[2]; + *rv = WrapRemoteNDArray(tensor, nd_handle); } else { - LOG(FATAL) << "Cannot wrap tcode=" << tcode; + CHECK_EQ(args.size(), 2); + *rv = args[1]; } } -Module CreateRPCModule(std::shared_ptr sess) { +Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); + RPCSession::InsertToSessionTable(sess); return Module(n); } -TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - TVMContext ctx; - ctx.device_type = static_cast(args[2].operator int()); - ctx.device_id = args[3]; - if (tkey == "rpc") { - *rv = static_cast(m.operator->()) - ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); - } else { - *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); +std::shared_ptr RPCModuleGetSession(Module mod) { + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + return rmod->sess(); +} + +PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + CHECK(pf != nullptr); + + if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, ctx, number, repeat); + } + + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + + for (int i = 0; i < repeat; ++i) { + std::chrono::time_point tbegin, + tend; + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + tbegin = std::chrono::high_resolution_clock::now(); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + tend = std::chrono::high_resolution_clock::now(); + + duration_ms = + std::chrono::duration_cast>(tend - tbegin).count() * 1000; + } while (duration_ms < min_repeat_ms); + + double speed = + std::chrono::duration_cast>(tend - tbegin).count() / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); } - }); - -TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - auto& sess = static_cast(m.operator->())->sess(); - void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - auto n = make_object(mhandle, sess); - *rv = Module(n); - }); - -TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module parent = args[0]; - Module child = args[1]; - CHECK(!std::strcmp(parent->type_key(), "rpc") && - !std::strcmp(child->type_key(), "rpc")); - auto* pmod = static_cast(parent.operator->()); - auto* cmod = static_cast(child.operator->()); - CHECK(pmod->sess().get() == cmod->sess().get()) - << "Import of remote module need to belong to same session."; - pmod->sess()->CallRemote(RPCCode::kModuleImport, - pmod->module_handle(), - cmod->module_handle()); - }); - -TVM_REGISTER_GLOBAL("rpc._ModuleHandle") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->module_handle(); - }); - -TVM_REGISTER_GLOBAL("rpc._SessTableIndex") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + +TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") + .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, + int number, int repeat, int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + if (tkey == "rpc") { + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); + } else { + return WrapTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); + } + }); + +// server function registration. +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { + parent->Import(child); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") + .set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); + }); + +// functions to access an RPC module. +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { + std::string tkey = sess->type_key(); + CHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); +}); + +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { + std::string tkey = parent->type_key(); + CHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); +}); + +TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + std::string tkey = m->type_key(); + CHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc new file mode 100644 index 0000000000000..2f42435749096 --- /dev/null +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_pipe_impl.cc + * \brief Pipe-based RPC channel. + */ +// Linux only for now, as linux is the most common usecase. +#if defined(__linux__) || defined(__ANDROID__) + +#include +#include +#include +#include +#include + +#include +#include + +#include "../../support/pipe.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +class PipeChannel final : public RPCChannel { + public: + explicit PipeChannel(int readfd, int writefd, pid_t child_pid) + : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {} + + ~PipeChannel() { Close(); } + + size_t Send(const void* data, size_t size) final { + ssize_t n = write(writefd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe write error"; + } + return static_cast(n); + } + + size_t Recv(void* data, size_t size) final { + ssize_t n = read(readfd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe read error"; + } + return static_cast(n); + } + + void Close() { + close(readfd_); + close(writefd_); + kill(child_pid_, SIGKILL); + } + + private: + int readfd_; + int writefd_; + pid_t child_pid_; +}; + +Module CreatePipeClient(std::vector cmd) { + int parent2child[2]; + int child2parent[2]; + CHECK_EQ(pipe(parent2child), 0); + CHECK_EQ(pipe(child2parent), 0); + + int parent_read = child2parent[0]; + int parent_write = parent2child[1]; + int child_read = parent2child[0]; + int child_write = child2parent[1]; + + pid_t pid = fork(); + if (pid == 0) { + // child process + close(parent_read); + close(parent_write); + std::string sread_pipe = std::to_string(child_read); + std::string swrite_pipe = std::to_string(child_write); + std::vector argv; + for (auto& str : cmd) { + argv.push_back(dmlc::BeginPtr(str)); + } + argv.push_back(dmlc::BeginPtr(sread_pipe)); + argv.push_back(dmlc::BeginPtr(swrite_pipe)); + argv.push_back(nullptr); + execvp(argv[0], &argv[0]); + } + // parent process + close(child_read); + close(child_write); + + auto endpt = RPCEndpoint::Create( + std::unique_ptr(new PipeChannel(parent_read, parent_write, pid)), "pipe", + "pipe"); + endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); + return CreateRPCSessionModule(CreateClientSession(endpt)); +} + +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].operator std::string()); + } + *rv = CreatePipeClient(cmd); +}); + +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h new file mode 100644 index 0000000000000..3a0555d0cc6d4 --- /dev/null +++ b/src/runtime/rpc/rpc_protocol.h @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_procotol.h + * \brief Common header defining the communication code used in the RPC protocol. + */ +#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ +#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ + +namespace tvm { +namespace runtime { + +/*! \brief The current RPC procotol version. */ +constexpr const char* kRPCProtocolVer = "0.7.0"; + +/*! \brief The RPC code */ +enum class RPCCode : int { + kNone, + kShutdown, + kInitServer, + kCallFunc, + kReturn, + kException, + kCopyFromRemote, + kCopyToRemote, + kCopyAck, + // The following are syscall code that can send over CallRemote + kSyscallCodeStart, + kGetGlobalFunc = kSyscallCodeStart, + kFreeHandle, + kDevSetDevice, + kDevGetAttr, + kDevAllocData, + kDevFreeData, + kDevStreamSync, + kCopyAmongRemote, +}; + +/*! + * \brief List of potential error status during rpc communication. + */ +enum class RPCServerStatus : int { + kSuccess = 0, + kInvalidTypeCodeObject, + kInvalidTypeCodeNDArray, + kInvalidDLTensorFieldStride, + kInvalidDLTensorFieldByteOffset, + kUnknownTypeCode, + kUnknownRPCCode, + kRPCCodeNotSupported, + kUnknownRPCSyscall, + kCheckError, + kReadError, + kWriteError, + kAllocError +}; + +/*! + * \brief Convert RPC server status to string. + * \param status The status. + * \return The corresponding string. + */ +inline const char* RPCServerStatusToString(RPCServerStatus status) { + switch (status) { + case RPCServerStatus::kSuccess: + return "kSuccess"; + case RPCServerStatus::kInvalidTypeCodeObject: + return "kInvalidTypeCodeObject"; + case RPCServerStatus::kInvalidTypeCodeNDArray: + return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidDLTensorFieldStride: + return "kInvalidDLTensorFieldStride"; + case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { + return "kInvalidDLTensorFieldByteOffset"; + } + case RPCServerStatus::kUnknownTypeCode: + return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownRPCCode: + return "kUnknownRPCCode"; + case RPCServerStatus::kRPCCodeNotSupported: + return "RPCCodeNotSupported"; + case RPCServerStatus::kUnknownRPCSyscall: + return "kUnknownRPCSyscall"; + case RPCServerStatus::kCheckError: + return "kCheckError"; + case RPCServerStatus::kReadError: + return "kReadError"; + case RPCServerStatus::kWriteError: + return "kWriteError"; + case RPCServerStatus::kAllocError: + return "kAllocError"; + default: + return ""; + } +} + +/*! + * \brief Reference implementation of the communication protocol. + * + * \note The implementation is intentionally written via template + * so it can be used in a dependency free setting. + * + * \sa src/runtime/rpc/device/min_rpc_server.h + */ +struct RPCReference { + /*! + * \brief Auxiliary class to get the packed sequence. + * \tparam TChannel The channel to throw errror. + */ + template + struct PackedSeqNumBytesGetter { + public: + explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {} + + template + void Write(const T& value) { + num_bytes_ += sizeof(T); + } + + template + void WriteArray(const T* value, size_t num) { + num_bytes_ += sizeof(T) * num; + } + + void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } + + uint64_t num_bytes() const { return num_bytes_; } + + private: + TChannel* channel_; + uint64_t num_bytes_{0}; + }; + + /*! + * \return the length of the str. + * \param str the string. + * \return The length. + */ + static uint64_t StrLength(const char* str) { + uint64_t len = 0; + while (str[len] != '\0') ++len; + return len; + } + + /*! + * \brief Get the total nbytes to be sent in the packed sequence. + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \return The total number of bytes. + */ + template + static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, + int num_args, bool client_mode, TChannel* channel) { + PackedSeqNumBytesGetter getter(channel); + SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); + return getter.num_bytes(); + } + + /*! + * \brief Send packed argument sequnce to the other peer. + * + * This function serves as the foundational communication primitive between peers. + * + * TVMValue sequence encoding protocol(according to the type): + * + * - int/float/uint/bytes/str: Serialize all content. + * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) + * - OpaqueHandle: send as uint64_t + * - ModuleHandle, PackedFuncHandle: send as uint64_t, + * The support to Module/PackedFuncHandle are reserved for arguments + * in the CallFunc from a client to server only. + * Note that we cannot simply take these argument out(as the handle) + * refers to a value on the remote(instead of local). + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode, TChannel* channel) { + channel->Write(num_args); + channel->WriteArray(type_codes, num_args); + + // Argument packing. + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + TVMValue value = arg_values[i]; + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Write(value.v_int64); + break; + } + case kTVMDataType: { + channel->Write(value.v_type); + // padding + int32_t padding = 0; + channel->template Write(padding); + break; + } + case kTVMContext: { + channel->Write(value.v_ctx); + break; + } + + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + if (!client_mode) { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); + } + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMNDArrayHandle: { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); + break; + } + case kTVMDLTensorHandle: { + DLTensor* arr = static_cast(value.v_handle); + TVMContext ctx; + uint64_t data; + // When we return NDArray, we directly return + // the space and the context + // The client will be further wrapping + ctx = arr->ctx; + data = reinterpret_cast(arr->data); + channel->Write(data); + channel->Write(ctx); + channel->Write(arr->ndim); + channel->Write(arr->dtype); + channel->WriteArray(arr->shape, arr->ndim); + if (arr->strides != nullptr) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); + } + if (arr->byte_offset != 0) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldByteOffset); + } + break; + } + case kTVMNullptr: + break; + case kTVMStr: { + const char* s = value.v_str; + uint64_t len = StrLength(s); + channel->Write(len); + channel->WriteArray(s, len); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + uint64_t len = bytes->size; + channel->Write(len); + channel->WriteArray(bytes->data, len); + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Receive packed seq from the channel. + * + * \param out_arg_values The values to be received. + * \param out_tcodes The type codes to be received. + * \param out_num_args Number of argument. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \note The temporary space are populated via an arena inside channel. + */ + template + static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args, + TChannel* channel) { + // receive number of args + int num_args; + channel->Read(&num_args); + *out_num_args = num_args; + + if (num_args == 0) { + *out_values = nullptr; + *out_tcodes = nullptr; + return; + } + + TVMValue* values = channel->template ArenaAlloc(num_args); + int* tcodes = channel->template ArenaAlloc(num_args); + *out_values = values; + *out_tcodes = tcodes; + + // receive type code. + channel->ReadArray(tcodes, num_args); + + // receive arguments + for (int i = 0; i < num_args; ++i) { + auto& value = values[i]; + switch (tcodes[i]) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Read(&(value.v_int64)); + break; + } + case kTVMDataType: { + channel->Read(&(value.v_type)); + int32_t padding = 0; + channel->template Read(&padding); + break; + } + case kTVMContext: { + channel->Read(&(value.v_ctx)); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle; + channel->Read(&handle); + value.v_handle = reinterpret_cast(handle); + break; + } + case kTVMNullptr: { + value.v_handle = nullptr; + break; + } + case kTVMStr: { + uint64_t len; + channel->Read(&len); + char* str = channel->template ArenaAlloc(len + 1); + str[len] = '\0'; + channel->ReadArray(str, len); + value.v_str = str; + break; + } + case kTVMBytes: { + uint64_t len; + channel->Read(&len); + TVMByteArray* arr = channel->template ArenaAlloc(1); + char* data = channel->template ArenaAlloc(len); + arr->size = len; + arr->data = data; + channel->ReadArray(data, len); + value.v_handle = arr; + break; + } + case kTVMDLTensorHandle: { + uint64_t handle; + channel->Read(&handle); + DLTensor* arr = channel->template ArenaAlloc(1); + DLTensor& tensor = *arr; + tensor.data = reinterpret_cast(handle); + channel->Read(&(tensor.ctx)); + channel->Read(&(tensor.ndim)); + channel->Read(&(tensor.dtype)); + tensor.shape = channel->template ArenaAlloc(tensor.ndim); + channel->ReadArray(tensor.shape, tensor.ndim); + tensor.strides = nullptr; + tensor.byte_offset = 0; + value.v_handle = arr; + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Return an exception packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnException(const char* msg, TChannel* channel) { + RPCCode code = RPCCode::kException; + int32_t num_args = 1; + int32_t tcode = kTVMStr; + uint64_t len = StrLength(msg); + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + channel->Write(len); + channel->WriteArray(msg, len); + } + + /*! + * \brief Return a normal packed sequence packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + TChannel* channel) { + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); + + channel->Write(packet_nbytes); + channel->Write(code); + SendPackedSeq(arg_values, type_codes, num_args, false, channel); + } + + /*! + * \brief Return a null(void) packet. + * + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnVoid(TChannel* channel) { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index f6a7fb60b5f4d..b999a48a376a4 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,42 +22,40 @@ * \brief Server environment of the RPC. */ #include + #include "../file_util.h" namespace tvm { namespace runtime { std::string RPCGetPath(const std::string& name) { - static const PackedFunc* f = - runtime::Registry::Get("tvm.rpc.server.workpath"); + // do live lookup everytime as workpath can change. + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload"). -set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data = args[1]; - SaveBinaryToFile(file_name, data); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.download") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data; - LoadBinaryFromFile(file_name, &data); - TVMByteArray arr; - arr.data = data.c_str(); - arr.size = data.length(); - LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; - *rv = arr; - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - RemoveFile(file_name); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data = args[1]; + SaveBinaryToFile(file_name, data); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.download").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data; + LoadBinaryFromFile(file_name, &data); + TVMByteArray arr; + arr.data = data.c_str(); + arr.size = data.length(); + LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; + *rv = arr; +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.remove").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + RemoveFile(file_name); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 43ca630f94963..9e05e5d1628d4 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,816 +21,84 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include -#include +#include "rpc_session.h" + #include -#include -#include -#include +#include + #include -#include -#include -#include -#include -#include -#include -#include "rpc_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" -#include "../../support/socket.h" +#include namespace tvm { namespace runtime { -// Temp buffer for data array -struct RPCByteArrayBuffer { - TVMByteArray arr; - std::string data; -}; -// Temp buffer for data array -struct RPCDataArrayBuffer { - DLTensor tensor; - std::vector shape; -}; -/*! - * \brief Temporal argument buffer. - */ -struct RPCArgBuffer { - // The argument values - std::vector value; - // The type codes. - std::vector tcode; - // Temporal resources. - std::vector > temp_bytes; - // Temporal array - std::vector > temp_array; - // convert buffer as TVMArgs - TVMArgs AsTVMArgs() const { - return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); - } -}; - -// Event handler for RPC events. -class RPCSession::EventHandler : public dmlc::Stream { - public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - int rpc_sess_table_index, - std::string name, - std::string* remote_key) - : reader_(reader), - writer_(writer), - rpc_sess_table_index_(rpc_sess_table_index), - name_(name), - remote_key_(remote_key) { - this->Clear(); - if (*remote_key == "%toinit") { - state_ = kInitHeader; - remote_key_->resize(0); - pending_request_bytes_ = sizeof(int32_t); - } - } - // Bytes needed to fulfill current request - size_t BytesNeeded() { - if (reader_->bytes_available() < pending_request_bytes_) { - return pending_request_bytes_ - reader_->bytes_available(); - } else { - return 0; - } - } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; - } - bool CanCleanShutdown() const { - return state_ == kRecvCode; - } - void FinishCopyAck() { - this->SwitchToState(kRecvCode); - } - RPCCode HandleNextEvent(TVMRetValue* rv, - bool client_mode, - const PackedFunc* fwrap) { - std::swap(client_mode_, client_mode); - while (this->Ready()) { - switch (state_) { - case kInitHeader: HandleInitHeader(); break; - case kRecvCode: HandleRecvCode(); break; - case kRecvCallHandle: { - CHECK(this->Read(&call_handle_)); - this->SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case kRecvPackedSeqNumArgs: { - CHECK(this->Read(&num_packed_args_)); - arg_buf_.reset(new RPCArgBuffer()); - arg_buf_->value.resize(num_packed_args_); - arg_buf_->tcode.resize(num_packed_args_); - this->SwitchToState(kRecvPackedSeqTypeCode); - break; - } - case kRecvPackedSeqTypeCode: { - if (num_packed_args_ != 0) { - this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); - } - arg_index_ = 0; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kRecvPackedSeqArg: { - this->HandleRecvPackedSeqArg(); - break; - } - case kDoCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case kDoCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case kReturnReceived: { - CHECK_GE(arg_buf_->value.size(), 1U); +bool RPCSession::IsAsync() const { return false; } - TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; - if (argv.type_code() == kTVMPackedFuncHandle || - argv.type_code() == kTVMModuleHandle || - argv.type_code() == kTVMDLTensorHandle) { - CHECK(fwrap != nullptr) << "function/module wrapper not available"; - fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); - } else { - CHECK_EQ(arg_buf_->value.size(), 1U); - *rv = argv; - } - arg_buf_.reset(); - this->SwitchToState(kRecvCode); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; - } - case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; - } - case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; - } - } - } - std::swap(client_mode_, client_mode); - return RPCCode::kNone; - } - // Reset and clear all states. - void Clear() { - state_ = kRecvCode; - pending_request_bytes_ = sizeof(RPCCode); - arg_recv_stage_ = 0; - arg_buf_.reset(); - } - // strip session on mask - TVMContext StripSessMask(TVMContext ctx) { - int dev_type = ctx.device_type; - CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1) - << "Can not pass in local context or context with a different remote session"; - ctx.device_type = static_cast(dev_type % kRPCSessMask); - return ctx; - } - // Send Packed sequence to writer. - // - // client_mode: whether we are in client mode. - // - // funwrap: auxiliary function to unwrap remote Object - // when it is provided, we need to unwrap objects. - // - // return_ndarray is a special flag to handle returning of ndarray - // In this case, we return the shape, context and data of the array, - // as well as a customized PackedFunc that handles deletion of - // the array in the remote. - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - FUnwrapRemoteObject funwrap = nullptr, - bool return_ndarray = false) { - std::swap(client_mode_, client_mode); - - this->Write(num_args); - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle; - this->Write(tcode); - } - - // Argument packing. - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Write(value.v_int64); - break; - } - case kTVMDataType: { - this->Write(value.v_type); - // padding - int32_t padding = 0; - this->Write(padding); - break; - } - case kTVMContext: { - value.v_ctx = StripSessMask(value.v_ctx); - this->Write(value.v_ctx); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { - // always send handle in 64 bit. - uint64_t handle; - // allow pass module as argument to remote. - if (funwrap != nullptr) { - void* remote_handle = (*funwrap)( - rpc_sess_table_index_, - runtime::TVMArgValue(value, tcode)); - handle = reinterpret_cast(remote_handle); - } else { - CHECK(!client_mode_) - << "Cannot directly pass remote object as argument"; - handle = reinterpret_cast(value.v_handle); - } - this->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - this->Write(handle); - break; - } - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); - TVMContext ctx; - uint64_t data; - if (!return_ndarray) { - // in the client mode - // ctx contains the remote table index - // the space is wrapped by an RemoteSpace - // that holds reference to the session. - ctx = StripSessMask(arr->ctx); - data = reinterpret_cast( - static_cast(arr->data)->data); - } else { - // When we return NDArray, we directly return - // the space and the context - // The client will be further wrapping - ctx = arr->ctx; - data = reinterpret_cast(arr->data); - } - this->Write(data); - this->Write(ctx); - this->Write(arr->ndim); - this->Write(arr->dtype); - this->WriteArray(arr->shape, arr->ndim); - CHECK(arr->strides == nullptr) - << "Do not support strided remote array"; - CHECK_EQ(arr->byte_offset, 0) - << "Do not support send byte offset"; - break; - } - case kTVMNullptr: break; - case kTVMStr: { - const char* s = value.v_str; - uint64_t len = strlen(s); - this->Write(len); - this->WriteArray(s, len); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); - uint64_t len = bytes->size; - this->Write(len); - this->WriteArray(bytes->data, len); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - std::swap(client_mode_, client_mode); - } - - // Endian aware IO handling - using Stream::Read; - using Stream::Write; - using Stream::ReadArray; - using Stream::WriteArray; - - inline bool Read(RPCCode* code) { - int cdata; - if (!this->Read(&cdata)) return false; - *code = static_cast(cdata); - return true; - } - inline void Write(RPCCode code) { - int cdata = static_cast(code); - this->Write(cdata); - } +void RPCSession::SendException(FAsyncCallback callback, const char* msg) { + TVMValue value; + value.v_str = msg; + int32_t tcode = kTVMStr; + callback(RPCCode::kException, TVMArgs(&value, &tcode, 1)); +} - protected: - enum State { - kInitHeader, - kRecvCode, - kRecvCallHandle, - kRecvPackedSeqNumArgs, - kRecvPackedSeqTypeCode, - kRecvPackedSeqArg, - kDoCopyFromRemote, - kDoCopyToRemote, - kReturnReceived, - kCopyAckReceived, - kShutdownReceived - }; - // Current state; - State state_; - // The RPCCode to be read. - RPCCode code_; - // Handle for the remote function call. - uint64_t call_handle_; - // Initialize remote header - bool init_header_step_{0}; - // Number of packed arguments. - int num_packed_args_; - // Current argument index. - int arg_index_; - // The stage of each argument receiver. - int arg_recv_stage_; - // Whether current handler is client or server mode. - bool client_mode_{false}; - // Argument buffer - std::unique_ptr arg_buf_; - // Temp byte buffer. - std::unique_ptr temp_bytes_; - // Temp array buffer. - std::unique_ptr temp_array_; - // Internal temporal data space. - std::string temp_data_; - // Temp variables for copy request state. - TVMContext copy_ctx_; - DLDataType copy_dtype_; - uint64_t copy_handle_, copy_offset_, copy_size_; - // State switcher - void SwitchToState(State state) { - // invariant - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; - state_ = state; - switch (state) { - case kInitHeader: { - LOG(FATAL) << "cannot switch to init header"; - break; - } - case kRecvCode: { - this->RequestBytes(sizeof(RPCCode)); - break; - } - case kRecvCallHandle: { - this->RequestBytes(sizeof(call_handle_)); - break; - } - case kRecvPackedSeqNumArgs: { - this->RequestBytes(sizeof(num_packed_args_)); - break; - } - case kRecvPackedSeqTypeCode: { - this->RequestBytes(sizeof(int) * num_packed_args_); - break; - } - case kRecvPackedSeqArg: { - CHECK_LE(arg_index_, num_packed_args_); - if (arg_index_ == num_packed_args_) { - // The function can change state_ again. - HandlePackedCall(); - } else { - RequestRecvPackedSeqArg(); - } - break; - } - case kDoCopyFromRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kDoCopyToRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kCopyAckReceived: - case kReturnReceived: - case kShutdownReceived: { - break; - } - } - } - // Requets bytes needed for next computation. - void RequestRecvPackedSeqArg() { - CHECK_EQ(arg_recv_stage_, 0); - int tcode = arg_buf_->tcode[arg_index_]; - static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: - case kTVMDataType: - case kTVMOpaqueHandle: - case kTVMStr: - case kTVMBytes: - case kTVMModuleHandle: - case kTVMContext: { - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMPackedFuncHandle: { - CHECK(client_mode_) - << "Only client can receive remote functions"; - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMNullptr: break; - case kTVMDLTensorHandle: { - this->RequestBytes(sizeof(uint64_t)); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(int)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - // Handler for packed sequence argument receive. - void HandleRecvPackedSeqArg() { - CHECK_LT(arg_index_, num_packed_args_); - int tcode = arg_buf_->tcode[arg_index_]; - TVMValue& value = arg_buf_->value[arg_index_]; - if (arg_recv_stage_ == 0) { - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Read(&(value.v_int64)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMDataType: { - this->Read(&(value.v_type)); - int32_t padding = 0; - this->Read(&padding); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMContext: { - this->Read(&(value.v_ctx)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle; - this->Read(&handle); - value.v_handle = reinterpret_cast(handle); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMStr: - case kTVMBytes: { - uint64_t len; - this->Read(&len); - temp_bytes_.reset( new RPCByteArrayBuffer()); - temp_bytes_->data.resize(len); - arg_recv_stage_ = 1; - this->RequestBytes(len); - break; - } - case kTVMDLTensorHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } else { - CHECK_EQ(arg_recv_stage_, 1); - if (tcode == kTVMStr || tcode == kTVMBytes) { - if (temp_bytes_->data.size() != 0) { - this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); - } - if (tcode == kTVMStr) { - value.v_str = temp_bytes_->data.c_str(); - } else { - temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); - temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); - value.v_handle = &(temp_bytes_->arr); - } - arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); - } else { - CHECK_EQ(tcode, kTVMDLTensorHandle); - DLTensor& tensor = temp_array_->tensor; - this->ReadArray(tensor.shape, tensor.ndim); - value.v_handle = &tensor; - arg_buf_->temp_array.emplace_back(std::move(temp_array_)); - } - ++arg_index_; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - } - } - // handler for initial header read - void HandleInitHeader() { - if (init_header_step_ == 0) { - int32_t len; - this->Read(&len); - remote_key_->resize(len); - init_header_step_ = 1; - this->RequestBytes(len); - return; - } else { - CHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); - this->SwitchToState(kRecvCode); - } - } - // Handler for read code. - void HandleRecvCode() { - this->Read(&code_); - if (code_ > RPCCode::kSystemFuncStart) { - SwitchToState(kRecvPackedSeqNumArgs); - return; - } - // invariant. - CHECK_EQ(arg_recv_stage_, 0); - switch (code_) { - case RPCCode::kCallFunc: { - SwitchToState(kRecvCallHandle); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case RPCCode::kCopyFromRemote: { - SwitchToState(kDoCopyFromRemote); - break; - } - case RPCCode::kCopyToRemote: { - SwitchToState(kDoCopyToRemote); - break; - } - case RPCCode::kShutdown: { - SwitchToState(kShutdownReceived); - break; - } - case RPCCode::kCopyAck: { - SwitchToState(kCopyAckReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } +void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback) { + try { + this->CallFunc(func, arg_values, arg_type_codes, num_args, + [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } +} - void HandleCopyFromRemote() { - uint64_t handle, offset, num_bytes; - TVMContext ctx; - DLDataType type_hint; - this->Read(&handle); - this->Read(&offset); - this->Read(&num_bytes); - this->Read(&ctx); - this->Read(&type_hint); - size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; +void RPCSession::AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - if (ctx.device_type == kDLCPU) { - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - char* dptr = reinterpret_cast(handle) + offset; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - temp_data_.resize(0); - temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - this->WriteArray(temp_data_.data(), num_bytes); - } else { - this->WriteArray(dptr, num_bytes); - } - } else { - temp_data_.resize(num_bytes + 1); - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(ctx)->CopyDataFromTo( - reinterpret_cast(handle), offset, - dmlc::BeginPtr(temp_data_), 0, - num_bytes, ctx, cpu_ctx, type_hint, nullptr); - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - } - this->WriteArray(&temp_data_[0], num_bytes); - } catch (const std::runtime_error &e) { - RPCCode code = RPCCode::kException; - this->Write(code); - TVMValue ret_value; - ret_value.v_str = e.what(); - int ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - this->SwitchToState(kRecvCode); + try { + this->CopyToRemote(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + remote_ctx_to, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } +} - void HandleCopyToRemote() { - // use static variable to persist state. - // This only works if next stage is immediately after this. - if (arg_recv_stage_ == 0) { - CHECK(this->Read(©_handle_)); - CHECK(this->Read(©_offset_)); - CHECK(this->Read(©_size_)); - CHECK(this->Read(©_ctx_)); - CHECK(this->Read(©_dtype_)); - arg_recv_stage_ = 1; - CHECK_EQ(pending_request_bytes_, 0U); - this->RequestBytes(copy_size_); - } else { - CHECK_EQ(arg_recv_stage_, 1); - TVMValue ret_value; - ret_value.v_handle = nullptr; - int ret_tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - std::string errmsg; +void RPCSession::AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; - if (copy_ctx_.device_type == kDLCPU) { - char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; - this->ReadArray(dptr, copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); - } - } else { - temp_data_.resize(copy_size_ + 1); - this->ReadArray(&temp_data_[0], copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); - } - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( - temp_data_.data(), 0, - reinterpret_cast(copy_handle_), copy_offset_, - copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr); - } catch (const std::runtime_error &e) { - code = RPCCode::kException; - errmsg = e.what(); - ret_value.v_str = errmsg.c_str(); - ret_tcode = kTVMStr; - } - } - this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - arg_recv_stage_ = 0; - this->SwitchToState(kRecvCode); - } + try { + this->CopyFromRemote(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } - // Handle for packed call. - void HandlePackedCall(); +} - template - void CallHandler(F f) { - TVMRetValue rv; - TVMValue ret_value; - int ret_tcode; - try { - // Need to move out, in case f itself need to call RecvPackedSeq - // Which will override argbuf again. - std::unique_ptr args = std::move(arg_buf_); - f(args->AsTVMArgs(), &rv); - RPCCode code = RPCCode::kReturn; - this->Write(code); - if (rv.type_code() == kTVMStr) { - ret_value.v_str = rv.ptr()->c_str(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMBytes) { - std::string* bytes = rv.ptr(); - TVMByteArray arr; - arr.data = bytes->c_str(); - arr.size = bytes->length(); - ret_value.v_handle = &arr; - ret_tcode = kTVMBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMPackedFuncHandle || - rv.type_code() == kTVMModuleHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send function and module handle back."; - rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMNDArrayHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send NDArray back"; - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. - TVMValue ret_value_pack[2]; - int ret_tcode_pack[2]; - rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; - ret_tcode_pack[1] = kTVMOpaqueHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); - } else { - ret_value = rv.value(); - ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } catch (const std::runtime_error& e) { - RPCCode code = RPCCode::kException; - this->Write(code); - ret_value.v_str = e.what(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } +void RPCSession::AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - private: - // Utility functions - // Internal read function, update pending_request_bytes_ - size_t Read(void* data, size_t size) final { - CHECK_LE(size, pending_request_bytes_); - reader_->Read(data, size); - pending_request_bytes_ -= size; - return size; - } - void Write(const void* data, size_t size) final { - writer_->Write(data, size); + try { + this->GetDeviceAPI(ctx)->StreamSync(ctx, stream); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } - // Number of pending bytes requests - size_t pending_request_bytes_; - // The ring buffer to read data from. - support::RingBuffer* reader_; - // The ringr buffer to write reply to. - support::RingBuffer* writer_; - // Session table index. - int rpc_sess_table_index_; - // Name of session. - std::string name_; - // remote key - std::string* remote_key_; -}; +} -struct RPCSessTable { +class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; // Get global singleton @@ -848,7 +116,8 @@ struct RPCSessTable { std::lock_guard lock(mutex_); for (int i = 0; i < kMaxRPCSession; ++i) { if (tbl_[i].lock() == nullptr) { - tbl_[i] = ptr; return i; + tbl_[i] = ptr; + return i; } } LOG(FATAL) << "maximum number of RPC session reached"; @@ -863,493 +132,13 @@ struct RPCSessTable { std::array, kMaxRPCSession> tbl_; }; -RPCCode RPCSession::HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) { - RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { - while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - size_t bytes_needed = handler_->BytesNeeded(); - if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); - if (n == 0) { - if (handler_->CanCleanShutdown()) { - return RPCCode::kShutdown; - } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; - } - } - } - code = handler_->HandleNextEvent(rv, client_mode, fwrap); - } - return code; -} - -void RPCSession::Init() { - // Event handler - handler_ = std::make_shared( - &reader_, &writer_, table_index_, name_, &remote_key_); - // Quick function to call remote. - call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); - RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); - }); -} - -std::shared_ptr RPCSession::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { - std::shared_ptr sess = std::make_shared(); - sess->channel_ = std::move(channel); - sess->name_ = std::move(name); - sess->remote_key_ = std::move(remote_key); - sess->table_index_ = RPCSessTable::Global()->Insert(sess); - sess->Init(); - return sess; -} - std::shared_ptr RPCSession::Get(int table_index) { return RPCSessTable::Global()->Get(table_index); } -RPCSession::~RPCSession() { - this->Shutdown(); -} - -void RPCSession::Shutdown() { - if (channel_ != nullptr) { - RPCCode code = RPCCode::kShutdown; - handler_->Write(code); - // flush all writing buffer to output channel. - try { - while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - if (n == 0) break; - } - } catch (const dmlc::Error& e) { - } - channel_.reset(nullptr); - } -} - -void RPCSession::ServerLoop() { - std::lock_guard lock(mutex_); - if (const auto* f = Registry::Get("tvm.rpc.server.start")) { - (*f)(); - } - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { - (*f)(); - } - channel_.reset(nullptr); -} - -int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { - std::lock_guard lock(mutex_); - RPCCode code = RPCCode::kNone; - if (bytes.length() != 0) { - reader_.Write(bytes.c_str(), bytes.length()); - TVMRetValue rv; - code = handler_->HandleNextEvent(&rv, false, nullptr); - } - if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); - if (code == RPCCode::kShutdown) return 0; - if (writer_.bytes_available() != 0) return 2; - return 1; -} - -// Get remote function with name -void RPCSession::CallFunc(void* h, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap) { - std::lock_guard lock(mutex_); - - RPCCode code = RPCCode::kCallFunc; - handler_->Write(code); - uint64_t handle = reinterpret_cast(h); - handler_->Write(handle); - handler_->SendPackedSeq( - args.values, args.type_codes, args.num_args, true, funwrap); - code = HandleUntilReturnEvent(rv, true, fwrap); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); -} - -void RPCSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_to = handler_->StripSessMask(ctx_to); - RPCCode code = RPCCode::kCopyToRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(to); - handler_->Write(handle); - uint64_t offset = static_cast(to_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_to); - handler_->Write(type_hint); - handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); -} - -void RPCSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_from = handler_->StripSessMask(ctx_from); - RPCCode code = RPCCode::kCopyFromRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(from); - handler_->Write(handle); - uint64_t offset = static_cast(from_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_from); - handler_->Write(type_hint); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); - reader_.Reserve(data_size); - handler_->RequestBytes(data_size); - while (!handler_->Ready()) { - size_t bytes_needed = handler_->BytesNeeded(); - reader_.WriteWithCallback([this](void* data, size_t size) { - size_t n = channel_->Recv(data, size); - CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; - return n; - }, bytes_needed); - } - handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); - handler_->FinishCopyAck(); -} - -RPCFuncHandle RPCSession::GetTimeEvaluator( - RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - return this->CallRemote( - RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms); -} - -// Event handler functions -void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - auto *fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - *rv = static_cast(new tvm::runtime::PackedFunc(*fp)); - } else { - *rv = nullptr; - } -} - -void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - delete static_cast(handle); -} - -void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAPI::Get(ctx)->SetDevice(ctx); -} - -void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAttrKind kind = static_cast(args[1].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPI::Get(ctx, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, rv); - } else { - *rv = 0; - } - } else { - DeviceAPI::Get(ctx)->GetAttr( - ctx, static_cast(kind), rv); - } -} - -void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - uint64_t nbytes = args[1]; - uint64_t alignment = args[2]; - DLDataType type_hint = args[3]; - void* data = DeviceAPI::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); - *rv = data; -} - -void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - void* ptr = args[1]; - DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr); -} - -void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - DeviceAPI::Get(ctx)->StreamSync(ctx, handle); -} - -void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { - void* from = args[0]; - uint64_t from_offset = args[1]; - void* to = args[2]; - uint64_t to_offset = args[3]; - uint64_t size = args[4]; - TVMContext ctx_from = args[5]; - TVMContext ctx_to = args[6]; - DLDataType type_hint = args[7]; - TVMStreamHandle stream = args[8]; - TVMContext ctx = ctx_from; - if (ctx.device_type == kDLCPU) { - ctx = ctx_to; - } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) - << "Can not copy across different ctx types directly"; - } - DeviceAPI::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); -} - -void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { - static const PackedFunc* fsys_load_ = nullptr; - if (fsys_load_ == nullptr) { - fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module"); - CHECK(fsys_load_ != nullptr); - } - std::string file_name = args[0]; - TVMRetValue ret = (*fsys_load_)(file_name); - // pass via void* - TVMValue value; - int rcode; - ret.MoveToCHost(&value, &rcode); - CHECK_EQ(rcode, kTVMModuleHandle); - *rv = static_cast(value.v_handle); -} - -void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { - void* pmod = args[0]; - void* cmod = args[1]; - ObjectInternal::GetModuleNode(pmod)->Import( - GetRef(ObjectInternal::GetModuleNode(cmod))); -} - -void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - ObjectInternal::ObjectFree(mhandle); -} - -void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( - args[1], false); - if (pf != nullptr) { - *rv = static_cast(new PackedFunc(pf)); - } else { - *rv = nullptr; - } -} - -void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - std::string fmt = args[1]; - *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); -} - -void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - static_cast( - reinterpret_cast(handle))->DecRef(); -} - -void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { - PackedFunc *pf = static_cast(args[0].operator void*()); - void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4])); - delete pf; - *rv = fhandle; -} - -void RPCSession::EventHandler::HandlePackedCall() { - CHECK_EQ(pending_request_bytes_, 0U); - if (code_ == RPCCode::kReturn) { - state_ = kReturnReceived; return; - } - // reset state to clean init state - state_ = kRecvCode; - this->RequestBytes(sizeof(RPCCode)); - // Event handler sit at clean state at this point. - switch (code_) { - case RPCCode::kCallFunc: { - PackedFunc* pf = reinterpret_cast(call_handle_); - CallHandler([pf](TVMArgs args, TVMRetValue* rv) { - pf->CallPacked(args, rv); - }); - break; - } - case RPCCode::kException: { - CHECK_EQ(arg_buf_->value.size(), 1U); - CHECK_EQ(arg_buf_->tcode[0], kTVMStr); - std::ostringstream os; - os << "Except caught from RPC call: " << arg_buf_->value[0].v_str; - arg_buf_.reset(); - throw dmlc::Error(os.str()); - break; - } - // system functions - case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; - case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; - case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; - case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; - case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; - case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; - case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; - case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; - case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; - case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - CHECK_EQ(state_, kRecvCode); -} - -PackedFunc MicroTimeEvaluator( - PackedFunc pf, - TVMContext ctx, - int number, - int repeat) { - auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - for (int i = 0; i < repeat; ++i) { - double speed = 0.0; - for (int j = 0; j < number; ++j) { - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - speed += (temp.operator double()) / number; - } - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { - return MicroTimeEvaluator(pf, ctx, number, repeat); - } - - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - tbegin = std::chrono::high_resolution_clock::now(); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - tend = std::chrono::high_resolution_clock::now(); - - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; - } while (duration_ms < min_repeat_ms); - - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -size_t CallbackChannel::Send(const void* data, size_t size) { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - int64_t n = fsend_(bytes); - if (n == -1) { - support::Socket::Error("CallbackChannel::Send"); - } - return static_cast(n); -} - -size_t CallbackChannel::Recv(void* data, size_t size) { - TVMRetValue ret = frecv_(size); - - if (ret.type_code() != kTVMBytes) { - support::Socket::Error("CallbackChannel::Recv"); - } - std::string* bytes = ret.ptr(); - memcpy(static_cast(data), bytes->c_str(), bytes->length()); - return bytes->length(); +void RPCSession::InsertToSessionTable(std::shared_ptr sess) { + CHECK_EQ(sess->table_index_, 0); + sess->table_index_ = RPCSessTable::Global()->Insert(sess); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index db63be4be74da..6a7e6d6e41c1a 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,230 +24,253 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ -#include #include -#include -#include +#include + +#include #include -#include -#include "../../support/ring_buffer.h" +#include + +#include "rpc_protocol.h" namespace tvm { namespace runtime { -// Magic header for RPC data plane -const int kRPCMagic = 0xff271; -// magic header for RPC tracker(control plane) -const int kRPCTrackerMagic = 0x2f271; -// sucess response -const int kRPCSuccess = kRPCMagic + 0; -// cannot found matched key in server -const int kRPCMismatch = kRPCMagic + 2; - -/*! \brief Enumeration code for the RPC tracker */ -enum class TrackerCode : int { - kFail = -1, - kSuccess = 0, - kPing = 1, - kStop = 2, - kPut = 3, - kRequest = 4, - kUpdateInfo = 5, - kSummary = 6, - kGetPendingMatchKeys = 7 -}; -/*! \brief The remote functio handle */ -using RPCFuncHandle = void*; - -struct RPCArgBuffer; - -/*! \brief The RPC code */ -enum class RPCCode : int { - kNone, - kCallFunc, - kReturn, - kException, - kShutdown, - kCopyFromRemote, - kCopyToRemote, - kCopyAck, - // The following are code that can send over CallRemote - kSystemFuncStart, - kGetGlobalFunc, - kGetTimeEvaluator, - kFreeFunc, - kDevSetDevice, - kDevGetAttr, - kDevAllocData, - kDevFreeData, - kDevStreamSync, - kCopyAmongRemote, - kModuleLoad, - kModuleImport, - kModuleFree, - kModuleGetFunc, - kModuleGetSource, - kNDArrayFree -}; - -/*! - * \brief Function that unwraps a remote object to its handle. - * \param rpc_sess_table_index RPC session table index for validation. - * \param obj Handle to the object argument. - * \return The corresponding handle. - */ -typedef void* (*FUnwrapRemoteObject)( - int rpc_sess_table_index, - const TVMArgValue& obj); - /*! - * \brief Abstract channel interface used to create RPCSession. + * \brief The interface of all remote RPC sessions. + * + * It contains all the necessary interface to implement + * remote call and resource management. + * + * The interface is designed to allow easy proxy-chaining + * by forward requests to another RPCSession. */ -class RPCChannel { +class RPCSession { public: - /*! \brief virtual destructor */ - virtual ~RPCChannel() {} + /*! \brief PackedFunc Handle in the remote. */ + using PackedFuncHandle = void*; + + /*! \brief Module handle in the remote. */ + using ModuleHandle = void*; + + /*! \brief NDArray handle in the remote. */ + using NDArrayHandle = void*; + /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. + * \brief Callback to send an encoded return values via encode_args. + * + * \param encode_args The arguments that we can encode the return values into. + * + * Encoding convention (as list of arguments): + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. + * - PackedFunc/Module: [tcode: int, handle: void*] + * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * DLTensor* contains the meta-data as well as handle into the remote data. + * nd_handle can be used for deletion. */ - virtual size_t Send(const void* data, size_t size) = 0; + using FEncodeReturn = std::function; + /*! - * \brief Recv data from channel. + * \brief Callback to send an encoded return values via encode_args. * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. + * \param status The return status, can be RPCCode::kReturn or RPCCode::kException. + * \param encode_args The arguments that we can encode the return values into. */ - virtual size_t Recv(void* data, size_t size) = 0; -}; + using FAsyncCallback = std::function; + + /*! \brief Destructor.*/ + virtual ~RPCSession() {} -// Bidirectional Communication Session of PackedRPC -class RPCSession { - public: - /*! \brief virtual destructor */ - ~RPCSession(); /*! - * \brief The server loop that server runs to handle RPC calls. + * \brief Get function in the session. + * \param name The name of the function. + * \return The function handle. */ - void ServerLoop(); + virtual PackedFuncHandle GetFunction(const std::string& name) = 0; + /*! - * \brief Message handling function for event driven server. - * Called when the server receives a message. - * Event driven handler will never call recv on the channel - * and always relies on the ServerEventHandler. - * to receive the data. + * \brief Call into a remote Packed function. * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ - int ServerEventHandler(const std::string& in_bytes, - int event_flag); - /*! - * \brief Call into remote function - * \param handle The function handle - * \param args The arguments - * \param rv The return value. - * \param funpwrap Function that takes a remote object and returns the raw handle. - * \param fwrap Wrapper function to turn Function/Module handle into real return. + * Calling convention: + * + * - type_code is follows the PackedFunc convention. + * - int/float/string/bytes follows the PackedFunc convention, all data are local. + * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * points to a remote data handle returned by the Device API. + * The meta-data of the DLTensor sits on local. + * + * The caller populates the arguments and manages these arguments. + * + * The callee can change the content of arg_values and arg_type_codes + * if they want to do inplace modify and forward. + * + * The callee need to store the return value into ret_value. + * - PackedFunc/Module are stored as void* + * - NDArray is stored as local NDArray, whose data field is a remote handle. + * Notably the NDArray's deleter won't delete remote handle. + * It is up to the user of the RPCSession to such wrapping. + * - In short, remote handles are "moved" as return values + * and the callee needs to explicitly manage them by calling + * the deleter functions when they are no longer needed. + * + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to set the return value, + * if not called, return value is null. */ - void CallFunc(RPCFuncHandle handle, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap); + virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + const FEncodeReturn& fencode_return) = 0; + /*! * \brief Copy bytes into remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. - * \param to The target array. - * \param to_offset The byte offset in the to. + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_to The target context. + * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + virtual void CopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_from The source context. + * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + virtual void CopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, + DLDataType type_hint) = 0; + + /*! + * \brief Free a remote function. + * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param type_code The type code of the underlying type. + */ + virtual void FreeHandle(void* handle, int type_code) = 0; + + /*! + * \brief Get device API that represents the remote + * actions that can be taken on the remote. + * + * The caller can then call into the Alloc/Free functions + * to allocate free spaces and taking the pointer as the handle. + * + * The device API is guaranteed to be alive during the + * lifetime of the Session. + * + * \param ctx The remote context. + * \param allow_missing Whether can we return nullptr if it is not available. + * + * \return The device API. + */ + virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + + /*! + * \brief Whether the session is a local session and we can directly + * the data handle returned by the session and treat it as pointer + * to the local memory. + * + * This information is useful for RPC server to directly copy into the + * local memory without creating a temporary buffer. + * + * \return Whether it is a local session. + */ + virtual bool IsLocalSession() const = 0; + + // Asynchrous variant of API + // These APIs are used by the RPC server to allow sessions that + // have special implementations for the async functions. + // + // In the async APIs, an exception is returned by the passing + // async_error=true, encode_args=[error_msg]. + /*! - * \brief Get a remote timer function on ctx. - * This function consumes fhandle, caller should not call Free on fhandle. + * \brief Whether the session is async. + * + * If the session is not async, its Aync implementations + * simply calls into the their synchronize counterparts, + * and the callback is guaranteed to be called before the async function finishes. + * + * \return the async state. * - * \param fhandle The function handle. - * \param ctx The ctx to run measurement on. - * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. - * \return A remote timer function + * \note We can only use async session in an Event driven RPC server. */ - RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms); + virtual bool IsAsync() const; + /*! - * \brief Call a remote defined system function with arguments. - * \param fcode The function code. - * \param args The arguments - * \return The returned remote value. + * \brief Asynchrously call func. + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * + * \param callback The callback to pass the return value or exception. */ - template - inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); + virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback); + /*! - * \return The session table index of the session. + * \brief Asynchrous version of CopyToRemote. + * + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_to The target context. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in local_from, and remote_to + * must stay alive until on_compelete is called. */ - int table_index() const { - return table_index_; - } + virtual void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete); + /*! - * \brief Create a RPC session with given channel. - * \param channel The communication channel. - * \param name The local name of the session, used for debug - * \param remote_key The remote key of the session - * if remote_key equals "%toinit", we need to re-intialize - * it by event handler. + * \brief Asynchrous version of CopyFromRemote. + * + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_from The source context in the remote. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in remote_from, and local_to + * must stay alive until on_compelete is called. + */ + virtual void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, + FAsyncCallback on_complete); + /*! + * \brief Asynchrously wait for all events in ctx, stream compeletes. + * \param ctx The device context. + * \param stream The stream to wait on. + * \param on_complete The callback to signal copy complete. */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + virtual void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_compelte); + + /*! + * \return The session table index of the session. + */ + int table_index() const { return table_index_; } + /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. @@ -255,63 +278,33 @@ class RPCSession { */ static std::shared_ptr Get(int table_index); + protected: + /*! + * \brief Send an exception to the callback. + * \param msg The exception message. + */ + void SendException(FAsyncCallback callback, const char* msg); + private: - class EventHandler; - // Handle events until receives a return - // Also flushes channels so that the function advances. - RPCCode HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); - // Initalization - void Init(); - // Shutdown - void Shutdown(); - // Internal channel. - std::unique_ptr channel_; - // Internal mutex - std::recursive_mutex mutex_; - // Internal ring buffer. - support::RingBuffer reader_, writer_; - // Event handler. - std::shared_ptr handler_; - // call remote with specified function code. - PackedFunc call_remote_; - // The index of this session in RPC session table. + /*! \brief index of this session in RPC session table */ int table_index_{0}; - // The name of the session. - std::string name_; - // The remote key - std::string remote_key_; + /*! \brief Insert the current session to the session table.*/ + static void InsertToSessionTable(std::shared_ptr sess); + // friend declaration + friend Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! - * \brief RPC channel which callback - * frontend (Python/Java/etc.)'s send & recv function + * \brief Remote space handle cell used by the RPC runtime API. + * + * When we allocate space using a rpc context, the data pointer + * points to an allocated RemoteSpace. */ -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) - : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} - - ~CallbackChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - size_t Send(const void* data, size_t size) final; - /*! - * \brief Recv data from channel. - * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. - */ - size_t Recv(void* data, size_t size) final; - - private: - PackedFunc fsend_; - PackedFunc frecv_; +struct RemoteSpace { + /*! \brief The remote data handle. */ + void* data; + /*! \brief Reference to the underlying RPC session. */ + std::shared_ptr sess; }; /*! @@ -319,24 +312,21 @@ class CallbackChannel final : public RPCChannel { * \param f The function argument. * \param ctx The context. * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. + * We call these runs as one `repeat` of measurement. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ -PackedFunc WrapTimeEvaluator(PackedFunc f, - TVMContext ctx, - int number, - int repeat, +PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat, int min_repeat_ms); /*! @@ -344,21 +334,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f, * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCModule(std::shared_ptr sess); +Module CreateRPCSessionModule(std::shared_ptr sess); -// Remote space pointer. -struct RemoteSpace { - void* data; - std::shared_ptr sess; -}; +/*! + * \brief Get the session module from a RPC session Module. + * \param mod The input module(must be an RPCModule). + * \return The internal RPCSession. + */ +std::shared_ptr RPCModuleGetSession(Module mod); -// implementation of inline functions -template -inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { - std::lock_guard lock(mutex_); - writer_.Write(&code, sizeof(code)); - return call_remote_(std::forward(args)...); -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 642fbb8ec7f26..77a743be0de69 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,18 +21,22 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ +#include #include + #include -#include "rpc_session.h" + #include "../../support/socket.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "rpc_session.h" namespace tvm { namespace runtime { class SockChannel final : public RPCChannel { public: - explicit SockChannel(support::TCPSocket sock) - : sock_(sock) {} + explicit SockChannel(support::TCPSocket sock) : sock_(sock) {} ~SockChannel() { try { // BadSocket can throw @@ -61,13 +65,12 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key) { +std::shared_ptr RPCConnect(std::string url, int port, std::string key, + TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); - CHECK(sock.Connect(addr)) - << "Connect to " << addr.AsString() << " failed"; + CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake std::ostringstream os; int code = kRPCMagic; @@ -80,12 +83,10 @@ RPCConnect(std::string url, int port, std::string key) { CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); if (code == kRPCMagic + 2) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " cannot find server that matches key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; } else if (code == kRPCMagic + 1) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " server already have key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; } else if (code != kRPCMagic) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; @@ -96,42 +97,46 @@ RPCConnect(std::string url, int port, std::string key) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - return RPCSession::Create( - std::unique_ptr(new SockChannel(sock)), key, remote_key); + auto endpt = + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), key, remote_key); + endpt->InitRemoteSession(init_seq); + return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key) { - return CreateRPCModule(RPCConnect(url, port, "client:" + key)); +Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, init_seq); + return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { - support::TCPSocket sock( - static_cast(sockfd)); - RPCSession::Create( - std::unique_ptr(new SockChannel(sock)), - "SockServerLoop", "")->ServerLoop(); + support::TCPSocket sock(static_cast(sockfd)); + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "") + ->ServerLoop(); } void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { - RPCSession::Create(std::unique_ptr( - new CallbackChannel(fsend, frecv)), - "SockServerLoop", "")->ServerLoop(); + RPCEndpoint::Create(std::unique_ptr(new CallbackChannel(fsend, frecv)), + "SockServerLoop", "") + ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body_typed(RPCClientConnect); +TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string url = args[0]; + int port = args[1]; + std::string key = args[2]; + *rv = RPCClientConnect(url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +}); + +TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) { + if (args[0].type_code() == kDLInt) { + RPCServerLoop(args[0]); + } else { + RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } +}); -TVM_REGISTER_GLOBAL("rpc._ServerLoop") -.set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 1) { - RPCServerLoop(args[0]); - } else { - CHECK_EQ(args.size(), 2); - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); - } - }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 84fc3c462c3d6..21601df1ad395 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -25,25 +25,37 @@ #define TVM_RUNTIME_RUNTIME_BASE_H_ #include + #include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (std::runtime_error & _except_) { \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (std::runtime_error & _except_) { \ + Finalize; \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief handle exception throwed out * \param e the exception * \return the return value of API after exception is handled */ -int TVMAPIHandleException(const std::runtime_error &e); +int TVMAPIHandleException(const std::runtime_error& e); #endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index 0f17f9e4b4a23..042815b3d68b7 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -21,87 +21,88 @@ * Implementation stack VM. * \file stackvm.cc */ +#include "stackvm.h" + #include #include + #include -#include "stackvm.h" namespace tvm { namespace runtime { typedef dmlc::ThreadLocalStore StackVMStateStore; -StackVM::State* StackVM::ThreadLocalState() { - return StackVMStateStore::Get(); -} +StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); } #define STACK_VM_BINOP(OP, FIELD) \ { \ stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } #define STACK_VM_CMPOP(OP, FIELD) \ { \ stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } -#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - stack[sp]FIELD = static_cast( \ - static_cast(stack[sp].v_handle)[index]); \ - pc += 2; \ +#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + stack[sp] FIELD = static_cast(static_cast(stack[sp].v_handle)[index]); \ + pc += 2; \ } -#define STACK_VM_STORE(FIELD, DST_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - static_cast(stack[sp - 1].v_handle)[index] = \ - static_cast(stack[sp]FIELD); \ - sp -= 2; pc += 2; \ +#define STACK_VM_STORE(FIELD, DST_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + static_cast(stack[sp - 1].v_handle)[index] = \ + static_cast(stack[sp] FIELD); \ + sp -= 2; \ + pc += 2; \ } -#define STACK_VM_PRINT_CODE0(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \ +#define STACK_VM_PRINT_CODE0(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << std::endl; \ + return pc + 1; \ } -#define STACK_VM_PRINT_CODE1(CODE) \ - case CODE: { \ +#define STACK_VM_PRINT_CODE1(CODE) \ + case CODE: { \ os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_CODE2(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE \ - << " " << code[pc + 1].v_int \ - << " " << code[pc + 2].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl \ - << "[" << pc + 2 << "]" << std::endl; \ - return pc + 3; \ +#define STACK_VM_PRINT_CODE2(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \ + << "\n" \ + << "[" << pc + 1 << "]" << std::endl \ + << "[" << pc + 2 << "]" << std::endl; \ + return pc + 3; \ } -#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \ - << " " << heap_id_name[code[pc + 1].v_int] << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \ + << heap_id_name[code[pc + 1].v_int] << "\n" \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_JUMP(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \ - << " to " << pc + code[pc + 1].v_int << '\n' \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_JUMP(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \ + << pc + code[pc + 1].v_int << '\n' \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } - int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { switch (code[pc].op_code) { // int @@ -164,9 +165,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; os << "[" << pc << "]\tCALL_PACKED_FUNC " - << " fid=" << call_fid - << " begin=" << begin - << " end=" << end; + << " fid=" << call_fid << " begin=" << begin << " end=" << end; os << '\n'; for (int i = 0; i < 3; ++i) { os << "[" << pc + 1 + i << "]" << std::endl; @@ -181,8 +180,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) int64_t pc = 0; const int64_t code_size = static_cast(vm.code.size()); - os << "Program dump: code-size=" << code_size << '\n' - << "----------begin-----------------\n"; + os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n"; while (pc < code_size) { pc = vm.PrintCode(os, pc); } @@ -190,8 +188,7 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) return os; } -void StackVM::Run(const runtime::TVMArgs& args, - runtime::ModuleNode* mod_ctx) const { +void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const { StackVM::State* s = StackVM::ThreadLocalState(); if (s->heap.size() < heap_size) { s->heap.resize(heap_size); @@ -199,7 +196,7 @@ void StackVM::Run(const runtime::TVMArgs& args, s->sp = 0; s->pc = 0; s->mod_ctx = mod_ctx; - s->heap[0].v_handle = (void*)args.values; // NOLINT(*) + s->heap[0].v_handle = (void*)args.values; // NOLINT(*) s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) s->heap[2].v_int64 = args.num_args; this->Run(s); @@ -207,16 +204,13 @@ void StackVM::Run(const runtime::TVMArgs& args, void StackVM::InitCache() { extern_func_cache_.clear(); - extern_func_cache_.resize( - extern_func_name.size(), PackedFunc(nullptr)); + extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr)); } void StackVM::Save(dmlc::Stream* strm) const { // to be endian invariant. std::vector code_copy(code.size()); - std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { - return c.v_int; - }); + std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; }); strm->Write(code_copy); strm->Write(str_data); strm->Write(extern_func_name); @@ -225,14 +219,16 @@ void StackVM::Save(dmlc::Stream* strm) const { strm->Write(stack_size); } -bool StackVM::Load(dmlc::Stream* strm) { +bool StackVM::Load(dmlc::Stream* strm) { // to be endian invariant. std::vector code_copy; if (!strm->Read(&code_copy)) return false; code.resize(code_copy.size()); std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) { - Code code; code.v_int = v; return code; - }); + Code code; + code.v_int = v; + return code; + }); if (!strm->Read(&str_data)) return false; if (!strm->Read(&extern_func_name)) return false; if (!strm->Read(&heap_id_name)) return false; @@ -258,36 +254,92 @@ void StackVM::Run(State* s) const { const int64_t code_size = static_cast(code.size()); while (pc < code_size) { switch (code[pc].op_code) { - case ADD_I64: STACK_VM_BINOP(+, v_int64); break; - case SUB_I64: STACK_VM_BINOP(-, v_int64); break; - case MUL_I64: STACK_VM_BINOP(*, v_int64); break; - case DIV_I64: STACK_VM_BINOP(/, v_int64); break; - case MOD_I64: STACK_VM_BINOP(%, v_int64); break; - case EQ_I64: STACK_VM_CMPOP(==, v_int64); break; - case LT_I64: STACK_VM_CMPOP(<, v_int64); break; - case LE_I64: STACK_VM_CMPOP(<=, v_int64); break; - case ADD_F64: STACK_VM_BINOP(+, v_float64); break; - case SUB_F64: STACK_VM_BINOP(-, v_float64); break; - case MUL_F64: STACK_VM_BINOP(*, v_float64); break; - case DIV_F64: STACK_VM_BINOP(/, v_float64); break; - case EQ_F64: STACK_VM_CMPOP(==, v_float64); break; - case LT_F64: STACK_VM_CMPOP(<, v_float64); break; - case LE_F64: STACK_VM_CMPOP(<=, v_float64); break; - case EQ_HANDLE: STACK_VM_CMPOP(==, v_handle); break; + case ADD_I64: + STACK_VM_BINOP(+, v_int64); + break; + case SUB_I64: + STACK_VM_BINOP(-, v_int64); + break; + case MUL_I64: + STACK_VM_BINOP(*, v_int64); + break; + case DIV_I64: + STACK_VM_BINOP(/, v_int64); + break; + case MOD_I64: + STACK_VM_BINOP(%, v_int64); + break; + case EQ_I64: + STACK_VM_CMPOP(==, v_int64); + break; + case LT_I64: + STACK_VM_CMPOP(<, v_int64); + break; + case LE_I64: + STACK_VM_CMPOP(<=, v_int64); + break; + case ADD_F64: + STACK_VM_BINOP(+, v_float64); + break; + case SUB_F64: + STACK_VM_BINOP(-, v_float64); + break; + case MUL_F64: + STACK_VM_BINOP(*, v_float64); + break; + case DIV_F64: + STACK_VM_BINOP(/, v_float64); + break; + case EQ_F64: + STACK_VM_CMPOP(==, v_float64); + break; + case LT_F64: + STACK_VM_CMPOP(<, v_float64); + break; + case LE_F64: + STACK_VM_CMPOP(<=, v_float64); + break; + case EQ_HANDLE: + STACK_VM_CMPOP(==, v_handle); + break; // addressing - case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break; - case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break; - case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break; - case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break; - case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break; - case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break; + case ARRAY_LOAD_UINT32: + STACK_VM_LOAD(.v_int64, int64_t, uint32_t); + break; + case ARRAY_LOAD_INT32: + STACK_VM_LOAD(.v_int64, int64_t, int32_t); + break; + case ARRAY_LOAD_INT64: + STACK_VM_LOAD(.v_int64, int64_t, int64_t); + break; + case ARRAY_LOAD_FP64: + STACK_VM_LOAD(.v_float64, double, double); + break; + case ARRAY_LOAD_HANDLE: + STACK_VM_LOAD(.v_handle, void*, void*); + break; + case ARRAY_LOAD_TVMVALUE: + STACK_VM_LOAD(, TVMValue, TVMValue); + break; // store - case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break; - case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break; - case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break; - case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break; - case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break; - case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break; + case ARRAY_STORE_UINT32: + STACK_VM_STORE(.v_int64, uint32_t); + break; + case ARRAY_STORE_INT32: + STACK_VM_STORE(.v_int64, int32_t); + break; + case ARRAY_STORE_INT64: + STACK_VM_STORE(.v_int64, int64_t); + break; + case ARRAY_STORE_FP64: + STACK_VM_STORE(.v_float64, double); + break; + case ARRAY_STORE_HANDLE: + STACK_VM_STORE(.v_handle, void*); + break; + case ARRAY_STORE_TVMVALUE: + STACK_VM_STORE(, TVMValue); + break; // add case ADDR_ADD: { stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*) @@ -365,9 +417,8 @@ void StackVM::Run(State* s) const { } case ASSERT_SP: { int64_t expected = code[pc + 1].v_int; - CHECK_EQ(sp, expected) - << "sp assertion failed, expected=" - << expected << " now=" << sp << ", pc=" << pc; + CHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp + << ", pc=" << pc; pc += 2; break; } @@ -379,11 +430,10 @@ void StackVM::Run(State* s) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; int num_args = end - begin; - static_assert(sizeof(Code) == sizeof(int) && - alignof(Code) == alignof(int), "asusmption"); + static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption"); runtime::TVMRetValue rv; - GetExtern(s, call_fid).CallPacked( - runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); + GetExtern(s, call_fid) + .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); sp = sp - 1; stack[sp] = rv.value(); pc += 4; @@ -396,47 +446,55 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp].v_handle); switch (kind) { case StackVM::kArrData: { - stack[sp].v_handle = arr[index].data; break; + stack[sp].v_handle = arr[index].data; + break; } case StackVM::kArrShape: { - stack[sp].v_handle = arr[index].shape; break; + stack[sp].v_handle = arr[index].shape; + break; } case StackVM::kArrStrides: { - stack[sp].v_handle = arr[index].strides; break; + stack[sp].v_handle = arr[index].strides; + break; } case StackVM::kArrNDim: { - stack[sp].v_int64 = arr[index].ndim; break; + stack[sp].v_int64 = arr[index].ndim; + break; } case StackVM::kArrTypeCode: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.code); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.code); + break; } case StackVM::kArrTypeBits: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.bits); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.bits); + break; } case StackVM::kArrTypeLanes: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.lanes); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.lanes); + break; } case StackVM::kArrByteOffset: { - stack[sp].v_int64 = static_cast( - arr[index].byte_offset); break; + stack[sp].v_int64 = static_cast(arr[index].byte_offset); + break; } case StackVM::kArrDeviceId: { - stack[sp].v_int64 = arr[index].ctx.device_id; break; + stack[sp].v_int64 = arr[index].ctx.device_id; + break; } case StackVM::kArrDeviceType: { - stack[sp].v_int64 = static_cast( - arr[index].ctx.device_type); break; + stack[sp].v_int64 = static_cast(arr[index].ctx.device_type); + break; } case StackVM::kArrAddr: { - stack[sp].v_handle = arr + index; break; + stack[sp].v_handle = arr + index; + break; } case StackVM::kTVMValueContent: { - stack[sp] = static_cast(stack[sp].v_handle)[index]; break; + stack[sp] = static_cast(stack[sp].v_handle)[index]; + break; } - default: LOG(FATAL) << "unhandled get " << kind; + default: + LOG(FATAL) << "unhandled get " << kind; } pc = pc + 3; break; @@ -447,7 +505,8 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp - 1].v_handle); switch (kind) { case StackVM::kArrData: { - arr[index].data = stack[sp].v_handle; break; + arr[index].data = stack[sp].v_handle; + break; } case StackVM::kArrShape: { arr[index].shape = static_cast(stack[sp].v_handle); @@ -486,9 +545,11 @@ void StackVM::Run(State* s) const { break; } case StackVM::kTVMValueContent: { - static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; break; + static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; + break; } - default: LOG(FATAL) << "unhandled tvm_struct_set " << kind; + default: + LOG(FATAL) << "unhandled tvm_struct_set " << kind; } sp -= 2; pc += 3; @@ -511,8 +572,8 @@ void StackVM::Run(State* s) const { size_t nbytes = static_cast(stack[sp - 2].v_int64); int dtype_code_hint = static_cast(stack[sp - 1].v_int64); int dtype_bits_hint = static_cast(stack[sp].v_int64); - void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, - dtype_code_hint, dtype_bits_hint); + void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, + dtype_bits_hint); stack[sp - 4].v_handle = ptr; sp = sp - 4; pc = pc + 1; @@ -543,8 +604,7 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const { // allow race write in this, since write is idempotent PackedFunc& f = extern_func_cache_[fid]; if (f == nullptr) { - CHECK(s->mod_ctx != nullptr) - << "No local context is set in stackvm"; + CHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm"; const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); CHECK(pf != nullptr); f = *pf; diff --git a/src/runtime/stackvm/stackvm.h b/src/runtime/stackvm/stackvm.h index f36e171cdf3e9..09581a6d0b62a 100644 --- a/src/runtime/stackvm/stackvm.h +++ b/src/runtime/stackvm/stackvm.h @@ -29,8 +29,9 @@ #define TVM_RUNTIME_STACKVM_STACKVM_H_ #include -#include #include +#include + #include #include @@ -339,7 +340,7 @@ class StackVM { * \param pc The pc * \return the pc to next instruction. */ - int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*) + int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*) /*! \brief Get thread local state of the stack VM */ static State* ThreadLocalState(); // The code below are programs @@ -362,15 +363,26 @@ class StackVM { */ static OpCode CodeI64ToF64(OpCode code) { switch (code) { - case ADD_I64: return ADD_F64; - case SUB_I64: return SUB_F64; - case MUL_I64: return MUL_F64; - case DIV_I64: return DIV_F64; - case EQ_I64: return EQ_F64; - case LT_I64: return LT_F64; - case LE_I64: return LE_F64; - case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64; - default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64; + case ADD_I64: + return ADD_F64; + case SUB_I64: + return SUB_F64; + case MUL_I64: + return MUL_F64; + case DIV_I64: + return DIV_F64; + case EQ_I64: + return EQ_F64; + case LT_I64: + return LT_F64; + case LE_I64: + return LE_F64; + case MOD_I64: + LOG(FATAL) << "cannot handle mod for float"; + return ADD_F64; + default: + LOG(FATAL) << "cannot handle op " << code; + return ADD_F64; } } /*! @@ -383,16 +395,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_INT32; - case 64 : return ARRAY_LOAD_INT64; + case 32: + return ARRAY_LOAD_INT32; + case 64: + return ARRAY_LOAD_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_UINT32; + case 32: + return ARRAY_LOAD_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_LOAD_FP64; + case 64: + return ARRAY_LOAD_FP64; } } LOG(FATAL) << "Cannot load type " << t; @@ -408,16 +424,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_INT32; - case 64 : return ARRAY_STORE_INT64; + case 32: + return ARRAY_STORE_INT32; + case 64: + return ARRAY_STORE_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_UINT32; + case 32: + return ARRAY_STORE_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_STORE_FP64; + case 64: + return ARRAY_STORE_FP64; } } LOG(FATAL) << "Cannot store type " << t; diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 8b30b750e7143..9e1f1f515f4a4 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -20,13 +20,16 @@ /*! * \file stackvm_module.cc */ -#include -#include +#include "stackvm_module.h" + #include +#include +#include + #include -#include #include -#include "stackvm_module.h" +#include + #include "../file_util.h" namespace tvm { @@ -34,13 +37,9 @@ namespace runtime { class StackVMModuleNode : public runtime::ModuleNode { public: - const char* type_key() const { - return "stackvm"; - } + const char* type_key() const { return "stackvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == runtime::symbol::tvm_module_main) { return GetFunction(entry_func_, sptr_to_self); } @@ -48,9 +47,8 @@ class StackVMModuleNode : public runtime::ModuleNode { if (it == fmap_.end()) return PackedFunc(); const StackVM& vm = it->second; // capture sptr_to_self to keep module node alive. - return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - vm.Run(args, this); - }); + return PackedFunc( + [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); }); } std::string GetSource(const std::string& format) final { @@ -62,8 +60,7 @@ class StackVMModuleNode : public runtime::ModuleNode { return os.str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string data, mblob; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -74,8 +71,7 @@ class StackVMModuleNode : public runtime::ModuleNode { strm->Write(num_imports); for (runtime::Module im : imports_) { - CHECK_EQ(im->imports().size(), 0U) - << "Only support simply one-level hierarchy"; + CHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); strm->Write(tkey); LOG(INFO) << "save " << tkey; @@ -85,8 +81,7 @@ class StackVMModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data); } - static Module Create(std::unordered_map fmap, - std::string entry_func) { + static Module Create(std::unordered_map fmap, std::string entry_func) { auto n = make_object(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); @@ -108,17 +103,14 @@ class StackVMModuleNode : public runtime::ModuleNode { CHECK(strm->Read(&tkey)); std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(strm)); n->imports_.emplace_back(std::move(m)); } return Module(n); } - static Module LoadFromFile(std::string file_name, - std::string format) { + static Module LoadFromFile(std::string file_name, std::string format) { std::string data; LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); @@ -132,13 +124,12 @@ class StackVMModuleNode : public runtime::ModuleNode { std::string entry_func_; }; -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func) { +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func) { return StackVMModuleNode::Create(fmap, entry_func); } TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm") -.set_body_typed(StackVMModuleNode::LoadFromFile); + .set_body_typed(StackVMModuleNode::LoadFromFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/stackvm/stackvm_module.h b/src/runtime/stackvm/stackvm_module.h index c84eb6fe49458..6ae4ae47a92c9 100644 --- a/src/runtime/stackvm/stackvm_module.h +++ b/src/runtime/stackvm/stackvm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,10 @@ #define TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_ #include + #include #include + #include "stackvm.h" namespace tvm { @@ -38,8 +40,7 @@ namespace runtime { * \param entry_func The entry function name. * \return The created module */ -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func); +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func); } // namespace runtime } // namespace tvm diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index 3eb7b1c46b45d..fe29146d8b7b3 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -21,10 +21,12 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ -#include -#include #include +#include +#include + #include + #include "library_module.h" namespace tvm { @@ -48,10 +50,8 @@ class SystemLibrary : public Library { std::lock_guard lock(mutex_); auto it = tbl_.find(name); if (it != tbl_.end() && ptr != it->second) { - LOG(WARNING) - << "SystemLib symbol " << name - << " get overriden to a different address " - << ptr << "->" << it->second; + LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr + << "->" << it->second; } tbl_[name] = ptr; } @@ -68,11 +68,9 @@ class SystemLibrary : public Library { std::unordered_map tbl_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib") -.set_body_typed([]() { - static auto mod = CreateModuleFromLibrary( - SystemLibrary::Global()); - return mod; +TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() { + static auto mod = CreateModuleFromLibrary(SystemLibrary::Global()); + return mod; }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 00f089b86b0f3..0cc881ceb7f28 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -21,26 +21,26 @@ * \file thread_pool.cc * \brief Threadpool for multi-threading runtime. */ -#include +#include +#include #include -#include +#include #include +#include #include -#include -#include #if TVM_THREADPOOL_USE_OPENMP #include #endif -#include -#include -#include -#include #include -#include -#include +#include +#include #include #include +#include #include +#include +#include +#include const constexpr int kL1CacheBytes = 64; @@ -69,10 +69,7 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic); class ParallelLauncher { public: // Reset the the task request. - void Init(FTVMParallelLambda flambda, - void* cdata, - int num_task, - bool need_sync) { + void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) { num_pending_.store(num_task); this->cdata = cdata; this->flambda = flambda; @@ -88,17 +85,14 @@ class ParallelLauncher { } if (need_sync) { for (int i = 0; i < num_task; ++i) { - sync_counter_[i * kSyncStride].store( - 0, std::memory_order_relaxed); + sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed); } this->env.sync_handle = sync_counter_; } else { this->env.sync_handle = nullptr; } } - ~ParallelLauncher() { - delete[] sync_counter_; - } + ~ParallelLauncher() { delete[] sync_counter_; } // Wait n jobs to finish int WaitForJobs() { while (num_pending_.load() != 0) { @@ -122,13 +116,9 @@ class ParallelLauncher { has_error_.store(true); } // Signal that one job has finished. - void SignalJobFinish() { - num_pending_.fetch_sub(1); - } + void SignalJobFinish() { num_pending_.fetch_sub(1); } // Get thread local version of the store. - static ParallelLauncher* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } // The parallel lambda FTVMParallelLambda flambda; // The closure data @@ -159,15 +149,9 @@ class SpscTaskQueue { int32_t task_id; }; - SpscTaskQueue() : - buffer_(new Task[kRingSize]), - head_(0), - tail_(0) { - } + SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {} - ~SpscTaskQueue() { - delete[] buffer_; - } + ~SpscTaskQueue() { delete[] buffer_; } /*! * \brief Push a task into the queue and notify the comsumer if it is on wait. @@ -198,9 +182,7 @@ class SpscTaskQueue { } if (pending_.fetch_sub(1) == 0) { std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return pending_.load() >= 0 || exit_now_.load(); - }); + cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); }); } if (exit_now_.load(std::memory_order_relaxed)) { return false; @@ -275,7 +257,7 @@ class SpscTaskQueue { // The thread pool class ThreadPool { public: - ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) { + ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) { for (int i = 0; i < num_workers_; ++i) { // The SpscTaskQueue only hosts ONE item at a time queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); @@ -286,8 +268,8 @@ class ThreadPool { } threads_ = std::unique_ptr( new tvm::runtime::threading::ThreadGroup( - num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, - exclude_worker0_ /* include_main_thread */)); + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + exclude_worker0_ /* include_main_thread */)); num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); } ~ThreadPool() { @@ -296,10 +278,7 @@ class ThreadPool { } threads_.reset(); } - int Launch(FTVMParallelLambda flambda, - void* cdata, - int num_task, - int need_sync) { + int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); CHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, consider fuse then parallel"; @@ -332,15 +311,12 @@ class ThreadPool { return res; } - static ThreadPool* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) { // this will also reset the affinity of the ThreadGroup // may use less than the MaxConcurrency number of workers - num_workers_used_ = threads_->Configure(mode, nthreads, - exclude_worker0_); + num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_); // if MaxConcurrency restricted the number of workers (e.g., due to // hyperthreading), respect the restriction num_workers_used_ = std::min(num_workers_, num_workers_used_); @@ -376,33 +352,25 @@ class ThreadPool { std::unique_ptr threads_; }; -TVM_REGISTER_GLOBAL("runtime.config_threadpool") -.set_body([](TVMArgs args, TVMRetValue* rv) { - threading::ThreadGroup::AffinityMode mode =\ - static_cast(\ - static_cast(args[0])); - int nthreads = args[1]; - ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); +TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) { + threading::ThreadGroup::AffinityMode mode = + static_cast(static_cast(args[0])); + int nthreads = args[1]; + ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); }); - } // namespace runtime } // namespace tvm - -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { #if !TVM_THREADPOOL_USE_OPENMP - int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch( - flambda, cdata, num_task, 1); + int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); return res; #else int num_workers = tvm::runtime::threading::MaxConcurrency(); if (num_task == 0) num_task = num_workers; omp_set_num_threads(num_workers); - #pragma omp parallel num_threads(num_workers) +#pragma omp parallel num_threads(num_workers) { TVMParallelGroupEnv env; env.num_task = num_task; @@ -414,18 +382,15 @@ int TVMBackendParallelLaunch( int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { #if TVM_THREADPOOL_USE_OPENMP - #pragma omp barrier +#pragma omp barrier #else using tvm::runtime::kSyncStride; int num_task = penv->num_task; - std::atomic* sync_counter = - reinterpret_cast*>(penv->sync_handle); - int old_counter = sync_counter[task_id * kSyncStride].fetch_add( - 1, std::memory_order_release); + std::atomic* sync_counter = reinterpret_cast*>(penv->sync_handle); + int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release); for (int i = 0; i < num_task; ++i) { if (i != task_id) { - while (sync_counter[i * kSyncStride].load( - std::memory_order_relaxed) <= old_counter) { + while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) { tvm::runtime::threading::Yield(); } } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 3e6fd781023c5..92e12b5f3a38a 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #include + #include #include @@ -64,9 +65,12 @@ enum class StorageRank { */ inline StorageRank DefaultStorageRank(int thread_scope_rank) { switch (thread_scope_rank) { - case -1: return StorageRank::kGlobal; - case 0: return StorageRank::kShared; - case 1: return StorageRank::kLocal; + case -1: + return StorageRank::kGlobal; + case 0: + return StorageRank::kShared; + case 1: + return StorageRank::kLocal; default: { LOG(FATAL) << "unknown rank"; return StorageRank::kGlobal; @@ -84,20 +88,27 @@ struct StorageScope { inline bool operator==(const StorageScope& other) const { return rank == other.rank && tag == other.tag; } - inline bool operator!=(const StorageScope& other) const { - return !(*this == other); - } + inline bool operator!=(const StorageScope& other) const { return !(*this == other); } inline std::string to_string() const { std::string ret; switch (rank) { - case StorageRank::kGlobal: return "global" + tag; - case StorageRank::kShared: return "shared" + tag; - case StorageRank::kWarp: return "warp" + tag; - case StorageRank::kLocal: return "local" + tag; - case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag; - case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag; - case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag; - default: LOG(FATAL) << "unknown storage scope"; return ""; + case StorageRank::kGlobal: + return "global" + tag; + case StorageRank::kShared: + return "shared" + tag; + case StorageRank::kWarp: + return "warp" + tag; + case StorageRank::kLocal: + return "local" + tag; + case StorageRank::kWMMAMatrixA: + return "wmma.matrix_a" + tag; + case StorageRank::kWMMAMatrixB: + return "wmma.matrix_b" + tag; + case StorageRank::kWMMAAccumulator: + return "wmma.accumulator" + tag; + default: + LOG(FATAL) << "unknown storage scope"; + return ""; } } /*! @@ -107,7 +118,7 @@ struct StorageScope { */ static StorageScope make(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { @@ -165,7 +176,6 @@ struct ThreadScope { } }; - /*! \brief workload specification */ struct ThreadWorkLoad { // array, first three are thread configuration. @@ -174,22 +184,17 @@ struct ThreadWorkLoad { * \param i The block dimension. * \return i-th block dim */ - inline size_t block_dim(size_t i) const { - return work_size[i + 3]; - } + inline size_t block_dim(size_t i) const { return work_size[i + 3]; } /*! * \param i The grid dimension. * \return i-th grid dim */ - inline size_t grid_dim(size_t i) const { - return work_size[i]; - } + inline size_t grid_dim(size_t i) const { return work_size[i]; } }; /*! \brief Thread axis configuration */ class ThreadAxisConfig { public: - void Init(size_t base, - const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& thread_axis_tags) { base_ = base; std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { @@ -210,15 +215,12 @@ class ThreadAxisConfig { ThreadWorkLoad w; std::fill(w.work_size, w.work_size + 6, 1); for (size_t i = 0; i < arg_index_map_.size(); ++i) { - w.work_size[arg_index_map_[i]] = - static_cast(x.values[base_ + i].v_int64); + w.work_size[arg_index_map_[i]] = static_cast(x.values[base_ + i].v_int64); } return w; } // return the work dim - size_t work_dim() const { - return work_dim_; - } + size_t work_dim() const { return work_dim_; } private: /*! \brief base axis */ diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 9d14d3a14d033..2e781eaf38879 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,10 +21,11 @@ * \file threading_backend.cc * \brief Native threading backend */ -#include #include -#include +#include + #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #include @@ -40,12 +41,9 @@ namespace threading { class ThreadGroup::Impl { public: - Impl(int num_workers, - std::function worker_callback, - bool exclude_worker0) + Impl(int num_workers, std::function worker_callback, bool exclude_worker0) : num_workers_(num_workers) { - CHECK_GE(num_workers, 1) - << "Requested a non-positive number of worker threads."; + CHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads."; for (int i = exclude_worker0; i < num_workers_; ++i) { threads_.emplace_back([worker_callback, i] { worker_callback(i); }); } @@ -79,15 +77,14 @@ class ThreadGroup::Impl { // ones. num_workers_used = std::min(num_workers_, num_workers_used); - const char *val = getenv("TVM_BIND_THREADS"); + const char* val = getenv("TVM_BIND_THREADS"); if (val == nullptr || atoi(val) == 1) { // Do not set affinity if there are more workers than found cores if (sorted_order_.size() >= static_cast(num_workers_)) { - SetAffinity(exclude_worker0, mode == kLittle); + SetAffinity(exclude_worker0, mode == kLittle); } else { - LOG(WARNING) - << "The thread affinity cannot be set when the number of workers" - << "is larger than the number of available cores in the system."; + LOG(WARNING) << "The thread affinity cannot be set when the number of workers" + << "is larger than the number of available cores in the system."; } } return num_workers_used; @@ -101,15 +98,14 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) #ifndef CPU_SET #define CPU_SETSIZE 1024 -#define __NCPUBITS (8 * sizeof (uint64_t)) +#define __NCPUBITS (8 * sizeof(uint64_t)) typedef struct { uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; } cpu_set_t; #define CPU_SET(cpu, cpusetp) \ - ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) -#define CPU_ZERO(cpusetp) \ - memset((cpusetp), 0, sizeof(cpu_set_t)) + ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) #endif #endif #if defined(__linux__) || defined(__ANDROID__) @@ -128,8 +124,7 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #else - pthread_setaffinity_np(threads_[i].native_handle(), - sizeof(cpu_set_t), &cpuset); + pthread_setaffinity_np(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #endif } if (exclude_worker0) { // master thread run task @@ -182,27 +177,27 @@ class ThreadGroup::Impl { void InitSortedOrder() { unsigned int threads = std::thread::hardware_concurrency(); - std::vector > max_freqs; + std::vector > max_freqs; for (unsigned int i = 0; i < threads; ++i) { int64_t cur_freq = 0; - #if defined(__linux__) || defined(__ANDROID__) - std::ostringstream filepath; - filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; - std::ifstream ifs(filepath.str()); - if (!ifs.fail()) { - if (!(ifs >> cur_freq)) { - cur_freq = -1; - } - ifs.close(); +#if defined(__linux__) || defined(__ANDROID__) + std::ostringstream filepath; + filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; + std::ifstream ifs(filepath.str()); + if (!ifs.fail()) { + if (!(ifs >> cur_freq)) { + cur_freq = -1; } - #endif + ifs.close(); + } +#endif max_freqs.push_back(std::make_pair(i, cur_freq)); } - auto fcmpbyfreq = [] (const std::pair &a, - const std::pair &b) { - return a.second == b.second ? a.first < b.first : a.second > b.second; + auto fcmpbyfreq = [](const std::pair& a, + const std::pair& b) { + return a.second == b.second ? a.first < b.first : a.second > b.second; }; std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); int64_t big_freq = max_freqs.begin()->second; @@ -228,10 +223,9 @@ class ThreadGroup::Impl { int little_count_ = 0; }; -ThreadGroup::ThreadGroup(int num_workers, - std::function worker_callback, +ThreadGroup::ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0) - : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} + : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} ThreadGroup::~ThreadGroup() { delete impl_; } void ThreadGroup::Join() { impl_->Join(); } @@ -239,13 +233,11 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 return impl_->Configure(mode, nthreads, exclude_worker0); } -void Yield() { - std::this_thread::yield(); -} +void Yield() { std::this_thread::yield(); } int MaxConcurrency() { int max_concurrency = 1; - const char *val = getenv("TVM_NUM_THREADS"); + const char* val = getenv("TVM_NUM_THREADS"); if (val == nullptr) { val = getenv("OMP_NUM_THREADS"); } @@ -255,12 +247,22 @@ int MaxConcurrency() { max_concurrency = std::thread::hardware_concurrency(); #if defined(_M_X64) || defined(__x86_64__) max_concurrency /= 2; // ignore hyper-threading +#elif defined(__hexagon__) + // With unsigned PDs, getting the number of available hardware threads + // is not supported in earlier versions of QuRT. In such cases assume 4. + // If running on simulator, set max_concurrency to 1. + if (max_concurrency == 0) { + if (dlsym(RTLD_DEFAULT, "running_in_sim_dev_17bc90206f6cf5a7")) { + max_concurrency = 1; + } else { + max_concurrency = 4; + } + } #endif } return std::max(max_concurrency, 1); } - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c2036da46e099..c72e70fd6f662 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -28,9 +28,9 @@ #include #include -#include -#include #include +#include +#include #include #include #include @@ -50,24 +50,17 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); // Helper to deserialize a serialized vm instruction. Instruction DeserializeInstruction(const VMInstructionSerializer& instr); -PackedFunc Executable::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Save(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); } else if (name == "get_function_arity") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; @@ -172,7 +165,8 @@ std::string Executable::Stats() const { // Get the number of globals and the name of each of them. oss << " Globals (#" << global_map.size() << "): ["; for (const auto& it : global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + oss << "(\"" << it.first << "\", " << it.second << ")" + << ", "; } if (!global_map.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; @@ -232,8 +226,7 @@ TVMByteArray Executable::Save() { void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector > globals(this->global_map.begin(), this->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -364,8 +357,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); + fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields); break; } case Opcode::AllocClosure: { @@ -373,15 +365,12 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); + fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar); break; } case Opcode::If: { // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, + fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset, instr.if_op.false_offset}); break; } @@ -399,8 +388,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.closure, instr.num_closure_args, instr.dst}); // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); + fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args); break; } case Opcode::LoadConst: { @@ -441,9 +429,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { strm->Write(static_cast(this->functions.size())); for (const auto& func : this->functions) { // Save the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), + VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); func_format.Save(strm); @@ -523,8 +509,7 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { // Extract the `cnt` number of fields started at `start` from the list // `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, +inline std::vector ExtractFields(const std::vector& instr_fields, Index start, Index cnt) { CHECK_LE(static_cast(start + cnt), instr_fields.size()); std::vector ret; @@ -634,11 +619,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { RegName dst = instr.fields[5]; - return Instruction::AllocStorage( - allocation_size, - alignment, - dtype, - dst); + return Instruction::AllocStorage(allocation_size, alignment, dtype, dst); } case Opcode::If: { // Number of fields = 4 @@ -727,9 +708,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, + VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, loaded_func.register_file_size); auto it = this->global_map.find(loaded_func.name); CHECK(it != this->global_map.end()); @@ -738,24 +717,21 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } } -TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->global_map.size()); }); -TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); int idx = args[1]; std::vector > globals(exec->global_map.begin(), exec->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -763,17 +739,14 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") *rv = globals[idx].first; }); -TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->primitive_map.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); @@ -790,11 +763,9 @@ TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") }); TVM_REGISTER_GLOBAL("runtime.Load_Executable") -.set_body_typed([]( - std::string code, - runtime::Module lib) { - return Executable::Load(code, lib); -}); + .set_body_typed([](std::string code, runtime::Module lib) { + return Executable::Load(code, lib); + }); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 3e6140ed38304..c0fd441bb0ca8 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -21,9 +21,11 @@ * \file tvm/runtime/vm/memory_manager.cc * \brief Allocate and manage memory for the runtime. */ -#include -#include #include "memory_manager.h" + +#include +#include + #include "naive_allocator.h" #include "pooled_allocator.h" @@ -35,8 +37,7 @@ static void BufferDeleter(Object* obj) { auto* ptr = static_cast(obj); CHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast(ptr->manager_ctx); - MemoryManager::Global()->GetAllocator(buffer->ctx)-> - Free(*(buffer)); + MemoryManager::Global()->GetAllocator(buffer->ctx)->Free(*(buffer)); delete buffer; delete ptr; } @@ -93,7 +94,7 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa // RAII in effect, now run the check. // TODO(@jroesch): generalize later to non-overlapping allocations. CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + << "size mistmatch required " << needed_size << " found " << this->buffer.size; return ret; } @@ -106,8 +107,8 @@ MemoryManager* MemoryManager::Global() { Allocator* MemoryManager::GetAllocator(TVMContext ctx) { std::lock_guard lock(mu_); if (allocators_.find(ctx) == allocators_.end()) { - DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" - << ctx.device_id << ")"; + DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id + << ")"; std::unique_ptr alloc(new NaiveAllocator(ctx)); allocators_.emplace(ctx, std::move(alloc)); } @@ -120,7 +121,7 @@ NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLContext container->SetDeleter(BufferDeleter); size_t size = GetDataSize(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor); - Buffer *buffer = new Buffer; + Buffer* buffer = new Buffer; *buffer = this->Alloc(size, alignment, dtype); container->manager_ctx = reinterpret_cast(buffer); container->dl_tensor.data = buffer->data; diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index b4453524d9961..f59d584fcfbab 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -27,6 +27,7 @@ #include #include #include + #include #include #include @@ -73,15 +74,13 @@ class Allocator { * \param ctx The context where the array is allocated. * \return The empty NDArray. */ - NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! \brief Allocate a buffer given a size, alignment and type. * \param nbytes The size of the buffer. * \param alignment The alignment of the buffer. * \param type_hint A type hint to the allocator. * \return A sized allocation in the form of a buffer. - */ + */ virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! \brief Free a buffer allocated by the allocator. * \param buffer The buffer to free. @@ -115,9 +114,7 @@ class StorageObj : public Object { Buffer buffer; /*! \brief Allocate an NDArray from a given piece of storage. */ - NDArray AllocNDArray(size_t offset, - std::vector shape, - DLDataType dtype); + NDArray AllocNDArray(size_t offset, std::vector shape, DLDataType dtype); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ static void Deleter(Object* ptr); diff --git a/src/runtime/vm/naive_allocator.h b/src/runtime/vm/naive_allocator.h index db47a62a7c399..5ac2ca61817e5 100644 --- a/src/runtime/vm/naive_allocator.h +++ b/src/runtime/vm/naive_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_NAIVE_ALLOCATOR_H_ #include + #include #include "memory_manager.h" @@ -52,9 +53,7 @@ class NaiveAllocator final : public Allocator { DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; } - size_t UsedMemory() const override { - return used_memory_.load(std::memory_order_relaxed); - } + size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } private: std::atomic used_memory_; diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index 5965a4e8cf233..e09628f72e97b 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ #include + #include #include #include diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 4dac66e50a820..6e4682d1ab96e 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -22,6 +22,8 @@ * \brief The Relay debug virtual machine. */ +#include "vm.h" + #include #include @@ -34,27 +36,24 @@ #include #include -#include "vm.h" - namespace tvm { namespace runtime { namespace vm { -PackedFunc VirtualMachineDebug::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "get_stat") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1U); std::vector> op_acc_time; for (auto kv : op_durations_) { - auto val = std::make_pair( - kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); + auto val = + std::make_pair(kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); op_acc_time.push_back(val); } bool sort_by_time = args[0]; if (sort_by_time) { - auto comp = [](const std::pair& lhs, - const std::pair& rhs) { + auto comp = [](const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; }; std::sort(op_acc_time.begin(), op_acc_time.end(), comp); @@ -74,9 +73,9 @@ PackedFunc VirtualMachineDebug::GetFunction( auto min_value = *std::min_element(vals.begin(), vals.end()); auto max_value = *std::max_element(vals.begin(), vals.end()); - os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" - << std::setw(10) << std::left << op_invokes_[kv.first] << "\t" - << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl; + os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" << std::setw(10) + << std::left << op_invokes_[kv.first] << "\t" << sum << "/" << mean << "/" << min_value + << "/" << max_value << std::endl; total_duration += sum; total_packed_funcs += op_invokes_[kv.first]; @@ -104,10 +103,8 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { } } -void VirtualMachineDebug::InvokePacked(Index packed_index, - const PackedFunc& func, Index arg_count, - Index output_size, - const std::vector& args) { +void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { CHECK(exec_); auto ctx = this->GetParamsContext(); // warmup @@ -119,9 +116,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_end = std::chrono::high_resolution_clock::now(); double op_duration = - std::chrono::duration_cast >(op_end - - op_begin) - .count(); + std::chrono::duration_cast>(op_end - op_begin).count(); op_durations_[packed_index].push_back(op_duration * 1e6); op_invokes_[packed_index] += 1; @@ -133,8 +128,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "Virtual machine has not been defined yet." diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index f0a407fd7266b..c286828231b0d 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -40,16 +40,15 @@ class VirtualMachineDebug : public VirtualMachine { public: VirtualMachineDebug() : VirtualMachine() {} - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void LoadExecutable(const Executable* exec) final; ~VirtualMachineDebug() {} private: - void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, - Index output_size, const std::vector& args) final; + void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, + const std::vector& args) final; std::unordered_map packed_index_map_; std::unordered_map> op_durations_; diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index 3423f7a941673..8bd1f86f88874 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -60,9 +60,7 @@ struct VMFunctionSerializer { VMFunctionSerializer() = default; - VMFunctionSerializer(const std::string& name, - Index register_file_size, - size_t num_instructions, + VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params) : name(name), register_file_size(register_file_size), @@ -87,7 +85,7 @@ struct VMFunctionSerializer { } /*! - * \brief Save the VM function header into the serialized form. + * \brief Save the VM function header into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { @@ -108,11 +106,11 @@ struct VMInstructionSerializer { VMInstructionSerializer() = default; - VMInstructionSerializer(Index opcode, const std::vector& fields) : - opcode(opcode), fields(fields) {} + VMInstructionSerializer(Index opcode, const std::vector& fields) + : opcode(opcode), fields(fields) {} /*! - * \brief Compute the hash of the serialized instruction. + * \brief Compute the hash of the serialized instruction. * \return The hash that combines the opcode and all fields of the VM * instruction. */ @@ -139,13 +137,12 @@ struct VMInstructionSerializer { } Index hash = Hash(); - CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " - << opcode << "\n"; + CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n"; return true; } /*! - * \brief Save the instruction into the serialized form. + * \brief Save the instruction into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index fedbbe9bb0833..0714709a07181 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,11 +23,11 @@ */ #include -#include #include -#include #include #include +#include +#include #include #include @@ -56,8 +56,7 @@ inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint // We could put cache in here, from ctx to storage allocator. auto storage_obj = SimpleObjAllocator().make_object(); auto alloc = MemoryManager::Global()->GetAllocator(ctx); - DCHECK(alloc != nullptr) - << "allocator must not null"; + DCHECK(alloc != nullptr) << "allocator must not null"; storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint); return Storage(storage_obj); } @@ -87,8 +86,8 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return; case Opcode::AllocTensorReg: @@ -151,7 +150,7 @@ Instruction::Instruction(const Instruction& instr) { } } -template +template static inline void FreeIf(T* t) { if (t != nullptr) { delete t; @@ -177,8 +176,8 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return *this; case Opcode::AllocTensorReg: @@ -294,9 +293,7 @@ Instruction Instruction::Fatal() { return instr; } -Instruction Instruction::InvokePacked(Index packed_index, - Index arity, - Index output_size, +Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args) { Instruction instr; instr.op = Opcode::InvokePacked; @@ -310,10 +307,8 @@ Instruction Instruction::InvokePacked(Index packed_index, return instr; } -Instruction Instruction::AllocTensor( - RegName storage, - const std::vector& shape, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(RegName storage, const std::vector& shape, + DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; @@ -327,10 +322,8 @@ Instruction Instruction::AllocTensor( return instr; } -Instruction Instruction::AllocTensorReg( - RegName storage, - RegName shape_register, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, + Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; @@ -340,9 +333,7 @@ Instruction Instruction::AllocTensorReg( return instr; } -Instruction Instruction::AllocStorage(RegName size, - Index alignment, - DLDataType dtype_hint, +Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, Index dst) { Instruction instr; instr.op = Opcode::AllocStorage; @@ -354,7 +345,7 @@ Instruction Instruction::AllocStorage(RegName size, } Instruction Instruction::AllocADT(Index tag, Index num_fields, - const std::vector& datatype_fields, Index dst) { + const std::vector& datatype_fields, Index dst) { Instruction instr; instr.op = Opcode::AllocADT; instr.dst = dst; @@ -486,7 +477,7 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { } } -template +template std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { if (cnt == 0) { return ""; @@ -515,26 +506,21 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { } case Opcode::InvokePacked: { os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $" - << StrJoin(instr.packed_args, 0, - instr.arity - instr.output_size, ", $") + << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ", $") << ", out: $" - << StrJoin(instr.packed_args, instr.arity - instr.output_size, - instr.output_size, ", $") + << StrJoin(instr.packed_args, instr.arity - instr.output_size, instr.output_size, + ", $") << ")"; break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" - << instr.alloc_tensor.storage << " [" - << StrJoin(instr.alloc_tensor.shape, 0, - instr.alloc_tensor.ndim) - << "] "; + os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " [" + << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); break; } case Opcode::AllocTensorReg: { - os << "alloc_tensor_reg $" << instr.dst << " $" - << instr.alloc_tensor_reg.storage << " $" + os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -545,26 +531,24 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocClosure: { - os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index - << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") - << ")"; + os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($" + << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") << ")"; break; } case Opcode::If: { - os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " " - << instr.if_op.true_offset << " " << instr.if_op.false_offset; + os << "if " + << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset + << " " << instr.if_op.false_offset; break; } case Opcode::Invoke: { os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" - << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") - << ")"; + << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")"; break; } case Opcode::InvokeClosure: { os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" - << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") - << ")"; + << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") << ")"; break; } case Opcode::LoadConst: { @@ -576,8 +560,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::GetField: { - os << "get_field $" << instr.dst << " $" << instr.object << "[" - << instr.field_index << "]"; + os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]"; break; } case Opcode::GetTag: { @@ -589,11 +572,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocStorage: { - os << "alloc_storage $" << - instr.dst << " $" << - instr.alloc_storage.allocation_size << " $" << - instr.alloc_storage.alignment << " " << - DLDataType2String(instr.alloc_storage.dtype_hint); + os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " $" + << instr.alloc_storage.alignment << " " + << DLDataType2String(instr.alloc_storage.dtype_hint); break; } default: @@ -637,14 +618,14 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, std::string func_name = args[0]; auto git = exec_->global_map.find(func_name); CHECK(git != exec_->global_map.end()) - << "Cannot find function " << func_name << " in the executable"; + << "Cannot find function " << func_name << " in the executable"; auto func = exec_->functions[git->second]; if (func.params.empty()) { *rv = Invoke(func, {}); } else { auto it = inputs_.find(func_name); CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name; - const std::vector &func_args = it->second; + const std::vector& func_args = it->second; *rv = Invoke(func, func_args); } }); @@ -672,8 +653,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const auto& param_names = vm_func.params; // TODO(icemelon9): For heterogeneous execution, get input device information TVMContext ctx = ctxs_[0]; - CHECK_EQ(args.size() - 1, param_names.size()) << - "The number of provided parameters doesn't match the number of arguments"; + CHECK_EQ(args.size() - 1, param_names.size()) + << "The number of provided parameters doesn't match the number of arguments"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { ObjectRef obj = CopyTo(args[i], ctx); @@ -745,16 +726,14 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { CHECK(exec_) << "The executable has not been created yet."; auto it = exec_->global_map.find(name); - CHECK(it != exec_->global_map.end()) - << "Cannot find function " << name << " in the executable"; + CHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable"; auto func_index_ = it->second; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_; return Invoke(exec_->functions[func_index_], args); } -void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, - Index arg_count, Index output_size, - const std::vector& args) { +void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* obj = args[i].as()) { @@ -806,10 +785,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } - -void VirtualMachine::Init(const std::vector& ctxs) { - ctxs_ = ctxs; -} +void VirtualMachine::Init(const std::vector& ctxs) { ctxs_ = ctxs; } inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames_.back().register_file[r] = val; @@ -893,13 +869,13 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::InvokePacked: { - DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity; + DLOG(INFO) << "InvokedPacked " + << "arity=" << instr.arity; const auto& func = packed_funcs_[instr.packed_index]; const auto& arity = instr.arity; std::vector args; for (Index i = 0; i < arity; ++i) { - DLOG(INFO) << - "arg" << i << " $" << instr.packed_args[i]; + DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i]; auto arg = ReadRegister(instr.packed_args[i]); args.push_back(arg); } @@ -1022,10 +998,8 @@ void VirtualMachine::RunLoop() { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = LoadScalarInt(instr.alloc_storage.alignment); - DLOG(INFO) << - "AllocStorage: allocation_size=" << size << - "alignment=" << alignment << - "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment + << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]); WriteRegister(instr.dst, storage); @@ -1057,8 +1031,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachine") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "The virtual machine executable has not been defined yet."; diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 4e2f8cbcc0bf2..207a86a71c929 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -17,21 +17,19 @@ * under the License. */ -#include #include #include #include #include +#include #include #include - #include "../file_util.h" #include "../pack_args.h" #include "../thread_storage_scope.h" #include "../workspace_pool.h" - #include "vulkan_common.h" #include "vulkan_module.h" #include "vulkan_shader.h" @@ -117,9 +115,7 @@ class VulkanDeviceAPI final : public DeviceAPI { } void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { const auto& vctx = context(ctx.device_id); VkBufferCreateInfo info; @@ -628,9 +624,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { #ifdef USE_VULKAN_IMMEDIATE_MODE if (has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template")) { - ctx.descriptor_template_khr_functions = - std::unique_ptr( - new VulkanDescriptorTemplateKHRFunctions()); + ctx.descriptor_template_khr_functions = std::unique_ptr( + new VulkanDescriptorTemplateKHRFunctions()); ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR = CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( ctx.device, "vkCreateDescriptorUpdateTemplateKHR")); @@ -672,9 +667,7 @@ class VulkanModuleNode; // a wrapped function class to get packed func. class VulkanWrappedFunc { public: - void Init(VulkanModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, + void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { m_ = m; @@ -710,13 +703,12 @@ class VulkanWrappedFunc { class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) + std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} const char* type_key() const final { return "vulkan"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); @@ -751,7 +743,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { } std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, - size_t num_pack_args) { + size_t num_pack_args) { const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; @@ -776,6 +768,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_binding; std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; + { auto fit = fmap_.find(func_name); CHECK(fit != fmap_.end()); @@ -931,8 +924,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { } private: - // the binary data - std::vector data_; // function information table. std::unordered_map smap_; // function information table. @@ -1021,8 +1012,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; CHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 9242d3d6d6806..780b11184931e 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -22,8 +22,8 @@ #include #include #include - #include + #include #include #include @@ -140,7 +140,6 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/vulkan_shader.h index 1b2e45458f9ce..d56ca61e91cbd 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/vulkan/vulkan_shader.h @@ -18,7 +18,6 @@ */ #pragma once - #include #include #include diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index 1a24d2873a60f..388cacc577b08 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -20,12 +20,11 @@ #include #include -#include #include +#include #include "vulkan_common.h" - namespace tvm { namespace runtime { namespace vulkan { @@ -44,8 +43,7 @@ struct VulkanStreamToken { class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx) - : vctx_(vctx), state_(new VulkanStreamState()) { + explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) { // create command pool VkCommandPoolCreateInfo cmd_pool_cinfo; cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index fc316cdeded10..8ee905e4ea846 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -21,9 +21,10 @@ * \file workspace_pool.h * \brief Workspace pool utility. */ -#include #include "workspace_pool.h" +#include + namespace tvm { namespace runtime { @@ -67,7 +68,8 @@ class WorkspacePool::Pool { if (free_list_.back().size >= nbytes) { // find smallest fit auto it = free_list_.end() - 2; - for (; it->size >= nbytes; --it) {} + for (; it->size >= nbytes; --it) { + } e = *(it + 1); free_list_.erase(it + 1); } else { @@ -91,7 +93,8 @@ class WorkspacePool::Pool { allocated_.pop_back(); } else { int index = static_cast(allocated_.size()) - 2; - for (; index > 0 && allocated_[index].data != data; --index) {} + for (; index > 0 && allocated_[index].data != data; --index) { + } CHECK_GT(index, 0) << "trying to free things that has not been allocated"; e = allocated_[index]; allocated_.erase(allocated_.begin() + index); @@ -132,8 +135,7 @@ class WorkspacePool::Pool { }; WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr device) - : device_type_(device_type), device_(device) { -} + : device_type_(device_type), device_(device) {} WorkspacePool::~WorkspacePool() { for (size_t i = 0; i < array_.size(); ++i) { @@ -158,8 +160,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) { } void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) { - CHECK(static_cast(ctx.device_id) < array_.size() && - array_[ctx.device_id] != nullptr); + CHECK(static_cast(ctx.device_id) < array_.size() && array_[ctx.device_id] != nullptr); array_[ctx.device_id]->Free(ptr); } diff --git a/src/runtime/workspace_pool.h b/src/runtime/workspace_pool.h index 72613caffb8e9..288da7d104838 100644 --- a/src/runtime/workspace_pool.h +++ b/src/runtime/workspace_pool.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,9 @@ #define TVM_RUNTIME_WORKSPACE_POOL_H_ #include -#include + #include +#include namespace tvm { namespace runtime { diff --git a/src/support/arena.h b/src/support/arena.h index 744ff4f121883..cb08db93641d6 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -26,42 +26,107 @@ #ifndef TVM_SUPPORT_ARENA_H_ #define TVM_SUPPORT_ARENA_H_ -#include +#ifndef TVM_ARENA_HAS_DESTRUCTOR +#define TVM_ARENA_HAS_DESTRUCTOR 1 +#endif + +#include #include +#include namespace tvm { namespace support { -const constexpr int kArenaPageSize = 16 << 10; +/*! + * \brief An arena page header. + */ +struct ArenaPageHeader { + /*! \brief points to the next page. */ + ArenaPageHeader* next; + /*! + * \brief Total size of the page. + */ + size_t size; + /*! \brief memory allocator offset inside page. */ + size_t offset; +}; + +/*! + * \brief Simple page allocator that uses new and delete. + */ +class SimplePageAllocator { + public: + /*! + * \brief Allocate a new page. + * \param min_size Minimum size of the page. + * \return The allocated page. + * \note This function can return a bigger page to meet the min_size requirement. + */ + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + ArenaPageHeader* header = reinterpret_cast(new Page[npages]); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + /*! + * \brief De-allocate an allocate page. + * \param page The page to be de-allocated. + */ + void deallocate(ArenaPageHeader* page) { delete[] reinterpret_cast(page); } + + static const constexpr int kPageSize = 16 << 10; + static const constexpr int kPageAlign = 1024; + + private: + // page size 16 KB + // The page data type; + using Page = std::aligned_storage::type; +}; /*! * \brief Arena allocator that allocates memory from continuous * chunk and frees them all only during destruction. */ -class Arena { +template +class GenericArena { public: - Arena() { + explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) { // eagerly allocate the first page. - head_ = reinterpret_cast(new Page()); + head_ = tail_ = alloc_.allocate(1); head_->next = nullptr; - head_->ptr = sizeof(PageHeader); } - ~Arena() { - // delete all the allocated pages. - while (head_ != nullptr) { - Page* page = reinterpret_cast(head_); - head_ = head_->next; - delete page; - } + +#if TVM_ARENA_HAS_DESTRUCTOR + ~GenericArena() { this->FreeAll(); } +#endif + + /*! \brief Free all pages. */ + void FreeAll() { + FreePageList(&head_); + FreePageList(&free_list_); + } + /*! \brief Recycle all the pages in the arena */ + void RecycleAll() { + // put all the current list to the free list. + tail_->next = free_list_; + // allocate the first in the free list to head + free_list_ = head_->next; + head_->next = nullptr; + // Reset the head. + head_->offset = sizeof(ArenaPageHeader); + tail_ = head_; } /*! * \brief Allocate a space from Arena for type T * \param T the data type to be allocated + * \param count Numberof elements * \note The space of T is not initialized. */ - template - T* allocate_() { - return static_cast(Alloc(sizeof(T), alignof(T))); + template + T* allocate_(int count = 1) { + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment"); + return static_cast(Alloc(sizeof(T) * count, alignof(T))); } /*! * \brief Create a new instance of type T. @@ -74,7 +139,7 @@ class Arena { * memory allocated from the same arena. * Otherwise the destructor needs to be called explicitly. */ - template + template T* make(Args&&... args) { T* ptr = allocate_(); new (ptr) T(std::forward(args)...); @@ -82,25 +147,21 @@ class Arena { } private: - // page size 16 KB - // The page data type; - using Page = std::aligned_storage::type; - /*! \brief Page header */ - struct PageHeader { - /*! \brief points to the next page */ - PageHeader* next; - /*! \brief memory allocator ptr inside page */ - size_t ptr; - }; - /* \brief The page header */ - PageHeader* head_{nullptr}; + /*! \brief internal page allocator. */ + PageAllocator alloc_; + /* \brief The the head of the allocated list. */ + ArenaPageHeader* head_{nullptr}; + /*! \brief The tail of the allocated list. */ + ArenaPageHeader* tail_{nullptr}; + /* \brief List of free pages. */ + ArenaPageHeader* free_list_{nullptr}; /*! * \brief Align ptr by upper bound. - * \param ptr The pointer value. + * \param offset The offset value. * \param align The alignment requirement. */ - size_t UpperAlign(size_t ptr, size_t align) { - return ptr + (align - (ptr % align)) % align; + size_t UpperAlign(size_t offset, size_t align) { + return offset + (align - (offset % align)) % align; } /*! * \brief Internal aligned alloc function. @@ -108,27 +169,46 @@ class Arena { * \param align The alignment requirement. */ void* Alloc(size_t size, size_t align) { - size_t ptr = UpperAlign(head_->ptr, align); - if (ptr + size <= kArenaPageSize) { - head_->ptr = ptr + size; - return reinterpret_cast(head_) + ptr; + size_t offset = UpperAlign(head_->offset, align); + if (offset + size <= head_->size) { + head_->offset = offset + size; + return reinterpret_cast(head_) + offset; } else { - PageHeader* new_head = reinterpret_cast(new Page()); + ArenaPageHeader* new_head; + offset = UpperAlign(sizeof(ArenaPageHeader), align); + if (free_list_ != nullptr && offset + size <= free_list_->size) { + new_head = free_list_; + free_list_ = free_list_->next; + } else { + new_head = alloc_.allocate(offset + size); + } new_head->next = head_; - ptr = UpperAlign(sizeof(PageHeader), align); - CHECK_LE(ptr + size, kArenaPageSize); - new_head->ptr = ptr + size; + new_head->offset = offset + size; head_ = new_head; - return reinterpret_cast(head_) + ptr; + return reinterpret_cast(head_) + offset; + } + } + /*! + * \brief Free all the pages in the list. + * \param ptr The head ptr. + */ + void FreePageList(ArenaPageHeader** ptr) { + // delete all the allocated pages. + while (ptr[0] != nullptr) { + ArenaPageHeader* temp = ptr[0]; + ptr[0] = ptr[0]->next; + alloc_.deallocate(temp); } } }; +using Arena = GenericArena; + /*! * \brief Link list node * \tparam T the content data type */ -template +template struct LinkNode { /*! \brief The content value */ T value; @@ -141,7 +221,7 @@ struct LinkNode { * \note This is a simple data structure that can be used together with the arena. * \sa LinkNode */ -template +template struct LinkedList { /*! \brief Head pointer */ LinkNode* head{nullptr}; diff --git a/src/support/base64.h b/src/support/base64.h index c85b268fd7b66..9849542471c29 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -27,7 +27,7 @@ #define TVM_SUPPORT_BASE64_H_ #include -#include + #include #include #include @@ -38,18 +38,16 @@ namespace support { namespace base64 { // decoding table const char DecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, - 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' }; // encoding table static const char EncodeTable[] = @@ -62,14 +60,12 @@ static const char EncodeTable[] = */ class StreamBufferReader { public: - explicit StreamBufferReader(size_t buffer_size) { - buffer_.resize(buffer_size); - } + explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); } /*! * \brief set input stream * \param stream The stream to be set */ - void set_stream(dmlc::Stream *stream) { + void set_stream(dmlc::Stream* stream) { stream_ = stream; read_len_ = read_ptr_ = 1; } @@ -88,13 +84,11 @@ class StreamBufferReader { } } /*! \return whether we are reaching the end of file */ - bool AtEnd() const { - return read_len_ == 0; - } + bool AtEnd() const { return read_len_ == 0; } private: /*! \brief the underlying stream */ - dmlc::Stream *stream_{nullptr}; + dmlc::Stream* stream_{nullptr}; /*! \brief buffer to hold data */ std::string buffer_; /*! \brief length of valid data in buffer */ @@ -106,11 +100,9 @@ class StreamBufferReader { /*! * \brief Input stream from base64 encoding */ -class Base64InStream: public dmlc::Stream { +class Base64InStream : public dmlc::Stream { public: - explicit Base64InStream(dmlc::Stream *fs) : reader_(256) { - reader_.set_stream(fs); - } + explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); } /*! * \brief initialize the stream position to beginning of next base64 stream * \note call this function before actually start read @@ -122,16 +114,14 @@ class Base64InStream: public dmlc::Stream { } while (isspace(temp_ch_)); } /*! \brief whether current position is end of a base64 stream */ - bool IsEOF(void) const { - return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); - } + bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } // override read function. - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { using base64::DecodeTable; if (size == 0) return 0; // use tlen to record left size size_t tlen = size; - unsigned char *cptr = static_cast(ptr); + unsigned char* cptr = static_cast(ptr); // if anything left, load from previous buffered result if (num_prev_ != 0) { if (num_prev_ == 2) { @@ -142,13 +132,16 @@ class Base64InStream: public dmlc::Stream { num_prev_ = 0; } else { // assert tlen == 1 - *cptr++ = buf_prev[0]; --tlen; + *cptr++ = buf_prev[0]; + --tlen; buf_prev[0] = buf_prev[1]; num_prev_ = 1; } } else { // assert num_prev_ == 1 - *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0; + *cptr++ = buf_prev[0]; + --tlen; + num_prev_ = 0; } } if (tlen == 0) return size; @@ -163,8 +156,9 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; nvalue |= DecodeTable[temp_ch_] << 12; - *cptr++ = (nvalue >> 16) & 0xFF; --tlen; - } + *cptr++ = (nvalue >> 16) & 0xFF; + --tlen; + } { // third byte temp_ch_ = reader_.GetChar(); @@ -174,13 +168,13 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == '=') << "invalid base64 format"; temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_] << 6; if (tlen) { - *cptr++ = (nvalue >> 8) & 0xFF; --tlen; + *cptr++ = (nvalue >> 8) & 0xFF; + --tlen; } else { buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; } @@ -188,19 +182,18 @@ class Base64InStream: public dmlc::Stream { { // fourth byte temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_]; if (tlen) { - *cptr++ = nvalue & 0xFF; --tlen; + *cptr++ = nvalue & 0xFF; + --tlen; } else { - buf_prev[num_prev_ ++] = nvalue & 0xFF; + buf_prev[num_prev_++] = nvalue & 0xFF; } } // get next char @@ -211,7 +204,7 @@ class Base64InStream: public dmlc::Stream { } return size - tlen; } - virtual void Write(const void *ptr, size_t size) { + virtual void Write(const void* ptr, size_t size) { LOG(FATAL) << "Base64InStream do not support write"; } @@ -228,17 +221,17 @@ class Base64InStream: public dmlc::Stream { /*! * \brief Stream to write to base64 format. */ -class Base64OutStream: public dmlc::Stream { +class Base64OutStream : public dmlc::Stream { public: - explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) { - } - virtual void Write(const void *ptr, size_t size) { + explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {} + virtual void Write(const void* ptr, size_t size) { using base64::EncodeTable; size_t tlen = size; - const unsigned char *cptr = static_cast(ptr); + const unsigned char* cptr = static_cast(ptr); while (tlen) { - while (buf__top_ < 3 && tlen != 0) { - buf_[++buf__top_] = *cptr++; --tlen; + while (buf__top_ < 3 && tlen != 0) { + buf_[++buf__top_] = *cptr++; + --tlen; } if (buf__top_ == 3) { // flush 4 bytes out @@ -250,7 +243,7 @@ class Base64OutStream: public dmlc::Stream { } } } - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; return 0; } @@ -280,12 +273,11 @@ class Base64OutStream: public dmlc::Stream { private: static constexpr size_t kBufferSize = 256; - dmlc::Stream *fp_{nullptr}; + dmlc::Stream* fp_{nullptr}; int buf__top_{0}; unsigned char buf_[4]; std::string out_buf_; - void PutChar(char ch) { out_buf_ += ch; if (out_buf_.length() >= kBufferSize) Flush(); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 622e28ef170a4..2b944cb2ddd61 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -17,15 +17,15 @@ * under the License. */ - /*! +/*! * FFI registration code used for frontend testing purposes. * \file ffi_testing.cc */ -#include -#include -#include #include #include +#include +#include +#include namespace tvm { // Attrs used to python API @@ -36,16 +36,10 @@ struct TestAttrs : public AttrsNode { TypedEnvFunc func; TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name"); - TVM_ATTR_FIELD(padding) - .describe("padding of input") - .set_default(Array({0, 0})); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name"); + TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array({0, 0})); TVM_ATTR_FIELD(func) .describe("some random env function") .set_default(TypedEnvFunc(nullptr)); @@ -54,49 +48,37 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_GLOBAL("testing.nop") -.set_body([](TVMArgs args, TVMRetValue *ret) { - }); +TVM_REGISTER_GLOBAL("testing.nop").set_body([](TVMArgs args, TVMRetValue* ret) {}); -TVM_REGISTER_GLOBAL("testing.echo") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0]; - }); +}); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc pf = args[0]; - *ret = runtime::TypedPackedFunc([pf](){ - pf(); - }); - }); +TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); TVM_REGISTER_GLOBAL("testing.test_raise_error_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](){ - LOG(FATAL) << msg; - }); - }); - -TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](int x, int y){ - CHECK_EQ(x, y) << msg; - }); - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = runtime::TypedPackedFunc([msg]() { LOG(FATAL) << msg; }); + }); -TVM_REGISTER_GLOBAL("testing.context_test") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLContext ctx = args[0]; - int dtype = args[1]; - int did = args[2]; - CHECK_EQ(static_cast(ctx.device_type), dtype); - CHECK_EQ(static_cast(ctx.device_id), did); - *ret = ctx; - }); +TVM_REGISTER_GLOBAL("testing.test_check_eq_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = + runtime::TypedPackedFunc([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); +}); +TVM_REGISTER_GLOBAL("testing.context_test").set_body([](TVMArgs args, TVMRetValue* ret) { + DLContext ctx = args[0]; + int dtype = args[1]; + int did = args[2]; + CHECK_EQ(static_cast(ctx.device_type), dtype); + CHECK_EQ(static_cast(ctx.device_id), did); + *ret = ctx; +}); // in src/api_test.cc void ErrorTest(int x, int y) { @@ -108,15 +90,13 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_GLOBAL("testing.ErrorTest") -.set_body_typed(ErrorTest); +TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count") -.set_body([](TVMArgs args, TVMRetValue *ret) { - runtime::ObjectRef obj = args[0]; - // substract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); - }); +TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRetValue* ret) { + runtime::ObjectRef obj = args[0]; + // substract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); +}); } // namespace tvm diff --git a/src/support/pipe.h b/src/support/pipe.h index 120bbdb95e77c..dcebd0ddf32f3 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -24,16 +24,17 @@ #ifndef TVM_SUPPORT_PIPE_H_ #define TVM_SUPPORT_PIPE_H_ -#include #include +#include #ifdef _WIN32 #include #else -#include #include -#include +#include + #include +#include #endif namespace tvm { @@ -48,12 +49,9 @@ class Pipe : public dmlc::Stream { using PipeHandle = int; #endif /*! \brief Construct a pipe from system handle. */ - explicit Pipe(int64_t handle) - : handle_(static_cast(handle)) {} + explicit Pipe(int64_t handle) : handle_(static_cast(handle)) {} /*! \brief destructor */ - ~Pipe() { - Flush(); - } + ~Pipe() { Flush(); } using Stream::Read; using Stream::Write; /*! @@ -62,18 +60,16 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - size_t Read(void *ptr, size_t size) final { + size_t Read(void* ptr, size_t size) final { if (size == 0) return 0; #ifdef _WIN32 DWORD nread; - CHECK(ReadFile(handle_, static_cast(ptr), - &nread, nullptr)) + CHECK(ReadFile(handle_, static_cast(ptr), &nread, nullptr)) << "Read Error: " << GetLastError(); #else ssize_t nread; nread = read(handle_, ptr, size); - CHECK_GE(nread, 0) - << "Write Error: " << strerror(errno); + CHECK_GE(nread, 0) << "Write Error: " << strerror(errno); #endif return static_cast(nread); } @@ -83,19 +79,17 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - void Write(const void *ptr, size_t size) final { + void Write(const void* ptr, size_t size) final { if (size == 0) return; #ifdef _WIN32 DWORD nwrite; - CHECK(WriteFile(handle_, static_cast(ptr), - &nwrite, nullptr) && + CHECK(WriteFile(handle_, static_cast(ptr), &nwrite, nullptr) && static_cast(nwrite) == size) << "Write Error: " << GetLastError(); #else ssize_t nwrite; nwrite = write(handle_, ptr, size); - CHECK_EQ(static_cast(nwrite), size) - << "Write Error: " << strerror(errno); + CHECK_EQ(static_cast(nwrite), size) << "Write Error: " << strerror(errno); #endif } /*! diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index e6e3b04ec7a98..a3938491f1d15 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -24,9 +24,9 @@ #ifndef TVM_SUPPORT_RING_BUFFER_H_ #define TVM_SUPPORT_RING_BUFFER_H_ -#include -#include #include +#include +#include namespace tvm { namespace support { @@ -41,41 +41,48 @@ class RingBuffer { /*! \brief constructor */ RingBuffer() : ring_(kInitCapacity) {} /*! \return number of bytes available in buffer. */ - size_t bytes_available() const { - return bytes_available_; - } + size_t bytes_available() const { return bytes_available_; } /*! \return Current capacity of buffer. */ - size_t capacity() const { - return ring_.size(); - } + size_t capacity() const { return ring_.size(); } /*! - * Reserve capacity to be at least n. - * Will only increase capacity if n is bigger than current capacity. + * Reserve capacity to be at least n. + * Will only increase capacity if n is bigger than current capacity. + * + * The effect of Reserve only lasts before the next call to Reserve. + * Other functions in the ring buffer can also call into the reserve. + * * \param n The size of capacity. */ void Reserve(size_t n) { if (ring_.size() < n) { - size_t old_size = ring_.size(); - size_t new_size = static_cast(n * 1.2); - ring_.resize(new_size); - if (head_ptr_ + bytes_available_ > old_size) { - // copy the ring overflow part into the tail. - size_t ncopy = head_ptr_ + bytes_available_ - old_size; - memcpy(&ring_[0] + old_size, &ring_[0], ncopy); - } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { - // shrink too large temporary buffer to avoid out of memory on some embedded devices + size_t old_size = ring_.size(); + size_t new_size = static_cast(n * 1.2); + ring_.resize(new_size); + if (head_ptr_ + bytes_available_ > old_size) { + // copy the ring overflow part into the tail. + size_t ncopy = head_ptr_ + bytes_available_ - old_size; + memcpy(&ring_[0] + old_size, &ring_[0], ncopy); + } + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { + // shrink too large temporary buffer to + // avoid out of memory on some embedded devices + if (bytes_available_ != 0) { + // move existing bytes to the head. size_t old_bytes = bytes_available_; - std::vector tmp(old_bytes); - Read(&tmp[0], old_bytes); - ring_.resize(kInitCapacity); - ring_.shrink_to_fit(); memcpy(&ring_[0], &tmp[0], old_bytes); - head_ptr_ = 0; bytes_available_ = old_bytes; + } + // shrink the ring. + size_t new_size = kInitCapacity; + new_size = std::max(new_size, n); + new_size = std::max(new_size, bytes_available_); + + ring_.resize(new_size); + ring_.shrink_to_fit(); + head_ptr_ = 0; } } @@ -90,8 +97,7 @@ class RingBuffer { size_t ncopy = std::min(size, ring_.size() - head_ptr_); memcpy(data, &ring_[0] + head_ptr_, ncopy); if (ncopy < size) { - memcpy(reinterpret_cast(data) + ncopy, - &ring_[0], size - ncopy); + memcpy(reinterpret_cast(data) + ncopy, &ring_[0], size - ncopy); } head_ptr_ = (head_ptr_ + size) % ring_.size(); bytes_available_ -= size; @@ -103,7 +109,7 @@ class RingBuffer { * \param max_nbytes Maximum number of bytes can to read. * \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size); */ - template + template size_t ReadWithCallback(FSend fsend, size_t max_nbytes) { size_t size = std::min(max_nbytes, bytes_available_); CHECK_NE(size, 0U); @@ -137,13 +143,13 @@ class RingBuffer { bytes_available_ += size; } /*! - * \brief Writen data into the buffer by give it a non-blocking callback function. + * \brief Written data into the buffer by give it a non-blocking callback function. * * \param frecv A receive function handle * \param max_nbytes Maximum number of bytes can write. * \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size); */ - template + template size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) { this->Reserve(bytes_available_ + max_nbytes); size_t nbytes = max_nbytes; @@ -168,9 +174,9 @@ class RingBuffer { private: // buffer head size_t head_ptr_{0}; - // number of bytes in the buffer. + // number of bytes occupied in the buffer. size_t bytes_available_{0}; - // The internald ata ring. + // The internal data ring. std::vector ring_; }; } // namespace support diff --git a/src/support/socket.h b/src/support/socket.h index aeb4626b5d471..3ccfaaab5ab59 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -35,26 +35,27 @@ using ssize_t = int; #pragma comment(lib, "Ws2_32.lib") #endif #else +#include +#include #include #include -#include -#include -#include #include -#include -#include #include +#include +#include +#include #endif #include -#include + #include -#include +#include #include +#include + #include "../support/util.h" #if defined(_WIN32) -static inline int poll(struct pollfd *pfd, int nfds, - int timeout) { +static inline int poll(struct pollfd* pfd, int nfds, int timeout) { return WSAPoll(pfd, nfds, timeout); } #else @@ -68,7 +69,8 @@ namespace support { * \return The hostname. */ inline std::string GetHostName() { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); CHECK_NE(gethostname(&buf[0], 256), -1); return std::string(buf.c_str()); } @@ -100,16 +102,14 @@ struct SockAddr { * \param url The url of the address * \param port The port of the address. */ - SockAddr(const char *url, int port) { - this->Set(url, port); - } + SockAddr(const char* url, int port) { this->Set(url, port); } /*! - * \brief SockAddr Get the socket address from tracker. - * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) - * \return SockAddr parsed from url. - */ - explicit SockAddr(const std::string &url) { + * \brief SockAddr Get the socket address from tracker. + * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) + * \return SockAddr parsed from url. + */ + explicit SockAddr(const std::string& url) { size_t sep = url.find(","); std::string host = url.substr(2, sep - 3); std::string port = url.substr(sep + 1, url.length() - 1); @@ -125,31 +125,28 @@ struct SockAddr { * \param host the url of the address * \param port the port of address */ - void Set(const char *host, int port) { + void Set(const char* host, int port) { addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = PF_UNSPEC; hints.ai_flags = AI_PASSIVE; hints.ai_socktype = SOCK_STREAM; - addrinfo *res = NULL; + addrinfo* res = NULL; int sig = getaddrinfo(host, NULL, &hints, &res); - CHECK(sig == 0 && res != NULL) - << "cannot obtain address of " << host; + CHECK(sig == 0 && res != NULL) << "cannot obtain address of " << host; switch (res->ai_family) { case AF_INET: { - sockaddr_in *addr4 = reinterpret_cast(&addr); - memcpy(addr4, res->ai_addr, res->ai_addrlen); - addr4->sin_port = htons(port); - addr4->sin_family = AF_INET; - } - break; + sockaddr_in* addr4 = reinterpret_cast(&addr); + memcpy(addr4, res->ai_addr, res->ai_addrlen); + addr4->sin_port = htons(port); + addr4->sin_family = AF_INET; + } break; case AF_INET6: { - sockaddr_in6 *addr6 = reinterpret_cast(&addr); - memcpy(addr6, res->ai_addr, res->ai_addrlen); - addr6->sin6_port = htons(port); - addr6->sin6_family = AF_INET6; - } - break; + sockaddr_in6* addr6 = reinterpret_cast(&addr); + memcpy(addr6, res->ai_addr, res->ai_addrlen); + addr6->sin6_port = htons(port); + addr6->sin6_family = AF_INET6; + } break; default: CHECK(false) << "cannot decode address"; } @@ -157,35 +154,34 @@ struct SockAddr { } /*! \brief return port of the address */ int port() const { - return ntohs((addr.ss_family == AF_INET6)? \ - reinterpret_cast(&addr)->sin6_port : \ - reinterpret_cast(&addr)->sin_port); + return ntohs((addr.ss_family == AF_INET6) + ? reinterpret_cast(&addr)->sin6_port + : reinterpret_cast(&addr)->sin_port); } /*! \brief return the ip address family */ - int ss_family() const { - return addr.ss_family; - } + int ss_family() const { return addr.ss_family; } /*! \return a string representation of the address */ std::string AsString() const { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); - const void *sinx_addr = nullptr; - if (addr.ss_family == AF_INET6) { - const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; - sinx_addr = reinterpret_cast(&addr6); - } else if (addr.ss_family == AF_INET) { - const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; - sinx_addr = reinterpret_cast(&addr4); - } else { - CHECK(false) << "illegal address"; - } + const void* sinx_addr = nullptr; + if (addr.ss_family == AF_INET6) { + const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; + sinx_addr = reinterpret_cast(&addr6); + } else if (addr.ss_family == AF_INET) { + const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; + sinx_addr = reinterpret_cast(&addr4); + } else { + CHECK(false) << "illegal address"; + } #ifdef _WIN32 - const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) + const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) &buf[0], buf.length()); #else - const char *s = inet_ntop(addr.ss_family, sinx_addr, - &buf[0], static_cast(buf.length())); + const char* s = + inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast(buf.length())); #endif CHECK(s != nullptr) << "cannot decode address"; std::ostringstream os; @@ -238,10 +234,10 @@ class Socket { * \brief bind the socket to an address * \param addr The address to be binded */ - void Bind(const SockAddr &addr) { + void Bind(const SockAddr& addr) { if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == -1) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + -1) { Socket::Error("Bind"); } } @@ -256,8 +252,8 @@ class Socket { for (int port = start_port; port < end_port; ++port) { SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + 0) { return port; } else { LOG(WARNING) << "Bind failed to " << host << ":" << port; @@ -278,7 +274,7 @@ class Socket { int GetSockError() const { int error = 0; socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { Error("GetSockError"); } return error; @@ -291,9 +287,7 @@ class Socket { return false; } /*! \brief check if socket is already closed */ - bool IsClosed() const { - return sockfd == INVALID_SOCKET; - } + bool IsClosed() const { return sockfd == INVALID_SOCKET; } /*! \brief close the socket */ void Close() { if (sockfd != INVALID_SOCKET) { @@ -354,7 +348,7 @@ class Socket { * \brief Report an socket error. * \param msg The error message. */ - static void Error(const char *msg) { + static void Error(const char* msg) { int errsv = GetLastError(); #ifdef _WIN32 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; @@ -364,8 +358,7 @@ class Socket { } protected: - explicit Socket(SockType sockfd) : sockfd(sockfd) { - } + explicit Socket(SockType sockfd) : sockfd(sockfd) {} }; /*! @@ -373,22 +366,20 @@ class Socket { */ class TCPSocket : public Socket { public: - TCPSocket() : Socket(INVALID_SOCKET) { - } + TCPSocket() : Socket(INVALID_SOCKET) {} /*! * \brief construct a TCP socket from existing descriptor * \param sockfd The descriptor */ - explicit TCPSocket(SockType sockfd) : Socket(sockfd) { - } + explicit TCPSocket(SockType sockfd) : Socket(sockfd) {} /*! * \brief enable/disable TCP keepalive * \param keepalive whether to set the keep alive option on */ void SetKeepAlive(bool keepalive) { int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(opt)) < + 0) { Socket::Error("SetKeepAlive"); } } @@ -406,9 +397,7 @@ class TCPSocket : public Socket { * \brief perform listen of the socket * \param backlog backlog parameter */ - void Listen(int backlog = 16) { - listen(sockfd, backlog); - } + void Listen(int backlog = 16) { listen(sockfd, backlog); } /*! * \brief get a new connection * \return The accepted socket connection. @@ -421,14 +410,13 @@ class TCPSocket : public Socket { return TCPSocket(newfd); } /*! - * \brief get a new connection - * \param addr client address from which connection accepted - * \return The accepted socket connection. - */ - TCPSocket Accept(SockAddr *addr) { + * \brief get a new connection + * \param addr client address from which connection accepted + * \return The accepted socket connection. + */ + TCPSocket Accept(SockAddr* addr) { socklen_t addrlen = sizeof(addr->addr); - SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), - &addrlen); + SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), &addrlen); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } @@ -453,10 +441,10 @@ class TCPSocket : public Socket { * \param addr the address to connect to * \return whether connect is successful */ - bool Connect(const SockAddr &addr) { - return connect(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0; + bool Connect(const SockAddr& addr) { + return connect( + sockfd, reinterpret_cast(&addr.addr), + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; } /*! * \brief send data using the socket @@ -466,8 +454,8 @@ class TCPSocket : public Socket { * \return size of data actually sent * return -1 if error occurs */ - ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); + ssize_t Send(const void* buf_, size_t len, int flag = 0) { + const char* buf = reinterpret_cast(buf_); return send(sockfd, buf, static_cast(len), flag); } /*! @@ -478,8 +466,8 @@ class TCPSocket : public Socket { * \return size of data actually received * return -1 if error occurs */ - ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); + ssize_t Recv(void* buf_, size_t len, int flags = 0) { + char* buf = reinterpret_cast(buf_); return recv(sockfd, buf, static_cast(len), flags); } /*! @@ -489,10 +477,10 @@ class TCPSocket : public Socket { * \param len the size of the buffer * \return size of data actually sent */ - size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); + size_t SendAll(const void* buf_, size_t len) { + const char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { + while (ndone < len) { ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); if (ret == -1) { if (LastErrorWouldBlock()) return ndone; @@ -510,14 +498,13 @@ class TCPSocket : public Socket { * \param len length of data to recv * \return size of data actually sent */ - size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); + size_t RecvAll(void* buf_, size_t len) { + char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); if (ret == -1) { - if (LastErrorWouldBlock()) { + if (LastErrorWouldBlock()) { LOG(FATAL) << "would block"; return ndone; } @@ -612,7 +599,7 @@ struct PollHelper { * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) pollfd pfd; pfd.fd = fd; pfd.events = POLLPRI; diff --git a/src/support/str_escape.h b/src/support/str_escape.h index fd25c019e6dce..65eec682086e2 100644 --- a/src/support/str_escape.h +++ b/src/support/str_escape.h @@ -25,8 +25,8 @@ #ifndef TVM_SUPPORT_STR_ESCAPE_H_ #define TVM_SUPPORT_STR_ESCAPE_H_ -#include #include +#include namespace tvm { namespace support { @@ -76,9 +76,7 @@ inline std::string StrEscape(const char* data, size_t size) { * \param size The size of the string. * \return the Result string. */ -inline std::string StrEscape(const std::string& val) { - return StrEscape(val.data(), val.length()); -} +inline std::string StrEscape(const std::string& val) { return StrEscape(val.data(), val.length()); } } // namespace support } // namespace tvm diff --git a/src/support/util.h b/src/support/util.h index 9a477e6f81f2c..859b372bd761c 100644 --- a/src/support/util.h +++ b/src/support/util.h @@ -26,16 +26,16 @@ #include #ifndef _WIN32 -#include #include +#include #endif -#include -#include -#include #include #include #include #include +#include +#include +#include namespace tvm { namespace support { @@ -92,15 +92,14 @@ inline int TVMWexitstatus(int status) { #endif } - /*! * \brief IsNumber check whether string is a number. * \param str input string * \return result of operation. */ inline bool IsNumber(const std::string& str) { - return !str.empty() && std::find_if(str.begin(), - str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); + return !str.empty() && + std::find_if(str.begin(), str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); } /*! diff --git a/src/target/build_common.h b/src/target/build_common.h index 93687c2578acc..ec5b522397ed3 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -24,27 +24,27 @@ #ifndef TVM_TARGET_BUILD_COMMON_H_ #define TVM_TARGET_BUILD_COMMON_H_ -#include -#include -#include #include -#include +#include +#include +#include #include +#include #include -#include + #include +#include + #include "../runtime/meta_data.h" namespace tvm { namespace codegen { -inline std::unordered_map -ExtractFuncInfo(const IRModule& mod) { +inline std::unordered_map ExtractFuncInfo(const IRModule& mod) { std::unordered_map fmap; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 848d27f42575f..e3890cac51591 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -21,23 +21,22 @@ * \file codegen.cc * \brief Common utilities to generated C style code. */ +#include +#include +#include +#include +#include +#include #include #include - -#include -#include #include +#include -#include -#include -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include namespace tvm { namespace codegen { @@ -50,17 +49,14 @@ runtime::Module Build(IRModule mod, const Target& target) { std::string build_f_name = "target.build." + target->target_name; // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) - << "target.build." << target << " is not enabled"; + CHECK(bf != nullptr) << "target.build." << target << " is not enabled"; return (*bf)(mod, target->str()); } /*! \brief Helper class to serialize module */ class ModuleSerializer { public: - explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { - Init(); - } + explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } void SerializeModule(dmlc::Stream* stream) { // Only have one DSO module and it is in the root, then @@ -109,8 +105,8 @@ class ModuleSerializer { // invariance: root module is always at location 0. // The module order is collected via DFS void CreateModuleIndex() { - std::unordered_set visited {mod_.operator->()}; - std::vector stack {mod_.operator->()}; + std::unordered_set visited{mod_.operator->()}; + std::vector stack{mod_.operator->()}; uint64_t module_index = 0; while (!stack.empty()) { @@ -139,8 +135,7 @@ class ModuleSerializer { } bool DSOExportable(const runtime::ModuleNode* mod) { - return !std::strcmp(mod->type_key(), "llvm") || - !std::strcmp(mod->type_key(), "c"); + return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c"); } runtime::Module mod_; @@ -148,21 +143,21 @@ class ModuleSerializer { std::unordered_map mod2index_; // index -> module std::vector mod_vec_; - std::vector import_tree_row_ptr_ {0}; + std::vector import_tree_row_ptr_{0}; std::vector import_tree_child_indices_; }; namespace { - std::string SerializeModule(const runtime::Module& mod) { - std::string bin; - dmlc::MemoryStringStream ms(&bin); - dmlc::Stream* stream = &ms; +std::string SerializeModule(const runtime::Module& mod) { + std::string bin; + dmlc::MemoryStringStream ms(&bin); + dmlc::Stream* stream = &ms; - ModuleSerializer module_serializer(mod); - module_serializer.SerializeModule(stream); + ModuleSerializer module_serializer(mod); + module_serializer.SerializeModule(stream); - return bin; - } + return bin; +} } // namespace std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { @@ -180,8 +175,8 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "#endif\n"; os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n"; uint64_t nbytes = bin.length(); - os << "const unsigned char " << runtime::symbol::tvm_dev_mblob - << "[" << bin.length() + sizeof(nbytes) << "] = {\n "; + os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "[" + << bin.length() + sizeof(nbytes) << "] = {\n "; os << std::hex; size_t nunit = 80 / 4; for (size_t i = 0; i < sizeof(nbytes); ++i) { @@ -214,8 +209,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { return os.str(); } -runtime::Module PackImportsToLLVM(const runtime::Module& mod, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, const std::string& target_triple) { std::string bin = SerializeModule(mod); @@ -233,19 +227,16 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, std::string codegen_f_name = "codegen.codegen_blob"; // the codegen function. const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name); - CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; + CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -TVM_REGISTER_GLOBAL("target.Build") -.set_body_typed(Build); +TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export two auxiliary function to the runtime namespace. -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") -.set_body_typed(PackImportsToC); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM") -.set_body_typed(PackImportsToLLVM); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index c16182da3674c..99d6bee60975b 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -16,34 +16,32 @@ * specific language governing permissions and limitations * under the License. */ -#include #include "registry.h" +#include + namespace tvm { namespace datatype { using runtime::TVMArgs; using runtime::TVMRetValue; -TVM_REGISTER_GLOBAL("runtime._datatype_register") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0]); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Registry::Global()->GetTypeName(args[0].operator int()); }); TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); + }); Registry* Registry::Global() { static Registry inst; diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index 919409f6e4f30..c04359208e64f 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -22,6 +22,7 @@ #include #include + #include #include @@ -69,7 +70,7 @@ class Registry { * \param type_name The type name * \return The type code */ - uint8_t GetTypeCode(const std::string &type_name); + uint8_t GetTypeCode(const std::string& type_name); /*! * \brief Get type name from type code diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 44d017f4ac5b3..9ad9f56f7c58d 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -20,14 +20,12 @@ * \file src/target/generic_func.cc */ #include - -#include -#include #include #include -#include -#include +#include #include +#include +#include #include #include @@ -43,8 +41,7 @@ struct GenericFunc::Manager { // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { static Manager inst; @@ -76,25 +73,23 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) m->fmap[name] = func; } -GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { +GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) { auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) - << "Generic function already registered for " << node->name_; + << "Generic function already registered for " << node->name_; } node->generic_func_ = value; return *this; } GenericFunc& GenericFunc::register_func(const std::vector& tags, - const PackedFunc value, - bool allow_override) { - for (auto &t : tags) { + const PackedFunc value, bool allow_override) { + for (auto& t : tags) { if (!allow_override) { auto iter = (*this)->dispatch_dict_.find(t); CHECK(iter == (*this)->dispatch_dict_.end()) - << "Tag " << t << " already registered for schedule factory " << (*this)->name_; + << "Tag " << t << " already registered for schedule factory " << (*this)->name_; } (*this)->dispatch_dict_[t] = value; } @@ -107,7 +102,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { PackedFunc func; if (target.defined()) { - for (auto &k : target->keys()) { + for (auto& k : target->keys()) { auto iter = node->dispatch_dict_.find(k); if (iter != node->dispatch_dict_.end()) { func = iter->second; @@ -124,30 +119,25 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { func.CallPacked(args, ret); } -TVM_REGISTER_GLOBAL("target.GenericFuncCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object()); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); bool allow_override = args[2]; - generic_func - .set_default(*func, allow_override); - }); + generic_func.set_default(*func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); @@ -159,17 +149,14 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") tags_vector.push_back(tag); } - generic_func - .register_func(tags_vector, *func, allow_override); - }); + generic_func.register_func(tags_vector, *func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); - generic_func - .CallPacked(func_args, ret); - }); + generic_func.CallPacked(func_args, ret); +}); } // namespace tvm diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 3a226e17daeeb..37855fb391794 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -21,108 +21,99 @@ * \file intrin_rule_default.cc * \brief Default intrinsic rules. */ -#include #include "intrin_rule.h" +#include + namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / sqrt(call->args[0]); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / sqrt(call->args[0]); + }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / (one + exp(-call->args[0])); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / (one + exp(-call->args[0])); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isfinite(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isfinite(call->args[0]); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isinf(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isinf(call->args[0]); + }); } // namespace intrin } // namespace codegen diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 091474254114c..8a5a44038df9d 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -24,9 +24,9 @@ #ifndef TVM_TARGET_INTRIN_RULE_H_ #define TVM_TARGET_INTRIN_RULE_H_ -#include -#include #include +#include + #include namespace tvm { @@ -49,21 +49,18 @@ struct FloatSuffix { // Return the intrinsic name struct Direct { - std::string operator()(DataType t, std::string name) const { - return name; - } + std::string operator()(DataType t, std::string name) const { return name; } }; // Call pure extern function. -template +template inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); std::string name = T()(call->dtype, call->name); if (name.length() != 0) { - *rv = CallNode::make( - call->dtype, name, call->args, CallNode::PureExtern); + *rv = CallNode::make(call->dtype, name, call->args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 8935809bc6f25..280c9998a4b0e 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -23,12 +23,13 @@ */ #ifdef TVM_LLVM_VERSION -#include #include +#include #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/rocm/rocm_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -45,8 +46,8 @@ static inline int DetectROCMmaxThreadsPerBlock() { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)-> - GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, + &val); return val.operator int(); } } @@ -73,8 +74,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { @@ -88,9 +88,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -104,12 +103,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -119,8 +117,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -132,18 +129,32 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break; - default: LOG(FATAL) << "unknown workitem idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; + break; + default: + LOG(FATAL) << "unknown workitem idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break; - default: LOG(FATAL) << "unknown workgroup idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; + break; + default: + LOG(FATAL) << "unknown workgroup idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -155,9 +166,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (sync == "warp") { return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::amdgcn_s_barrier); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -169,9 +179,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Additional optimization hook to tweak the builder. } - unsigned GetGlobalAddressSpace() const final { - return 1; - } + unsigned GetGlobalAddressSpace() const final { return 1; } protected: void InitTarget(llvm::TargetMachine* tm) final { @@ -211,13 +219,10 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { // issue #4087 for a discussion #endif InitializeLLVM(); - CHECK(target.length() >= 4 && - target.substr(0, 4) == "rocm"); + CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); std::ostringstream config; - config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" - << DetectROCMComputeVersion(target) - << " -mattr=-code-object-v3 " - << target.substr(4, target.length() - 4); + config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target) + << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr ctx(new llvm::LLVMContext()); // careful: cg will hold a naked pointer reference to ctx, so it should @@ -226,18 +231,16 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto *find_rocm_bitcodes = - tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); + const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); - for (auto &bitcode_path : bitcode_files) { + for (auto& bitcode_path : bitcode_files) { std::string path = bitcode_path; llvm::SMDiagnostic err; std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); @@ -248,7 +251,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { } mlib->setTargetTriple(tm->getTargetTriple().str()); mlib->setDataLayout(tm->createDataLayout()); - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } cg->AddLinkModule(std::move(mlib)); @@ -271,33 +274,28 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*mObj); std::string obj(dataObj.begin(), dataObj.end()); llvm::legacy::PassManager passAsm; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, - llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif passAsm.run(*mAsm); @@ -315,8 +313,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("target.build.rocm") -.set_body_typed(BuildAMDGPU); +TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 73d849a7b3d12..ba4511595f0ee 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -47,8 +47,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -57,21 +56,21 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { return CodeGenCPU::CreateIntrinsic(op); } -PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { +PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { using namespace tir; const PrimExpr& e = call->args[2]; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; // Fallback to default llvm lowering rule if input type not a full vector or half vector length - int total_size = call->dtype.bits() * call->dtype.lanes(); + int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_vector() || call->dtype.bits() == 8 || - (total_size != 128 && total_size != 64)) { + (total_size != 128 && total_size != 64)) { Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -80,12 +79,11 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // to return back to original input type // Dvisions are always divisible (number of bits = 64 or 128) - DataType uint8_type = DataType( - e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); - DataType uint16_type = DataType( - uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); - DataType uint32_type = DataType( - uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); + DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); + DataType uint16_type = + DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); + DataType uint32_type = + DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); // Interpret input as vector of 8bit values PrimExpr input8 = reinterpret(uint8_type, e); @@ -96,16 +94,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::CallNode::make( - uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = + tir::CallNode::make(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::CallNode::make( - uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = + tir::CallNode::make(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -115,8 +113,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::CallNode::make( - uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = + tir::CallNode::make(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -126,15 +124,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::CallNode::make( - call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenARM(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenARM(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index be8ef9262765f..b7c48c7790734 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -21,17 +21,17 @@ * \file codegen_blob.cc */ #ifdef TVM_LLVM_VERSION +#include "codegen_blob.h" + #include + #include -#include "codegen_blob.h" namespace tvm { namespace codegen { -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple) { +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple) { InitializeLLVM(); auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple); auto triple = tm->getTargetTriple(); @@ -41,10 +41,9 @@ std::pair, module->setTargetTriple(triple.str()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true, - llvm::GlobalValue::ExternalLinkage, blob_value, - runtime::symbol::tvm_dev_mblob, nullptr, - llvm::GlobalVariable::NotThreadLocal, 0); + auto* tvm_dev_mblob = new llvm::GlobalVariable( + *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, + runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob->setAlignment(llvm::Align(1)); @@ -64,11 +63,9 @@ std::pair, auto int8_ptr_ty = int8_ty->getPointerTo(0); llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty); - auto* tvm_dev_mblob_reg = - new llvm::GlobalVariable(*module, int32_ty, - false, llvm::GlobalValue::InternalLinkage, - constant_zero, - std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); + auto* tvm_dev_mblob_reg = new llvm::GlobalVariable( + *module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero, + std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment)); @@ -80,11 +77,9 @@ std::pair, llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1); auto* tvm_dev_mblob_string_value = llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true); - auto* tvm_dev_mblob_string = - new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty, - true, llvm::GlobalValue::PrivateLinkage, - tvm_dev_mblob_string_value, - std::string(runtime::symbol::tvm_dev_mblob) + ".str"); + auto* tvm_dev_mblob_string = new llvm::GlobalVariable( + *module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage, + tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str"); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_string->setAlignment(llvm::Align(1)); #else @@ -92,33 +87,30 @@ std::pair, #endif // Global init function - llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("_GLOBAL__sub_I_", module_name), - module.get()); + llvm::Function* init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("_GLOBAL__sub_I_", module_name), module.get()); // Create variable initialization function. - llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("__cxx_global_var_init"), - module.get()); + llvm::Function* var_init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("__cxx_global_var_init"), module.get()); // Create TVMBackendRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMBackendRegisterSystemLibSymbol"), - module.get()); + llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { - if (triple.isOSLinux()) { - return ".text.startup"; - } else if (triple.isOSDarwin()) { - return "__TEXT,__StaticInit,regular,pure_instructions"; - } else { - return ""; - } + if (triple.isOSLinux()) { + return ".text.startup"; + } else if (triple.isOSDarwin()) { + return "__TEXT,__StaticInit,regular,pure_instructions"; + } else { + return ""; + } }; auto static_init_section_specifier = get_static_init_section_specifier(); @@ -144,11 +136,9 @@ std::pair, llvm::Constant* indices[] = {constant_zero, constant_zero}; llvm::SmallVector args; args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty, - tvm_dev_mblob_string, - indices)); - args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), - tvm_dev_mblob, - indices)); + tvm_dev_mblob_string, indices)); + args.push_back( + llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices)); auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args); ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg); ir_builder.CreateRetVoid(); diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index a394f77a66383..2821f44ebd3c3 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -24,9 +24,10 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #define TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #ifdef TVM_LLVM_VERSION -#include #include #include +#include + #include "llvm_common.h" namespace tvm { @@ -40,10 +41,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple); +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index e474b9cf6bcf6..03b5496c244cc 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -22,20 +22,19 @@ */ #ifdef TVM_LLVM_VERSION +#include "codegen_cpu.h" + #include #include + #include #include -#include "codegen_cpu.h" namespace tvm { namespace codegen { -void CodeGenCPU::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); @@ -46,53 +45,34 @@ void CodeGenCPU::Init(const std::string& module_name, t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); t_tvm_func_handle_ = t_void_p_; - t_tvm_array_ = llvm::StructType::create( - {t_void_p_, - t_tvm_context_, - t_int_, - t_tvm_type_, - t_tvm_shape_index_->getPointerTo(), - t_tvm_shape_index_->getPointerTo(), - t_int64_}); + t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_context_, t_int_, t_tvm_type_, + t_tvm_shape_index_->getPointerTo(), + t_tvm_shape_index_->getPointerTo(), t_int64_}); t_tvm_value_ = llvm::StructType::create({t_float64_}); - t_tvm_parallel_group_env_ = llvm::StructType::create({ - t_int32_->getPointerTo(), t_int32_}); + t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_}); ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( - t_int_, - {t_int_, - t_tvm_parallel_group_env_->getPointerTo(), - t_void_p_}, false); + t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false); md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_); // Runtime functions. - ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, { - t_tvm_func_handle_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo(), + ftype_tvm_func_call_ = llvm::FunctionType::get( t_int_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo()}, false); - ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, { - t_void_p_, - t_char_->getPointerTo(), - t_tvm_func_handle_->getPointerTo()}, false); - ftype_tvm_api_set_last_error_ = llvm::FunctionType::get( - t_void_, {t_char_->getPointerTo()}, false); - ftype_tvm_parallel_launch_ = - llvm::FunctionType::get(t_int_, { - ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_} - , false); + {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + t_tvm_value_->getPointerTo(), t_int_->getPointerTo()}, + false); + ftype_tvm_get_func_from_env_ = llvm::FunctionType::get( + t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false); + ftype_tvm_api_set_last_error_ = + llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false); + ftype_tvm_parallel_launch_ = llvm::FunctionType::get( + t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false); ftype_tvm_parallel_barrier_ = - llvm::FunctionType::get(t_int_, { - t_int_, t_tvm_parallel_group_env_->getPointerTo()} - , false); - ftype_tvm_static_init_callback_ = - llvm::FunctionType::get(t_int_, {t_void_p_}, false); + llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false); + ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false); ftype_tvm_static_init_ = - llvm::FunctionType::get(t_int_, { - t_void_p_->getPointerTo(), - ftype_tvm_static_init_callback_->getPointerTo(), - t_void_p_, t_int_} - , false); + llvm::FunctionType::get(t_int_, + {t_void_p_->getPointerTo(), + ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_}, + false); // initialize TVM runtime API if (system_lib) { // We will need this in environment for backward registration. @@ -103,21 +83,20 @@ void CodeGenCPU::Init(const std::string& module_name, f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib) { - f_tvm_func_call_ = llvm::Function::Create( - ftype_tvm_func_call_, - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); - f_tvm_get_func_from_env_ = llvm::Function::Create( - ftype_tvm_get_func_from_env_, - llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get()); - f_tvm_api_set_last_error_ = llvm::Function::Create( - ftype_tvm_api_set_last_error_, - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); - f_tvm_parallel_launch_ = llvm::Function::Create( - ftype_tvm_parallel_launch_, - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); - f_tvm_parallel_barrier_ = llvm::Function::Create( - ftype_tvm_parallel_barrier_, - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); + f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage, + "TVMFuncCall", module_.get()); + f_tvm_get_func_from_env_ = + llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, + "TVMBackendGetFuncFromEnv", module_.get()); + f_tvm_api_set_last_error_ = + llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage, + "TVMAPISetLastError", module_.get()); + f_tvm_parallel_launch_ = + llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage, + "TVMBackendParallelLaunch", module_.get()); + f_tvm_parallel_barrier_ = + llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, + "TVMBackendParallelBarrier", module_.get()); } this->InitGlobalContext(dynamic_lookup); } @@ -152,22 +131,13 @@ void CodeGenCPU::AddDebugInformation(llvm::Function* function) { #if TVM_LLVM_VERSION >= 80 auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false /* internal linkage */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false /* internal linkage */); #else auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false, /* internal linkage */ - true, - 0 /* line number */, - llvm::DINode::FlagPrototyped, - true /* isOptimized */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false, /* internal linkage */ + true, 0 /* line number */, llvm::DINode::FlagPrototyped, true /* isOptimized */); #endif CHECK(DIFunction); @@ -236,9 +206,8 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { llvm::Function* f = module_->getFunction(entry_func_name); CHECK(f) << "Function " << entry_func_name << "does not in module"; llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, - runtime::symbol::tvm_module_main); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, runtime::symbol::tvm_module_main); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -254,8 +223,8 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr( - DataType t, llvm::Value* buf, llvm::Value* index, int kind) { +llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, + int kind) { if (kind < intrinsic::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -280,27 +249,22 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } case intrinsic::kArrTypeCode: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } case intrinsic::kArrTypeBits: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } case intrinsic::kArrTypeLanes: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(2)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } case intrinsic::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } case intrinsic::kArrDeviceId: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } case intrinsic::kArrDeviceType: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } case intrinsic::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); @@ -318,7 +282,9 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); } } - default: LOG(FATAL) << "unknown field code"; return nullptr; + default: + LOG(FATAL) << "unknown field code"; + return nullptr; } } @@ -331,8 +297,8 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); // Check if it is available in global function table as injected function. auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { @@ -349,8 +315,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } else { llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -361,12 +326,9 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } } -llvm::GlobalVariable* CodeGenCPU::InitContextPtr( - llvm::Type* p_type, std::string name) { +llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, p_type, false, - llvm::GlobalValue::LinkOnceAnyLinkage, 0, - name); + *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, 0, name); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type))); #else @@ -384,9 +346,8 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif - faddr->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + faddr->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); return faddr; } @@ -399,16 +360,15 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); } else { if (!dynamic_lookup) { - gv_tvm_func_call_ = InitContextPtr( - ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); - gv_tvm_get_func_from_env_ = InitContextPtr( - ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv"); - gv_tvm_api_set_last_error_ = InitContextPtr( - ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); - gv_tvm_parallel_launch_ = InitContextPtr( - ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); - gv_tvm_parallel_barrier_ = InitContextPtr( - ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier"); + gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); + gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(), + "__TVMBackendGetFuncFromEnv"); + gv_tvm_api_set_last_error_ = + InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); + gv_tvm_parallel_launch_ = + InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); + gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(), + "__TVMBackendParallelBarrier"); // Mark as context functions gv_func_map_["TVMBackendAllocWorkspace"] = nullptr; gv_func_map_["TVMBackendFreeWorkspace"] = nullptr; @@ -419,12 +379,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. using llvm::BasicBlock; - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "call_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "call_end", function_); - llvm::Value* succ = builder_->CreateICmpEQ( - retcode, llvm::ConstantInt::get(t_int_, 0)); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "call_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "call_end", function_); + llvm::Value* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); // return the code. @@ -448,20 +405,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { arg_values.push_back(value); arg_types.push_back(value->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(t_int_, arg_types, false); - llvm::Function* fcompute = - llvm::Function::Create(ftype, - llvm::Function::PrivateLinkage, - op->value.as()->value, - module_.get()); - BasicBlock* compute_call_end = CheckCallSuccess( - builder_->CreateCall(fcompute, arg_values)); + llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false); + llvm::Function* fcompute = llvm::Function::Create( + ftype, llvm::Function::PrivateLinkage, op->value.as()->value, module_.get()); + BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); // setup compute fuinction. std::unordered_map new_vmap; size_t idx = 0; - for (auto it = fcompute->arg_begin(); - it != fcompute->arg_end(); ++it, ++idx) { + for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) { llvm::Argument* v = &(*it); const Var& var = vargs[idx]; new_vmap[var.get()] = v; @@ -478,7 +429,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } std::swap(function_, fcompute); std::swap(new_vmap, var_map_); - BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_); + BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -503,48 +454,41 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { - builder_->CreateStore( - var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateStore(var_map_.at(vfields[i].get()), + builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); } *num_bytes = data_layout_->getTypeAllocSize( llvm::cast(cdata->getType())->getElementType()); return cdata; } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, - const Array& vfields, +void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP( - cdata, {ConstInt32(0), ConstInt32(i)})); + builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); } } void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_parallel_lambda_, - llvm::Function::PrivateLinkage, - "__tvm_parallel_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage, + "__tvm_parallel_lambda", module_.get()); // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; llvm::Value* cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 - auto launch_callee = llvm::FunctionCallee( - ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); + auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif - BasicBlock* par_launch_end = CheckCallSuccess( - builder_->CreateCall( - launch_callee, - {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( + launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -558,9 +502,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = builder_->CreateLoad( - builder_->CreateInBoundsGEP( - penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = + builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; std::swap(function_, f); std::swap(parallel_env_, par_env); @@ -571,16 +514,13 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { std::swap(var_map_, new_vmap); std::swap(parallel_env_, par_env); std::swap(function_, f); - CHECK_NE(par_env.parallel_loop_count, 0) - << "Cannot find parallel loop within parallel launch"; + CHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; builder_->SetInsertPoint(par_launch_end); } llvm::Value* CodeGenCPU::CreateStaticHandle() { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, t_void_p_, false, - llvm::GlobalValue::PrivateLinkage, 0, - "__tvm_static_handle"); + *module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage, 0, "__tvm_static_handle"); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_))); #else @@ -593,26 +533,23 @@ llvm::Value* CodeGenCPU::CreateStaticHandle() { void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_static_init_callback_, - llvm::Function::PrivateLinkage, - "__tvm_static_init_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage, + "__tvm_static_init_lambda", module_.get()); llvm::Value* gv = CreateStaticHandle(); llvm::Function* finit = module_->getFunction(init_fname); if (finit == nullptr) { - finit = llvm::Function::Create( - ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get()); + finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage, + init_fname, module_.get()); } // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); llvm::Value* cdata = PackClosureData(vfields, &nbytes); - BasicBlock* init_end = CheckCallSuccess( - builder_->CreateCall( - finit, - {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( + finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); @@ -642,9 +579,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { if (it == func_handle_map_.end()) { // create global location for the handle // create the function handle - hptr = new llvm::GlobalVariable( - *module_, t_tvm_func_handle_, false, - llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); + hptr = + new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false, + llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); #if TVM_LLVM_VERSION >= 100 hptr->setAlignment(llvm::Align(align)); #else @@ -657,42 +594,34 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { } // create emit codes that checks and load the function. BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* init_block = BasicBlock::Create( - *ctx_, "handle_init", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "handle_init_end", function_); + BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif - llvm::Value* handle_not_null = builder_->CreateICmpNE( - handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); - builder_->CreateCondBr( - handle_not_null, end_block, init_block, md_very_likely_branch_); + llvm::Value* handle_not_null = + builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); + builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_); // Initialize the handle if needed. builder_->SetInsertPoint(init_block); - llvm::Value* out = WithFunctionEntry([&]() { - return builder_->CreateAlloca(t_tvm_func_handle_); - }); + llvm::Value* out = + WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = + builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 - auto env_callee = llvm::FunctionCallee( - ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); + auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall( - env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); @@ -710,38 +639,33 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -llvm::BasicBlock * -CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, - const int64_t begin, const int64_t end) { +llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, + const int64_t begin, const int64_t end) { using llvm::BasicBlock; std::string func_name = args[0].as()->value; - llvm::Value *handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function int64_t nargs = end - begin; CHECK_GE(nargs, 0); - llvm::Value *stack_value = MakeValue(args[1]); - llvm::Value *stack_tcode = MakeValue(args[2]); - llvm::Value *arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(begin)); - llvm::Value *arg_tcode = - CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); - llvm::Value *ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(end)); + llvm::Value* stack_value = MakeValue(args[1]); + llvm::Value* stack_tcode = MakeValue(args[2]); + llvm::Value* arg_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); + llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + llvm::Value* ret_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), - ret_value, *ret_tcode})); + BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( + call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = builder_->CreatePointerCast( - ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Value* load_ptr = + builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); #if TVM_LLVM_VERSION >= 110 *rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); #else @@ -751,47 +675,44 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, return end_block; } -llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { CHECK_EQ(op->args.size(), 5U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, - op->args[3].as()->value, + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, op->args[4].as()->value); return rvalue; } -llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { using llvm::BasicBlock; CHECK_EQ(op->args.size(), 6U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - BasicBlock *end_block = MakeCallPacked( - op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + BasicBlock* end_block = + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. - llvm::Value *traced_value = MakeValue(op->args[5]); + llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock *update_block = - BasicBlock::Create(*ctx_, "update_block", function_); + BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock *continue_block = - BasicBlock::Create(*ctx_, "continue_block", function_); + BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); #else - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); #endif // Check the ret_type_code and create cmp instruction. - llvm::Value *cmp = builder_->CreateICmpNE( - ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + llvm::Value* cmp = + builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. - llvm::PHINode *phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); + llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); phi_rvalue->addIncoming(rvalue, update_block); phi_rvalue->addIncoming(traced_value, end_block); return phi_rvalue; @@ -823,17 +744,14 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { void CodeGenCPU::AddStartupFunction() { if (export_system_symbols_.size() != 0) { llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false); - function_ = llvm::Function::Create( - ftype, - llvm::Function::InternalLinkage, - "__tvm_module_startup", module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, + "__tvm_module_startup", module_.get()); llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); - builder_->CreateCall( - f_tvm_register_system_symbol_, { - name, builder_->CreateBitCast(kv.second, t_void_p_)}); + builder_->CreateCall(f_tvm_register_system_symbol_, + {name, builder_->CreateBitCast(kv.second, t_void_p_)}); } llvm::appendToGlobalCtors(*module_, function_, 65535); builder_->CreateRet(nullptr); @@ -853,9 +771,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = this->CreateStructRefPtr( - op->dtype, MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = + this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == intrinsic::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { @@ -865,13 +782,11 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr( - op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); CHECK(kind != intrinsic::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast( - value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); @@ -879,22 +794,22 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { - const int64_t* pval = as_const_int(op->args[1]); - CHECK(pval) << "require stack alloca to contain constant value"; - llvm::Value* num = ConstInt32(pval[0]); - if (type == "shape") { - return builder_->CreateAlloca(t_tvm_shape_index_, num); - } else if (type == "arg_value") { - return builder_->CreateAlloca(t_tvm_value_, num); - } else if (type == "arg_tcode") { - return builder_->CreateAlloca(t_int_, num); - } else if (type == "array") { - return builder_->CreateAlloca(t_tvm_array_, num); - } else { - LOG(FATAL) << "Unknown stack alloca type " << type; - return nullptr; - } - }); + const int64_t* pval = as_const_int(op->args[1]); + CHECK(pval) << "require stack alloca to contain constant value"; + llvm::Value* num = ConstInt32(pval[0]); + if (type == "shape") { + return builder_->CreateAlloca(t_tvm_shape_index_, num); + } else if (type == "arg_value") { + return builder_->CreateAlloca(t_tvm_value_, num); + } else if (type == "arg_tcode") { + return builder_->CreateAlloca(t_int_, num); + } else if (type == "array") { + return builder_->CreateAlloca(t_tvm_array_, num); + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + return nullptr; + } + }); } else { return CodeGenLLVM::CreateIntrinsic(op); } @@ -909,16 +824,14 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "assert_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "assert_end", function_); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "assert_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); #if TVM_LLVM_VERSION >= 90 - auto err_callee = llvm::FunctionCallee( - ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); + auto err_callee = + llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); #else auto err_callee = RuntimeTVMAPISetLastError(); #endif @@ -932,7 +845,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::coproc_uop_scope) { this->CreateStaticInit(op->value.as()->value, op->body); - } else if (op->attr_key == tir::attr::compute_scope) { + } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (tir::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { @@ -943,20 +856,18 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == "pragma_parallel_launch_point") { CreateParallelLaunch(op->body, 0); } else if (op->attr_key == "pragma_parallel_barrier_when_finish") { - CHECK(parallel_env_.penv != nullptr) - << "Cannot run barrier without parallel environment"; + CHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment"; CHECK(!parallel_env_.in_parallel_loop) << "Cannot not place within parallel loop as the workload may differ, " << " place it between parallel and parallel_launch_point"; this->VisitStmt(op->body); #if TVM_LLVM_VERSION >= 90 - auto bar_callee = llvm::FunctionCallee( - ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); + auto bar_callee = + llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); #else auto bar_callee = RuntimeTVMParallelBarrier(); #endif - builder_->CreateCall( - bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); + builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); } else if (op->attr_key == tir::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); CHECK(value != nullptr); @@ -973,15 +884,13 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); - if (op->for_type == ForType::Serial || - op->for_type == ForType::Unrolled) { + if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->for_type == ForType::Parallel) { if (parallel_env_.penv == nullptr) { CreateParallelLaunch( - ForNode::make( - op->loop_var, op->min, op->extent, - op->for_type, op->device_api, op->body), 0); + ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), + 0); } else { // already in parallel env. CHECK(parallel_env_.task_id.defined()); @@ -994,20 +903,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { - CreateSerialFor(MakeValue(task_id), - MakeValue(op->extent), - MakeValue(num_task), - op->loop_var, - op->body); + CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), + op->loop_var, op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; PrimExpr begin = MinNode::make(task_id * step, op->extent); PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); - CreateSerialFor(MakeValue(begin), - MakeValue(end), - llvm::ConstantInt::getSigned(GetLLVMType(end), 1), - op->loop_var, - op->body); + CreateSerialFor(MakeValue(begin), MakeValue(end), + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index aa8371c39a5ca..7a14b8fdc959b 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,11 +24,12 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ -#include -#include #include #include #include +#include +#include + #include "codegen_llvm.h" namespace tvm { @@ -37,11 +38,8 @@ namespace codegen { // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: - void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) override; + void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -95,20 +93,18 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t *num_bytes); + llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value*cdata, - const Array& fields, + void UnpackClosureData(llvm::Value* cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock *MakeCallPacked(const Array &args, - llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, + llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); // Create trace call into tvm packed function. - llvm::Value* CreateCallTracePacked(const CallNode *op); + llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization void CreateStaticInit(const std::string& init_fname, const Stmt& body); // Create parallel launch diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 86cd5a3acf61e..f664532b2dc16 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -22,20 +22,21 @@ */ #ifdef TVM_LLVM_VERSION // Part of the code are adapted from Halide's CodeGen_LLVM -#include +#include "codegen_llvm.h" + #include +#include #include #include -#include "codegen_llvm.h" -#include "codegen_cpu.h" #include "../../arith/pattern_match.h" #include "../build_common.h" +#include "codegen_cpu.h" namespace tvm { namespace codegen { -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { +std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { std::string target = tm->getTarget().getName(); std::string factory_name = "tvm.codegen.llvm.target_" + target; const PackedFunc* f = runtime::Registry::Get(factory_name); @@ -47,11 +48,8 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { InitializeLLVM(); ctx_ = ctx; builder_.reset(new IRBuilder(*ctx_)); @@ -68,7 +66,7 @@ void CodeGenLLVM::Init(const std::string& module_name, t_int64_ = llvm::Type::getInt64Ty(*ctx_); t_float64_ = llvm::Type::getDoubleTy(*ctx_); // meta data - md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1); + md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); this->InitTarget(tm); @@ -96,9 +94,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::AddFunction(const PrimFunc& f) { - this->AddFunctionInternal(f, false); -} +void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } void CodeGenLLVM::InitFuncState() { var_map_.clear(); @@ -108,7 +104,6 @@ void CodeGenLLVM::InitFuncState() { analyzer_.reset(new arith::Analyzer()); } - void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { this->InitFuncState(); @@ -126,8 +121,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { // TODO(tvm-team): // Update the function type to respect the ret_type field of f. // Once we allow more flexibility in the PrimFunc. - llvm::FunctionType* ftype = llvm::FunctionType::get( - ret_void ? t_void_ : t_int_, param_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) @@ -135,9 +130,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; - function_ = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -169,7 +163,6 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } - std::unique_ptr CodeGenLLVM::Finish() { this->AddStartupFunction(); for (size_t i = 0; i < link_modules_.size(); ++i) { @@ -182,13 +175,11 @@ std::unique_ptr CodeGenLLVM::Finish() { return std::move(module_); } - void CodeGenLLVM::HandleImport(const std::string& code) { std::unique_ptr mlib; llvm::SMDiagnostic err; if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || - code.substr(code.length() - 3) == ".bc")) { + (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { mlib = llvm::parseIRFile(code, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); @@ -196,20 +187,19 @@ void CodeGenLLVM::HandleImport(const std::string& code) { << "line " << err.getLineNo() << ":" << msg; } } else { - std::unique_ptr buf = - llvm::MemoryBuffer::getMemBuffer(code); + std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); mlib = llvm::parseIR(*buf, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg - << "\ncontent:\n" << code; + << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" + << code; } } mlib->setTargetTriple(target_machine_->getTargetTriple().str()); mlib->setDataLayout(target_machine_->createDataLayout()); // mark all the functions as force inline - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); f.addFnAttr(llvm::Attribute::AlwaysInline); f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); @@ -238,35 +228,27 @@ llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { class FPassManager : public llvm::legacy::FunctionPassManager { public: - explicit FPassManager(llvm::Module* m) - : llvm::legacy::FunctionPassManager(m) {} + explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {} // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::FunctionPassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); } }; class MPassManager : public llvm::legacy::PassManager { public: // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::PassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); } }; -void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) { -} +void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {} void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; @@ -300,24 +282,32 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co return native_vector_bits_; } -unsigned CodeGenLLVM::GetGlobalAddressSpace() const { - return 0; -} +unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { if (dtype.is_handle()) { CHECK_EQ(dtype.lanes(), 1); return t_void_p_; } + if (dtype.is_void()) { + return t_void_; + } llvm::Type* etype = nullptr; if (dtype.is_int() || dtype.is_uint()) { etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { - case 16: etype = llvm::Type::getHalfTy(*ctx_); break; - case 32: etype = llvm::Type::getFloatTy(*ctx_); break; - case 64: etype = llvm::Type::getDoubleTy(*ctx_); break; - default: LOG(FATAL) << "do not support " << dtype; + case 16: + etype = llvm::Type::getHalfTy(*ctx_); + break; + case 32: + etype = llvm::Type::getFloatTy(*ctx_); + break; + case 64: + etype = llvm::Type::getDoubleTy(*ctx_); + break; + default: + LOG(FATAL) << "do not support " << dtype; } } if (dtype.lanes() != 1) { @@ -352,16 +342,12 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { // // This trick comes from Halide's CodeGen_LLVM // -void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, - const VarNode* buffer, - PrimExpr index, +void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index, DataType type) { if (alias_var_set_.count(buffer) != 0) { // Mark all possibly aliased pointer as same type. llvm::MDNode* meta = md_tbaa_alias_set_; - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); return; } @@ -402,16 +388,11 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); } } - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); } -void CodeGenLLVM::GetAlignment(DataType t, - const VarNode* buf_var, - const PrimExpr& index, - int* p_alignment, - int* p_native_bits) { +void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, + int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { @@ -427,11 +408,9 @@ void CodeGenLLVM::GetAlignment(DataType t, int64_t coeff = me->coeff; int align_bits = t.bits(); - while (align_bits < max_align_bits && - base % 2 == 0 && - coeff % 2 == 0) { - base = base / 2; - coeff = coeff / 2; + while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) { + base = base / 2; + coeff = coeff / 2; align_bits *= 2; } if (align_bits < 8) { @@ -440,8 +419,7 @@ void CodeGenLLVM::GetAlignment(DataType t, *p_alignment = align_bits / 8; } -std::unique_ptr -CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { +std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); debug_info->di_builder_ = std::make_unique(*module); @@ -460,8 +438,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { } llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { - llvm::Constant* undef = llvm::UndefValue::get( - llvm::VectorType::get(value->getType(), lanes)); + llvm::Constant* undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), lanes)); llvm::Constant* zero = ConstInt32(0); value = builder_->CreateInsertElement(undef, value, zero); #if TVM_LLVM_VERSION >= 110 @@ -503,8 +480,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - llvm::Value* mask = llvm::UndefValue::get( - DTypeToLLVMType(DataType::Int(32, target_lanes))); + llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = llvm::cast(vec->getType())->getNumElements(); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); @@ -555,28 +531,21 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { return CreateVecSlice(vecs[0], 0, total_lanes); } - -void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, - const Var& loop_var, - const Stmt& body) { +void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, + const Var& loop_var, const Stmt& body) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* for_begin = BasicBlock::Create( - *ctx_, "for_begin", function_); - BasicBlock* for_body = BasicBlock::Create( - *ctx_, "for_body", function_); - BasicBlock* for_end = BasicBlock::Create( - *ctx_, "for_end", function_); + BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); + BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); loop_value->addIncoming(begin, pre_block); CHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), - for_body, for_end, md_very_likely_branch_); + builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, + md_very_likely_branch_); builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); @@ -588,7 +557,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { - llvm::Type * target = DTypeToLLVMType(to); + llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); @@ -625,8 +594,8 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -635,14 +604,12 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr( - type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); str_map_[str] = ptr; return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_EQ(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); @@ -654,13 +621,11 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( return builder_->CreateInBoundsGEP(buffer, index); } -llvm::Value* CodeGenLLVM::CreateBufferVecPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_GT(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo( - btype->getAddressSpace()); + llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); if (btype != ptype) { buffer = builder_->CreatePointerCast(buffer, ptype); } @@ -680,21 +645,18 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_type, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; } -llvm::Function* CodeGenLLVM::GetIntrinsicDecl( - llvm::Intrinsic::ID id, llvm::Type* ret_type, - llvm::ArrayRef arg_types) { +llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, + llvm::ArrayRef arg_types) { llvm::Module* module = module_.get(); if (!llvm::Intrinsic::isOverloaded(id)) { @@ -709,8 +671,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { overload_types.clear(); llvm::ArrayRef ref(infos); - auto match = - llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); if (error) { @@ -745,7 +706,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( // Failed to identify the type. return nullptr; -#else // TVM_LLVM_VERSION +#else // TVM_LLVM_VERSION llvm::ArrayRef ref(infos); // matchIntrinsicType returns true on error. if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { @@ -763,9 +724,8 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); - int64_t num_signature = Downcast(op->args[1])->value; + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector arg_type; for (size_t i = 2; i < op->args.size(); ++i) { @@ -781,9 +741,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type *return_type = (id != llvm::Intrinsic::prefetch) - ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) + : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " @@ -808,22 +767,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); - const RampNode *r = l->index.as(); + const RampNode* r = l->index.as(); llvm::Value* ptr; unsigned addrspace; if (!r) { - ptr = CreateBufferPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } else { - PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferVecPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); + ptr = CreateBufferVecPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { @@ -837,15 +792,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - CHECK_EQ(op->args[0].dtype().lanes(), 1) - << "if_then_else can only take scalar condition"; + CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -861,23 +812,23 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->is_intrinsic(CallNode::reinterpret)) { - llvm::Type * target = DTypeToLLVMType(op->dtype); + llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); } else if (op->is_intrinsic(CallNode::isnan)) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); } else if (op->is_intrinsic("vectorlow")) { - llvm::Value *v = MakeValue(op->args[0]); + llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); - return CreateVecSlice(v, 0, l/2); + return CreateVecSlice(v, 0, l / 2); } else if (op->is_intrinsic("vectorhigh")) { - llvm::Value *v = MakeValue(op->args[0]); + llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); - return CreateVecSlice(v, l/2, l/2); + return CreateVecSlice(v, l / 2, l / 2); } else if (op->is_intrinsic("vectorcombine")) { - llvm::Value *v0 = MakeValue(op->args[0]); - llvm::Value *v1 = MakeValue(op->args[1]); + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; #if TVM_LLVM_VERSION >= 110 std::vector indices; @@ -894,8 +845,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } -void CodeGenLLVM::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + (ramp->stride * i); @@ -909,11 +859,8 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, } } - // Visitors -llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { - return GetVarValue(op); -} +llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); } llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); @@ -926,52 +873,48 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { - return GetConstString(op->value); -} - -#define DEFINE_CODEGEN_BINARY_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value *b) { \ - if (t.is_int()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNSW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else if (t.is_uint()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNUW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateF ## Op (a, b); \ - } \ - } \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } + +#define DEFINE_CODEGEN_BINARY_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNSW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else if (t.is_uint()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNUW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateF##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_BINARY_OP(Add); DEFINE_CODEGEN_BINARY_OP(Sub); DEFINE_CODEGEN_BINARY_OP(Mul); -#define DEFINE_CODEGEN_CMP_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value* b) { \ - if (t.is_int()) { \ - return builder_->CreateICmpS ## Op (a, b); \ - } else if (t.is_uint()) { \ - return builder_->CreateICmpU ## Op (a, b); \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateFCmpO ## Op (a, b); \ - } \ -} \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ +#define DEFINE_CODEGEN_CMP_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + return builder_->CreateICmpS##Op(a, b); \ + } else if (t.is_uint()) { \ + return builder_->CreateICmpU##Op(a, b); \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateFCmpO##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_CMP_OP(LT); @@ -1050,10 +993,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { - return builder_->CreateSelect( - MakeValue(op->condition), - MakeValue(op->true_value), - MakeValue(op->false_value)); + return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), + MakeValue(op->false_value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { @@ -1074,8 +1015,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); llvm::Value* ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1083,20 +1023,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast( - ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = + builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1111,11 +1048,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { auto f = [&](int i, llvm::Value* index) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t); @@ -1125,16 +1060,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { return CreateIntrinsic(op); - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { return CreateCallExtern(op); } else { - LOG(FATAL) << "Unknown call type " << - "name= " << op->name << - " call_type= " << op->call_type; + LOG(FATAL) << "Unknown call type " + << "name= " << op->name << " call_type= " << op->call_type; return nullptr; } } @@ -1143,14 +1075,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( - vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), - ConstInt32(i)); + vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } return vec; } llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { - std::vector vecs(op->vectors.size()); + std::vector vecs(op->vectors.size()); int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { vecs[i] = VisitExpr(op->vectors[i]); @@ -1159,9 +1090,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { llvm::Value* v0 = CreateVecConcat(vecs); std::vector idx(op->indices.size()); for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " - << "but get " << op->indices[i] << "\n"; + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " + << "but get " << op->indices[i] << "\n"; idx[i] = *val; } llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); @@ -1195,15 +1126,13 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { return; } else { // vector store - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = @@ -1223,12 +1152,10 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, llvm::Align(basic_align), is_volatile); + builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), + ptr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype()); }; @@ -1245,21 +1172,16 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), - op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } - void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); if (op->else_case.defined()) { - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1276,39 +1198,35 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { builder_->SetInsertPoint(end_block); } - void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; - } - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the NV devices + if (info.alignment > 16) { + info.alignment = 16; + } + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - } - info.alignment = alloca->getAlignment(); - buf = alloca; + } + info.alignment = alloca->getAlignment(); + buf = alloca; buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -1331,8 +1249,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); - alloc_storage_info_[v].alignment = - static_cast(op->value.as()->value); + alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -1364,9 +1281,7 @@ void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e851f37901a20..4522c150b39bc 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -25,27 +25,27 @@ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #ifdef TVM_LLVM_VERSION +#include #include #include -#include +#include #include -#include -#include #include +#include +#include #include -#include - #include -#include -#include #include #include #include -#include "llvm_common.h" -#include "../../runtime/thread_storage_scope.h" +#include +#include + #include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" #include "../../tir/transforms/ir_util.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -55,9 +55,8 @@ using namespace tir; /*! * \brief A base class to generate a LLVM. */ -class CodeGenLLVM : - public ExprFunctor, - public StmtFunctor { +class CodeGenLLVM : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Create new code generator based on target machine. @@ -74,11 +73,8 @@ class CodeGenLLVM : * \param dynamic_lookup Whether dynamically lookup runtime function * or use the runtime function table passed by caller. */ - virtual void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup); + virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup); /*! * \brief Compile and add function f to the current module. * \param f The function to be added. @@ -104,9 +100,7 @@ class CodeGenLLVM : * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -170,7 +164,7 @@ class CodeGenLLVM : * \tparam F The function to be executed. * \return The result. */ - template + template llvm::AllocaInst* WithFunctionEntry(F falloca) { llvm::BasicBlock* current = builder_->GetInsertBlock(); llvm::BasicBlock* entry = &(function_->getEntryBlock()); @@ -191,8 +185,7 @@ class CodeGenLLVM : virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const PrimExpr& e, - std::function f); + virtual void Scalarize(const PrimExpr& e, std::function f); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); // Add module startup function if needed. @@ -205,8 +198,7 @@ class CodeGenLLVM : virtual unsigned GetGlobalAddressSpace() const; void AddFunctionInternal(const PrimFunc& f, bool ret_void); // Create extern call - llvm::CallInst* CreateCallExtern(llvm::Type* ret, - const std::string& name, + llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, const std::vector& value); /*! * \brief Get the LLVM Type for a given runtime type. @@ -243,20 +235,18 @@ class CodeGenLLVM : * could not be generated (e.g. if the argument/return types do not * match). */ - llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, - llvm::Type* ret_type, + llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); // initialize the function state. void InitFuncState(); // Get alignment given index. - void GetAlignment( - DataType t, const VarNode* buf_var, const PrimExpr& index, - int* p_alignment, int* p_native_bits); + void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, + int* p_native_bits); // Get constant string llvm::Value* GetConstString(const std::string& str); // do a scalarize call with f - llvm::Value* CreateScalarizedCall( - const CallNode* op, llvm::Function* f, const std::vector& args); + llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, + const std::vector& args); // handle module import void HandleImport(const std::string& code); // cast operatpr @@ -279,9 +269,7 @@ class CodeGenLLVM : llvm::Value* CreateVecConcat(std::vector vecs); llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for - void CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, + void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 40dc653f742b8..a0687b9b3b400 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -24,9 +24,10 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/cuda/cuda_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -39,10 +40,9 @@ class CodeGenNVPTX : public CodeGenLLVM { CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function module_->getOrInsertNamedMetadata("nvvm.annotations") - ->addOperand(llvm::MDNode::get(*ctx_, { - llvm::ValueAsMetadata::get(function_), - llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1)) })); + ->addOperand(llvm::MDNode::get( + *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -50,8 +50,7 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -65,9 +64,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -81,12 +79,11 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -96,8 +93,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -109,18 +105,32 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -133,9 +143,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // TODO(tqchen) warp sync in CUDA9 return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::nvvm_barrier0); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::nvvm_barrier0); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -174,11 +183,9 @@ inline int DetectCUDAComputeVersion() { tvm_ctx.device_type = kDLGPU; tvm_ctx.device_id = 0; TVMRetValue val; - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kExist, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kComputeVersion, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); std::string version = val; std::istringstream is(version); double ver; @@ -191,12 +198,10 @@ inline int DetectCUDAComputeVersion() { runtime::Module BuildNVPTX(IRModule mod, std::string target) { InitializeLLVM(); - CHECK(target.length() >= 5 && - target.substr(0, 5) == "nvptx"); + CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); int compute_ver = DetectCUDAComputeVersion(); std::ostringstream config; - config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" - << compute_ver + config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver << target.substr(5, target.length() - 5); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr cg(new CodeGenNVPTX()); @@ -204,15 +209,13 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto* flibdevice_path = - tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); + const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { @@ -239,16 +242,14 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { // emit ptx llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*module); @@ -256,8 +257,7 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("target.build.nvptx") -.set_body_typed(BuildNVPTX); +TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 570bb0d67672c..d0038b8da8455 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -24,8 +24,8 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_cpu.h" +#include "codegen_cpu.h" #include "llvm/MC/MCSubtargetInfo.h" namespace tvm { @@ -89,14 +89,12 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic)), - MakeValue( - tir::BroadcastNode::make( - FloatImm(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), + MakeValue(tir::CallNode::make(DataType::Int(16, from.lanes()), + tir::CallNode::reinterpret, {op->value}, + tir::CallNode::PureIntrinsic)), + MakeValue(tir::BroadcastNode::make(FloatImm(DataType::Float(32), 0), from.lanes())), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } @@ -105,12 +103,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); if (from.lanes() >= 8 && has_f16c) { - return CallVectorIntrin( - ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, - DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic))}); + return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8, + DTypeToLLVMType(DataType::Float(32, from.lanes())), + {MakeValue(tir::CallNode::make( + DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic))}); } #endif } @@ -150,10 +147,10 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenX86_64(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenX86_64(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 58bfb371c577d..d0bef465efa6f 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -22,153 +22,148 @@ */ #ifdef TVM_LLVM_VERSION -#include #include "intrin_rule_llvm.h" +#include + namespace tvm { namespace codegen { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); + .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = tir::CallNode::make( - x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = + tir::CallNode::make(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1); - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_two = make_const(x.dtype(), -2); - - PrimExpr exp_neg2x = tir::CallNode::make( - x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = tir::CallNode::make( - x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); - - PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = tir::SelectNode::make( - x >= make_zero(x.dtype()), tanh_pos, tanh_neg); -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1); + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_two = make_const(x.dtype(), -2); + + PrimExpr exp_neg2x = + tir::CallNode::make(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = + tir::CallNode::make(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = tir::SelectNode::make(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::CallNode::make( - x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::CallNode::make( - x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr sin_x = tir::CallNode::make(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::CallNode::make(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx + exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = + tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx + exp_negx) / two; + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx - exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = + tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx - exp_negx) / two; + *rv = ret; + }); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index bb9ff66c9cb55..8c5053bb68f95 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -25,17 +25,18 @@ #define TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_ #ifdef TVM_LLVM_VERSION -#include #include - #include +#include + #include + #include "llvm_common.h" namespace tvm { namespace codegen { // num_signature means number of arguments used to query signature -template +template inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -48,11 +49,10 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } -template +template inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -64,8 +64,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 0dc1272d7d497..ffe35ca9d4c74 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -22,9 +22,9 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include + #include namespace tvm { @@ -39,77 +39,54 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << call->name; if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 3699c9f691b32..52447a1699c59 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -22,9 +22,8 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include #include @@ -38,77 +37,54 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 29e4db3c91da3..5534a643676c2 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -22,11 +22,13 @@ */ #ifdef TVM_LLVM_VERSION +#include "llvm_common.h" + #include + #include -#include #include -#include "llvm_common.h" +#include namespace tvm { namespace codegen { @@ -56,15 +58,11 @@ void InitializeLLVM() { } } -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options) { +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options) { // setup target triple size_t start = 0; - if (target_str.length() >= 4 && - target_str.substr(0, 4) == "llvm") { + if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") { start = 4; } // simple parser @@ -82,16 +80,13 @@ void ParseLLVMTargetOptions(const std::string& target_str, } size_t pos = key.find('='); if (pos != std::string::npos) { - CHECK_GE(key.length(), pos + 1) - << "invalid argument " << key; + CHECK_GE(key.length(), pos + 1) << "invalid argument " << key; value = key.substr(pos + 1, key.length() - 1); key = key.substr(0, pos); } else { - CHECK(is >> value) - << "Unspecified value for option " << key; + CHECK(is >> value) << "Unspecified value for option " << key; } - if (key == "-target" || - key == "-mtriple") { + if (key == "-target" || key == "-mtriple") { *triple = value; } else if (key == "-mcpu") { *mcpu = value; @@ -115,16 +110,15 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - if (triple->length() == 0 || - *triple == "default") { + if (triple->length() == 0 || *triple == "default") { *triple = llvm::sys::getDefaultTargetTriple(); } // set target option llvm::TargetOptions& opt = *options; opt = llvm::TargetOptions(); - #if TVM_LLVM_VERSION < 50 +#if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; - #endif +#endif opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -136,21 +130,14 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, - bool allow_null) { +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null) { std::string target_triple, mcpu, mattr; llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_str, - &target_triple, - &mcpu, - &mattr, - &opt); + ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt); - if (target_triple.length() == 0 || - target_triple == "default") { + if (target_triple.length() == 0 || target_triple == "default") { target_triple = llvm::sys::getDefaultTargetTriple(); } if (mcpu.length() == 0) { @@ -158,14 +145,13 @@ GetLLVMTargetMachine(const std::string& target_str, } std::string err; - const llvm::Target* target = - llvm::TargetRegistry::lookupTarget(target_triple, err); + const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err); if (target == nullptr) { CHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + llvm::TargetMachine* tm = + target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 85ee1ee97495e..49389fe82ac0a 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -25,14 +25,12 @@ #define TVM_TARGET_LLVM_LLVM_COMMON_H_ #ifdef TVM_LLVM_VERSION -#include - #include #include -#include - -#include +#include #include +#include +#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -42,43 +40,41 @@ #include #include #include -#include #include +#include #include #include #include #include +#include +#include #include #include -#include #include - -#include +#include +#include #include #include -#include -#include #if TVM_LLVM_VERSION >= 100 #include #endif +#include +#include +#include +#include #include #include #include -#include -#include #include #include +#include #include #include -#include -#include - -#include -#include -#include #include +#include +#include namespace tvm { namespace codegen { @@ -97,11 +93,8 @@ void InitializeLLVM(); * \param options the options * \param mattr The attributes */ -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options); +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options); /*! * \brief Get target machine from target_str string. @@ -109,8 +102,8 @@ void ParseLLVMTargetOptions(const std::string& target_str, * \param allow_null Whether allow null to be returned. * \return target machine */ -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, bool allow_null = false); +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null = false); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index d1a244d01ff40..1151b33536b5d 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -23,23 +23,25 @@ */ #ifdef TVM_LLVM_VERSION +#include #include #include -#include #include + #include -#include "llvm_common.h" -#include "codegen_llvm.h" -#include "codegen_blob.h" + #include "../../runtime/file_util.h" #include "../../runtime/library_module.h" +#include "codegen_blob.h" +#include "codegen_llvm.h" +#include "llvm_common.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; class LLVMModuleNode final : public runtime::ModuleNode { public: @@ -51,24 +53,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return "llvm"; - } + const char* type_key() const { return "llvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "__tvm_is_system_module") { - bool flag = - (mptr_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) { - * rv = flag; - }); + bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); } else if (name == "_get_target_triple") { std::string target_triple = tm_->getTargetTriple().str(); - return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { - *rv = target_triple; - }); + return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); @@ -76,8 +69,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_main)); + const char* entry_name = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(GetFunctionAddr(entry_name)); @@ -88,13 +81,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { return WrapPackedFunc(faddr, sptr_to_self); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); - CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name - << " " << ecode.message(); + CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(mptr_); @@ -104,16 +95,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*m); @@ -126,16 +115,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif pass.run(*m); @@ -148,8 +135,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::WriteBitcodeToFile(*mptr_, dest); #endif } else { - LOG(FATAL) << "Do not know how to save file " - << file_name << " with format=\'"<< format << "\'"; + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } dest.close(); } @@ -165,28 +152,26 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::raw_svector_ostream rso(str); if (fmt == "s" || fmt == "asm") { - #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(mptr_); - #else - std::unique_ptr m = llvm::CloneModule(*mptr_); - #endif - llvm::legacy::PassManager pass; - CHECK(tm_); - #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #else - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #endif - pass.run(*m); - return rso.str().str(); +#if TVM_LLVM_VERSION <= 60 + std::unique_ptr m = llvm::CloneModule(mptr_); +#else + std::unique_ptr m = llvm::CloneModule(*mptr_); +#endif + llvm::legacy::PassManager pass; + CHECK(tm_); +#if TVM_LLVM_VERSION <= 60 + CHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#elif TVM_LLVM_VERSION <= 90 + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; +#else + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#endif + pass.run(*m); + return rso.str().str(); } else if (fmt == "" || fmt == "ll") { std::string type_str; llvm::raw_string_ostream rso(type_str); @@ -194,8 +179,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_->print(rso, nullptr); return rso.str(); } else { - LOG(FATAL) << "Do not know how to get source code with format: " - << format << "\'"; + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } return ""; } @@ -209,9 +193,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::vector funcs; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); @@ -251,8 +234,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_ = module_.get(); } - void Init(std::unique_ptr module, - std::shared_ptr ctx) { + void Init(std::unique_ptr module, std::shared_ptr ctx) { InitializeLLVM(); ctx_ = ctx; llvm::SMDiagnostic err; @@ -319,20 +301,17 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK(layout == mptr_->getDataLayout()) << "Data layout mismatch between module(" << mptr_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" - << layout.getStringRepresentation() << ")"; + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; ee_ = builder.create(tm.release()); - CHECK(ee_ != nullptr) - << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); + CHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { *ctx_addr = this; } - runtime::InitContextFunctions([this](const char *name) { - return reinterpret_cast(GetGlobalAddr(name)); - }); + runtime::InitContextFunctions( + [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); } // Get global address from execution engine. uint64_t GetGlobalAddr(const std::string& name) const { @@ -357,7 +336,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { // JIT lock std::mutex mutex_; // execution engine - llvm::ExecutionEngine *ee_{nullptr}; + llvm::ExecutionEngine* ee_{nullptr}; // The raw pointer to the module. llvm::Module* mptr_{nullptr}; // The target machine @@ -372,17 +351,13 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } - -TVM_REGISTER_GLOBAL("target.build.llvm") -.set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) { auto n = make_object(); n->Init(mod, target); return runtime::Module(n); }); - -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); auto target = args[0].operator std::string(); auto module_name = args[1].operator std::string(); @@ -403,35 +378,29 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") *rv = runtime::Module(n); }); -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = static_cast(LookupLLVMIntrinsic(args[0])); - }); - -TVM_REGISTER_GLOBAL("target.llvm_version_major") -.set_body([](TVMArgs args, TVMRetValue* rv) { - int major = TVM_LLVM_VERSION / 10; - *rv = major; - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->LoadIR(args[0]); - *rv = runtime::Module(n); - }); - -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") -.set_body([](TVMArgs args, TVMRetValue* rv) { - InitializeLLVM(); - *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); - }); +TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(LookupLLVMIntrinsic(args[0])); +}); + +TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) { + int major = TVM_LLVM_VERSION / 10; + *rv = major; +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->LoadIR(args[0]); + *rv = runtime::Module(n); +}); + +TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) { + InitializeLLVM(); + *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); +}); -TVM_REGISTER_GLOBAL("codegen.codegen_blob") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); - auto p = CodeGenBlob(args[0].operator std::string(), - args[1].operator bool(), + auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(), args[2].operator std::string()); n->Init(std::move(p.first), p.second); *rv = runtime::Module(n); diff --git a/src/target/opt/build_aocl_off.cc b/src/target/opt/build_aocl_off.cc index 2585ac23b9615..9f9d098b7a975 100644 --- a/src/target/opt/build_aocl_off.cc +++ b/src/target/opt/build_aocl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build aocl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "AOCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "aocl"); } diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 4f941a504f938..893eb67a268f4 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,9 @@ namespace tvm { namespace runtime { -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { LOG(FATAL) << "CUDA is not enabled"; return Module(); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 99dc5ad421a98..c9471d1bfa8df 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -27,30 +27,26 @@ #include #endif #include - #include + #include -#include "../build_common.h" -#include "../source/codegen_cuda.h" #include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_module.h" - +#include "../build_common.h" +#include "../source/codegen_cuda.h" namespace tvm { namespace codegen { -#define NVRTC_CALL(x) \ - { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) \ - << "NvrtcError: " #x " failed with error: " \ - << nvrtcGetErrorString(result); \ - } \ +#define NVRTC_CALL(x) \ + { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \ + } \ } - std::string FindCUDAIncludePath() { #if defined(_WIN32) const std::string delimiter = "\\"; @@ -78,7 +74,6 @@ std::string FindCUDAIncludePath() { return cuda_include_path; } - std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; @@ -104,16 +99,15 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { } for (const auto& string : compile_params) { - param_cstrings.push_back(string.c_str()); + param_cstrings.push_back(string.c_str()); } - NVRTC_CALL(nvrtcCreateProgram( - &prog, code.c_str(), nullptr, 0, nullptr, nullptr)); - nvrtcResult compile_res = - nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); + nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size; NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); - std::string log; log.resize(log_size); + std::string log; + log.resize(log_size); NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; size_t ptx_size; @@ -133,9 +127,8 @@ runtime::Module BuildCUDA(IRModule mod, std::string target) { CodeGenCUDA cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenCUDA: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -161,7 +154,6 @@ runtime::Module BuildCUDA(IRModule mod, std::string target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.cuda") -.set_body_typed(BuildCUDA); +TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index ce06700222aeb..c734eeceed6d9 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -23,9 +23,8 @@ namespace tvm { namespace runtime { Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index ff796d818b22a..3cfe1316e7ceb 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -20,16 +20,14 @@ /*! * Optional module when build metal is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/metal/metal_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module MetalModuleCreate(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal"); } diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 6e796b1edc62e..2367500eca927 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } diff --git a/src/target/opt/build_opengl_off.cc b/src/target/opt/build_opengl_off.cc index 781bf51c2cc03..2e860ceb422a5 100644 --- a/src/target/opt/build_opengl_off.cc +++ b/src/target/opt/build_opengl_off.cc @@ -20,14 +20,13 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opengl/opengl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, +Module OpenGLModuleCreate(std::unordered_map shaders, std::string fmt, std::unordered_map fmap) { LOG(WARNING) << "OpenGL runtime not enabled, return a source module..."; auto data = ToJSON(shaders); diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index 64ab759a9a249..476e5a88fc6ff 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -20,19 +20,15 @@ /*! * Optional module when build rocm is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/rocm/rocm_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly) { - +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly) { LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; auto fget_source = [rocm_source, assembly](const std::string& format) { if (format.length() == 0) return assembly; @@ -40,8 +36,7 @@ Module ROCMModuleCreate( if (format == "asm") return assembly; return std::string(""); }; - return codegen::DeviceSourceModuleCreate( - data, fmt, fmap, "hsaco", fget_source); + return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco", fget_source); } } // namespace runtime diff --git a/src/target/opt/build_sdaccel_off.cc b/src/target/opt/build_sdaccel_off.cc index 8c58c3f45b785..0de305c2a37c8 100644 --- a/src/target/opt/build_sdaccel_off.cc +++ b/src/target/opt/build_sdaccel_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "sdaccel"); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 64674e3360dd2..2b77869d48193 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -21,28 +21,27 @@ * \file codegen_aocl.cc */ #include -#include + #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/opencl/aocl/aocl_module.h" +#include + #include "../../runtime/file_util.h" +#include "../../runtime/opencl/aocl/aocl_module.h" +#include "../build_common.h" +#include "codegen_opencl.h" namespace tvm { namespace codegen { -runtime::Module BuildAOCL(IRModule mod, - std::string target_str, - bool emulation) { +runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) { // Get code. using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -80,15 +79,13 @@ runtime::Module BuildAOCL(IRModule mod, return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], false); - }); +TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], false); +}); -TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], true); - }); +TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 84604b8a0aed8..a99285144c235 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -20,20 +20,20 @@ /*! * \file codegen_c.cc */ -#include -#include #include "codegen_c.h" -#include "../../arith/pattern_match.h" + +#include +#include + #include "../../arith/compute_expr.h" +#include "../../arith/pattern_match.h" namespace tvm { namespace codegen { using namespace tir; -void CodeGenC::Init(bool output_ssa) { - print_ssa_form_ = output_ssa; -} +void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); @@ -79,8 +79,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); @@ -94,7 +93,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); - stream << ' '; } PrintType(GetType(v), stream); @@ -125,16 +123,11 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->stream << "}\n\n"; } -void CodeGenC::PrintFuncPrefix() { - stream << "void"; -} +void CodeGenC::PrintFuncPrefix() { stream << "void"; } -void CodeGenC::PrintFinalReturn() { -} +void CodeGenC::PrintFinalReturn() {} -std::string CodeGenC::Finish() { - return decl_stream.str() + stream.str(); -} +std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -146,12 +139,10 @@ void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) } } -void CodeGenC::PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) { +void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) { PrintType(t, stream); stream << ' ' << target << " = "; - if (src.length() > 3 && - src[0] == '(' && src[src.length() - 1] == ')') { + if (src.length() > 3 && src[0] == '(' && src[src.length() - 1] == ')') { stream << src.substr(1, src.length() - 2); } else { stream << src; @@ -160,8 +151,7 @@ void CodeGenC::PrintSSAAssign( } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -179,7 +169,6 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t, os); os << "*)" << vid << ')'; } else { @@ -188,8 +177,7 @@ std::string CodeGenC::GetBufferRef( os << "[("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << ']'; @@ -200,8 +188,7 @@ std::string CodeGenC::GetBufferRef( // optimize for constant access if (auto* ptr = index.as()) { int64_t offset = ptr->value; - CHECK_EQ(offset % t.lanes(), 0) - << "Find unaligned vector load to a vector type"; + CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; os << vid << '[' << (offset / t.lanes()) << ']'; return os.str(); } @@ -213,7 +200,6 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t, os); os << "*)("; if (!HandleTypeMatch(buffer, t.element_of())) { @@ -221,15 +207,13 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t.element_of(), os); os << "*)"; } os << vid << " + ("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << "))[0]"; @@ -238,8 +222,8 @@ std::string CodeGenC::GetBufferRef( } // Print a reference expression to a buffer. -std::string CodeGenC::GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { +std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, + int kind) { if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; @@ -256,17 +240,38 @@ std::string CodeGenC::GetStructRef( os << "]."; // other case: get fields. switch (kind) { - case intrinsic::kArrData: os << "data"; break; - case intrinsic::kArrShape: os << "shape"; break; - case intrinsic::kArrStrides: os << "strides"; break; - case intrinsic::kArrNDim: os << "ndim"; break; - case intrinsic::kArrTypeCode: os << "dtype.code"; break; - case intrinsic::kArrTypeBits: os << "dtype.bits"; break; - case intrinsic::kArrByteOffset: os << "byte_offset"; break; - case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break; - case intrinsic::kArrDeviceId: os << "ctx.device_id"; break; - case intrinsic::kArrDeviceType: os << "ctx.device_type"; break; - default: LOG(FATAL) << "unknown field code"; + case intrinsic::kArrData: + os << "data"; + break; + case intrinsic::kArrShape: + os << "shape"; + break; + case intrinsic::kArrStrides: + os << "strides"; + break; + case intrinsic::kArrNDim: + os << "ndim"; + break; + case intrinsic::kArrTypeCode: + os << "dtype.code"; + break; + case intrinsic::kArrTypeBits: + os << "dtype.bits"; + break; + case intrinsic::kArrByteOffset: + os << "byte_offset"; + break; + case intrinsic::kArrTypeLanes: + os << "dtype.lanes"; + break; + case intrinsic::kArrDeviceId: + os << "ctx.device_id"; + break; + case intrinsic::kArrDeviceType: + os << "ctx.device_type"; + break; + default: + LOG(FATAL) << "unknown field code"; } os << ')'; return os.str(); @@ -301,32 +306,26 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; } else { - CHECK(it->second == t) - << "conflicting buf var type"; + CHECK(it->second == t) << "conflicting buf var type"; } } -void CodeGenC::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << ".s" << std::hex << i << std::dec; } -void CodeGenC::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); - stream << vec << ".s" << std::hex << i - << " = " << value << ";\n" << std::dec; + stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; } -std::string CodeGenC::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -342,49 +341,58 @@ std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType targ return os.str(); } -void CodeGenC::BindThreadIndex(const IterVar& iv) { - LOG(FATAL) << "not implemented"; -} +void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; } -void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) +void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) } -void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_EQ(scope, "global"); } void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) - CHECK_EQ(t.lanes(), 1) - << "do not yet support vector types"; + CHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; if (t.is_handle()) { - os << "void*"; return; + os << "void*"; + return; } if (t.is_float()) { if (t.bits() == 32) { - os << "float"; return; + os << "float"; + return; } if (t.bits() == 64) { - os << "double"; return; + os << "double"; + return; } } else if (t.is_uint()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "uint" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "uint" << t.bits() << "_t"; + return; } - case 1: os << "int"; return; + case 1: + os << "int"; + return; } } else if (t.is_int()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "int" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "int" << t.bits() << "_t"; + return; } } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } - -void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) if (auto* ptr = type.as()) { return PrintType(ptr->dtype, os); } else if (auto* ptr = type.as()) { @@ -397,8 +405,7 @@ void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) } } - -inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->dtype == DataType::Int(32)) { std::ostringstream temp; temp << op->value; @@ -411,8 +418,8 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } - -inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, + CodeGenC* p) { // NOLINT(*) if (dtype == DataType::UInt(32)) { std::ostringstream temp; temp << val << "U"; @@ -425,9 +432,10 @@ inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeG } } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; temp << std::scientific << op->value; if (op->dtype.bits() == 32) temp << 'f'; @@ -438,10 +446,11 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { case 16: { os << '('; p->PrintType(op->dtype, os); - os << ')' << std::scientific <value << 'f'; + os << ')' << std::scientific << op->value << 'f'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -449,16 +458,15 @@ void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "\"" << op->value << "\""; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -480,10 +488,9 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsic(const CallNode* op, - const char* opstr, - std::ostream& os, // NOLINT(*) - CodeGenC* p) { +inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { if (op->dtype.lanes() == 1) { CHECK_EQ(op->args.size(), 2U); os << '('; @@ -554,8 +561,7 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { this->PrintExpr(op->args[i], os); @@ -594,19 +600,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[2], os); os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); os << "(("; this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) - << " + "; + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; this->PrintExpr(l->index, os); os << ')'; } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); - os << GetStructRef( - op->dtype, op->args[0], op->args[1], - op->args[2].as()->value); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); os << "("; @@ -626,19 +629,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->args[0], os); os << ")"; } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } } } -void CodeGenC::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { os << op << "("; this->PrintExpr(lhs, os); @@ -646,7 +646,7 @@ void CodeGenC::PrintVecBinaryOp( this->PrintExpr(rhs, os); os << ")"; } else { - os <<"("; + os << "("; this->PrintExpr(lhs, os); os << ' ' << op << ' '; this->PrintExpr(rhs, os); @@ -661,8 +661,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); HandleVolatileLoads(ref, op, os); } else { - CHECK(is_one(op->predicate)) - << "predicated load is not supported"; + CHECK(is_one(op->predicate)) << "predicated load is not supported"; arith::PVar base; if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { @@ -681,7 +680,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); - value_temp << ' '; } } PrintType(elem_type, value_temp); @@ -703,12 +701,11 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { DataType t = op->value.dtype(); if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { - CHECK(is_one(op->predicate)) - << "Predicated store is not supported"; + CHECK(is_one(op->predicate)) << "Predicated store is not supported"; arith::PVar base; if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); @@ -731,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); - stream << ' '; } } PrintType(elem_type, stream); @@ -762,9 +758,9 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) CHECK_EQ(op->base.dtype(), DataType::Int(32)); os << "((int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } @@ -773,7 +769,7 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { LOG(FATAL) << "Shuffle: not supported "; } -void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -794,19 +790,14 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - if (op->var.dtype() == DataType::Handle() && - handle_data_type_.count(op->var.get())) { + if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "* " - << AllocVarID(op->var.get()) - << " = ("; + stream << "* " << AllocVarID(op->var.get()) << " = ("; PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "*)" << value << ";\n"; + stream << "*)" << value << ";\n"; } else { PrintType(op->var.dtype(), this->stream); - this->stream << ' ' - << AllocVarID(op->var.get()) - << " = " << value << ";\n"; + this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; } } PrintStmt(op->body); @@ -816,17 +807,14 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - stream << ' '; - PrintType(op->dtype, stream); - stream << ' '<< vid << '[' - << constant_size << "];\n"; + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + const VarNode* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); @@ -848,6 +836,10 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); volatile_buf_.insert(v); + } else if (op->attr_key == tir::attr::pragma_import_c) { + const StringImmNode* value = op->value.as(); + CHECK(value != nullptr); + decl_stream << value->value; } this->PrintStmt(op->body); } @@ -871,9 +863,7 @@ void CodeGenC::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " - << vid << " < " << extent - << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); @@ -915,15 +905,13 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { const CallNode* call = op->value.as(); if (call) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { - this->PrintStorageSync(call); return; + this->PrintStorageSync(call); + return; } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); - std::string ref = GetStructRef( - call->args[3].dtype(), - call->args[0], - call->args[1], - call->args[2].as()->value); + std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], + call->args[2].as()->value); this->PrintIndent(); this->stream << ref << " = " << value << ";\n"; return; @@ -936,8 +924,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } } -void CodeGenC::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (i != 0) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index db655beded02a..309eb06816076 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,16 +24,18 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include +#include #include -#include #include +#include #include -#include -#include + #include -#include #include #include +#include + #include "codegen_source_base.h" namespace tvm { @@ -50,10 +52,9 @@ using namespace tir; * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`. */ -class CodeGenC : - public ExprFunctor, - public StmtFunctor, - public CodeGenSourceBase { +class CodeGenC : public ExprFunctor, + public StmtFunctor, + public CodeGenSourceBase { public: /*! * \brief Initialize the code generator. @@ -75,9 +76,7 @@ class CodeGenC : * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt& n) { - VisitStmt(n); - } + void PrintStmt(const Stmt& n) { VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. @@ -99,11 +98,11 @@ class CodeGenC : * * Example: stream << "void"; */ - virtual void PrintFuncPrefix(); // NOLINT(*) + virtual void PrintFuncPrefix(); // NOLINT(*) /*! * \brief Print the final return at the end the function. */ - virtual void PrintFinalReturn(); // NOLINT(*) + virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -115,33 +114,33 @@ class CodeGenC : */ virtual void InitFuncState(const PrimFunc& f); // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; @@ -158,36 +157,34 @@ class CodeGenC : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) /*! * Print Type represetnation of type type. * \param type The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) + virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; */ - virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) - virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) - virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) + virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) + virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) // Binary vector op. - virtual void PrintVecBinaryOp( - const std::string&op, DataType op_type, - PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) + virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, + std::ostream& os); // NOLINT(*) // print vector load virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); // print vector store - virtual void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element - virtual void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*) + virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os); // NOLINT(*) // print store of single element. - virtual void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value); + virtual void PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value); // Get a cast type from to virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression @@ -195,11 +192,9 @@ class CodeGenC : protected: // Print reference to struct location - std::string GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); + std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index); /*! * \brief Handle volatile loads. @@ -209,8 +204,7 @@ class CodeGenC : * does not implement volatile member functions. CUDA codegen will cast * away volatile qualifier from CUDA __half types. */ - virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) { + virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) { // By default, do nothing but print the loaded value. os << value; } @@ -223,9 +217,7 @@ class CodeGenC : * or "__constant__" is not part of type but a storage class (like * C/C++ static). */ - virtual bool IsScopePartOfType() const { - return true; - } + virtual bool IsScopePartOfType() const { return true; } /*! * \brief If buffer is allocated as type t. @@ -240,15 +232,12 @@ class CodeGenC : */ void RegisterHandleType(const VarNode* buf_var, DataType t); // override - void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) final; + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); /*! \brief Check if buf_var is volatile or not. */ - bool IsVolatile(const VarNode *buf_var) const { - return volatile_buf_.count(buf_var) != 0; - } + bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; } /*! \brief restrict keyword */ std::string restrict_keyword_{""}; @@ -257,29 +246,6 @@ class CodeGenC : /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; - /*! - * \brief A RAII utility class for emitting code in a scoped region. - */ - class EnterScopeRAII { - // The codegen context. - CodeGenC* cg; - - // The new scope level. - int scope; - - public: - explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) { - cg->PrintIndent(); - cg->stream << "{\n"; - scope = cg->BeginScope(); - } - ~EnterScopeRAII() { - cg->EndScope(scope); - cg->PrintIndent(); - cg->stream << "}\n"; - } - }; - private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index cbdec62017425..b11b3d8fc5f98 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -20,24 +20,26 @@ /*! * \file codegen_c_host.cc */ +#include "codegen_c_host.h" + #include -#include + #include -#include "codegen_c_host.h" +#include + #include "../build_common.h" namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { - module_name_ = GetUniqueName("__tvm_module_ctx"); -} +CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { emit_asserts_ = emit_asserts; + declared_globals_.clear(); decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; - decl_stream << "extern void* " << module_name_ << " = NULL;\n"; + decl_stream << "void* " << module_name_ << " = NULL;\n"; CodeGenC::Init(output_ssa); } @@ -56,12 +58,13 @@ void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "does not support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -69,37 +72,55 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) case 16: os << "half"; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; } switch (t.bits()) { - case 8: os << "int8_t"; break; - case 16: os << "int16_t"; break; - case 32: os << "int32_t"; break; - case 64: os << "int64_t"; break; - case 1: os << "int32_t"; break; - default: fail = true; break; + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } -void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -117,9 +138,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, this->stream << "if (" << packed_func_name << " == NULL) {\n"; int packed_func_if_scope = this->BeginScope(); this->PrintIndent(); - this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ - << ", \"" << func_name << "\"" - << ", &" << packed_func_name << ") != 0) {\n"; + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; int get_func_env_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -140,9 +160,12 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" << ", " << "(int*) stack_tcode" << ", " - << num_args << ", " << "&" << ret_val << ", " << "&" - << ret_type_code << ") != 0) {\n"; + << "(TVMValue*) stack_value" + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ") != 0) {\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -151,7 +174,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "}\n"; } -void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; @@ -182,8 +205,15 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( int64_t num_args = end - begin; CHECK_GE(num_args, 0); std::string func_name = s->value; - std::string packed_func_name = GetUniqueName(func_name + "_packed"); - decl_stream << "static void* " << packed_func_name << " = NULL;\n"; + // NOTE: cannot rely on GetUnique for global decl_stream declarations + // because it is reset between AddFunction(). + std::string packed_func_name = func_name + "_packed"; + if (declared_globals_.insert(packed_func_name).second) { + // Still reserve the name among unique names. + CHECK(GetUniqueName(packed_func_name) == packed_func_name) + << "Expected name " << packed_func_name << " to not be taken"; + decl_stream << "static void* " << packed_func_name << " = NULL;\n"; + } this->PrintGetFuncFromBackend(func_name, packed_func_name); this->PrintFuncCall(packed_func_name, num_args); } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { @@ -194,7 +224,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( } } -void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) +void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); @@ -211,18 +241,17 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) this->PrintStmt(op->body); } -void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } -void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } template -inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, - const char* compare, - std::ostream& os) { // NOLINT(*) +inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, + std::ostream& os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); @@ -241,9 +270,8 @@ runtime::Module BuildCHost(IRModule mod) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenCHost: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); cg.AddFunction(f); } @@ -252,9 +280,8 @@ runtime::Module BuildCHost(IRModule mod) { return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_GLOBAL("target.build.c") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildCHost(args[0]); - }); +TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildCHost(args[0]); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4f9a0a74511fc..94a76faabd787 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -24,10 +24,12 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ -#include -#include +#include #include + #include "codegen_c.h" +#include "tvm/target/codegen.h" +#include "tvm/tir/expr.h" namespace tvm { namespace codegen { @@ -37,22 +39,24 @@ class CodeGenCHost final : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts); - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintFuncPrefix() final; // NOLINT(*) - void PrintFinalReturn() final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintFuncPrefix() final; // NOLINT(*) + void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // overload min and max to use the ternary operator, so we don't rely on the // standard library implementations - void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) final; // NOLINT(*) - void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) + void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) private: std::string module_name_; + /* \brief tracks declared global variables which live despite GetUniqueName */ + std::set declared_globals_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; @@ -67,8 +71,7 @@ class CodeGenCHost final : public CodeGenC { * \param os stream reference to print into */ template - inline void PrintTernaryCondExpr(const T* op, - const char* compare, + inline void PrintTernaryCondExpr(const T* op, const char* compare, std::ostream& os); // NOLINT(*) }; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 02b5b413562ea..cf7a74f1dcc05 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -21,21 +21,21 @@ * \file codegen_cuda.cc */ +#include "codegen_cuda.h" + #include #include +#include #include #include -#include + #include "literal/cuda_half_t.h" -#include "codegen_cuda.h" namespace tvm { namespace codegen { -CodeGenCUDA::CodeGenCUDA() { - restrict_keyword_ = "__restrict__"; -} +CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); @@ -44,10 +44,7 @@ void CodeGenCUDA::Init(bool output_ssa) { CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } - -void CodeGenCUDA::PrintFuncPrefix() { - stream << "extern \"C\" __global__ void"; -} +void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } std::string CodeGenCUDA::Finish() { if (enable_fp16_) { @@ -64,6 +61,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_util; } + if (enable_warp_shuffle_) { + decl_stream << _cuda_warp_intrinsic_util; + } + if (enable_int8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; decl_stream << "#include \n"; @@ -92,16 +93,15 @@ void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - var_idmap_[iv->var.get()] = - CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } bool fail = false; if (t.is_float()) { @@ -126,22 +126,31 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } break; - case 32: os << "float"; break; - case 64: os << "double"; break; - default: fail = true; break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; } if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } else if (t.is_vector_bool()) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. int n = t.lanes(); if (n <= 4) { - os << "ushort" << n; return; + os << "ushort" << n; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -154,31 +163,41 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) switch (t.bits()) { case 1: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { - os << "int8_t"; return; + os << "int8_t"; + return; } else if (t.lanes() == 16) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 32) { - os << "int"; return; + os << "int"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } } case 4: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 4) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 8) { // directly 8 4-bit int in integer. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 16) { - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 32) { - os << "int4"; return; + os << "int4"; + return; } else if (t.lanes() == 64) { - os << "int8"; return; + os << "int8"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } @@ -191,59 +210,71 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // We use int for int8x4 instead of char4 because using char4 is // likely to produce extra instructions to pack four int8 elements // into 32-bit data. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { enable_int8_ = true; - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 16) { enable_int8_ = true; - os << "int4"; return; + os << "int4"; + return; } else if (!t.is_uint() && t.lanes() == 1) { - os << "signed char"; break; + os << "signed char"; + break; } else { - os << "char"; break; + os << "char"; + break; } } - case 16: os << "short"; break; - case 32: os << "int"; break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; case 64: { - if (sizeof(long) != 8) { // NOLINT(*) + if (sizeof(long) != 8) { // NOLINT(*) if (t.lanes() == 1) { - os << "long long"; break; + os << "long long"; + break; } else if (t.lanes() == 2) { - os << "longlong"; break; + os << "longlong"; + break; } else { // No longlong3, longlong4 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform"; break; } } else { - os << "long"; break; + os << "long"; + break; } } - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) { return; } if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } -void CodeGenCUDA::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) // Delcare the result. std::string sret = GetUniqueName("_"); this->PrintIndent(); this->PrintType(t, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); - // Unpack into individual ops. std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); @@ -269,37 +300,54 @@ void CodeGenCUDA::PrintVecBinaryOp( os << sret; } -void CodeGenCUDA::PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if ((t.is_int()) && t.bits() == 8) { - os << "((char)(" << vec << " >> " << i * 8 << "))"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else { + os << "((char)(" << vec << " >> " << i * 8 << "))"; + } } else if ((t.is_uint()) && t.bits() == 8) { - os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else { + os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; + } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2]; + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else { os << vec << "." << access[i]; } } -void CodeGenCUDA::PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) { +void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - stream << vec << "="; - // Do not read the first undef lane. - if (i != 0) { - stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; + if (t.lanes() == 2 || t.lanes() == 3) { + stream << vec << '.' << access[i % t.lanes()] << "=" + << "(" << value << ");\n"; + } else { + stream << vec << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; + } + stream << "(" << value << " << " << i * 8 << ");\n"; } - stream << "(" << value << " << " << i * 8 << ");\n"; } else if (t.is_float16()) { - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2] << " = " << value << ";\n"; + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " + << value << ";\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } @@ -315,8 +363,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } else if (sync == "global") { if (!need_global_barrier_) { need_global_barrier_ = true; - this->decl_stream << "extern \"C\" __device__ unsigned " - << vid_global_barrier_state_ << ";\n"; + this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ + << ";\n"; } // global synchronizer std::string is_load = PrintExpr(op->args[1]); @@ -324,33 +372,31 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { this->PrintIndent(); // In theory only threadfence is needed // but we observed problems with only threadfence - this->stream <<"__threadfence_system();\n"; + this->stream << "__threadfence_system();\n"; this->PrintIndent(); - this->stream <<"if (" << is_load << ") {\n"; + this->stream << "if (" << is_load << ") {\n"; int wb = this->BeginScope(); this->PrintIndent(); this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; this->PrintIndent(); std::string ptr = GetUniqueName("pf"); - this->stream << "volatile unsigned* " - << ptr << " = &" << vid_global_barrier_state_<< ";\n"; + this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; this->PrintIndent(); this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; this->PrintIndent(); - this->stream <<"while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; + this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; this->EndScope(wb); this->PrintIndent(); - this->stream <<"}\n"; + this->stream << "}\n"; this->PrintIndent(); - this->stream <<"__syncthreads();\n"; + this->stream << "__syncthreads();\n"; } } -void CodeGenCUDA::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_NE(scope, "global"); if (scope == "shared") { - os << "__shared__"; + os << "__shared__ "; } } @@ -360,8 +406,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { CHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. - if (from_ty.is_scalar()) - return CodeGenC::VisitExpr_(op, os); + if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. @@ -370,7 +415,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { this->PrintType(target_ty, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); std::string src = SSAGetID(PrintExpr(op->value), from_ty); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; @@ -385,7 +429,14 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { +void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || + op->is_intrinsic("__shfl_down_sync")) { + enable_warp_shuffle_ = true; + } + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); @@ -419,7 +470,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[6], os); - if (const StringImmNode *str = op->args[7].as()) { + if (const StringImmNode* str = op->args[7].as()) { os << ", nvcuda::wmma::mem_" << str->value; } else { LOG(FATAL) << "Invalid parameters"; @@ -433,7 +484,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { need_mma_h_ = true; @@ -443,7 +494,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { // @@ -470,8 +521,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintType(op->dtype, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); - // Load arguments. std::vector sargs; for (size_t i = 0; i < op->args.size(); ++i) { @@ -484,8 +533,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { std::ostringstream scall; scall << op->name << "("; for (size_t j = 0; j < op->args.size(); ++j) { - if (j > 0) - scall << ", "; + if (j > 0) scall << ", "; PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); } scall << ")"; @@ -517,46 +565,39 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; const VarNode* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || - op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) - << "Matrix_a and matrix_b only support half or char or unsigned char " - << "or uint4 or int4 or int1 type for now"; + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; } else { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Float(32) || + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32)) - << "Accumulator only support half, float and int type for now"; + << "Accumulator only support half, float and int type for now"; } constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); - stream << ' '; PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && scope == "shared") { + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { constant_size = constant_size / (32 / op->dtype.bits()); } - stream << ' '<< vid << '[' - << constant_size << "];\n"; + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } -void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { +void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { @@ -576,17 +617,17 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { os << "((make_int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } -void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { // make_int8x4 - const int64_t *p = as_const_int(op->value); + const int64_t* p = as_const_int(op->value); CHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; @@ -605,7 +646,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << '('; for (int i = 0; i < op->lanes / 2; ++i) { if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + os << "__pack_half2(" << v << ", " << v << ")"; } os << ')'; return; @@ -622,7 +663,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } -void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) { std::vector to_shuffle(op->vectors.size()); for (int i = 0, e = op->vectors.size(); i < e; ++i) { CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; @@ -632,15 +673,15 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { PrintType(op->dtype, os); os << '('; for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size()); + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); if (i != 0) os << ", "; os << to_shuffle[*val]; } os << ')'; } -void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { // Non-vector cases. if (!op->dtype.is_vector()) { CodeGenC::VisitExpr_(op, os); @@ -648,8 +689,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { } // Codegen vector condition case by serializing the select op. - CHECK(op->false_value->dtype == op->dtype && - op->true_value->dtype == op->dtype && + CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && op->dtype.lanes() == op->condition.dtype().lanes()); std::string r_var = GetUniqueName("_"); @@ -657,8 +697,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { this->PrintType(op->dtype, stream); stream << ' ' << r_var << ";\n"; { - EnterScopeRAII scope(this); - std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); @@ -682,9 +720,10 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { os << r_var; } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { @@ -708,17 +747,17 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) os << '(' << std::scientific << op->value << 'f' << ')'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } - -void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, - const VarNode* variable, std::ostream &os) { +void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes[variable]; @@ -743,22 +782,22 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, if (scope == "wmma.matrix_a") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } } -int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, - const VarNode* variable, int32_t size) { +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, + int32_t size) { std::string shape_str = fragment_shapes[variable]; size_t m, n, k; size_t last_pos = 0, pos = 0; @@ -779,8 +818,8 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, return 0; } -void CodeGenCUDA::HandleVolatileLoads(const std::string& value, - const LoadNode* op, std::ostream& os) { +void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op, + std::ostream& os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // @@ -793,15 +832,17 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, } } -void CodeGenCUDA::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (i != 0) { - os << "|"; + if (!(t.lanes() == 2 || t.lanes() == 3)) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; } - os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; - return; } if (t.is_float16()) { diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index d1db7047b1b66..f9ab0ade2cf20 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "codegen_c.h" namespace tvm { @@ -46,37 +48,32 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; // overload visitor - void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; - void VisitExpr_(const CallNode *op, std::ostream& os) final; + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CastNode* op, std::ostream& os) final; - void VisitStmt_(const EvaluateNode *op) final; - void VisitStmt_(const AllocateNode *op) final; - void VisitStmt_(const AttrStmtNode *op) final; + void VisitStmt_(const EvaluateNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; private: // Handle volatile loads - void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) final; + void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. - bool IsScopePartOfType() const final { - return false; - } + bool IsScopePartOfType() const final { return false; } // Whether global barrier is needed. bool need_global_barrier_{false}; @@ -88,6 +85,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable int8 bool enable_int8_{false}; + // whether enable warp shuffle intrinsics + bool enable_warp_shuffle_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; // whether need mma.h @@ -96,10 +95,9 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope( - const std::string& scope, DataType t, const VarNode* variable, std::ostream& os); - int32_t GetWmmaFragmentSize( - const std::string &scope, const VarNode* variable, int32_t size); + void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; } // namespace codegen diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ea49d33351a06..e381afb4db84c 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -20,13 +20,15 @@ /*! * \file codegen_metal.cc */ -#include -#include -#include #include "codegen_metal.h" -#include "../build_common.h" + +#include +#include +#include + #include "../../runtime/metal/metal_module.h" #include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -57,8 +59,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; @@ -67,14 +68,13 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t num_buffer = 0; for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { Var v = f->params[i]; - if (!v.dtype().is_handle()) break; + if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); } - stream << ' '; PrintType(GetType(v), stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the @@ -84,17 +84,15 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { RegisterHandleType(v.get(), prim->dtype); } } - stream << ' ' << vid - << " [[ buffer(" << i << ") ]],\n"; + stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = - static_cast(global_symbol.value()) + "_args_t"; - stream << " constant " << arg_buf_type << "& " << varg - << " [[ buffer(" << num_buffer << ") ]],\n"; + std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; + stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < f->params.size(); ++i) { @@ -121,8 +119,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>( - tir::attr::kDeviceThreadAxis).value(); + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -165,23 +162,31 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; break; - case 32: os << "float"; break; - default: fail = true; break; + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -189,18 +194,30 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 1: os << "bool"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 1: + os << "bool"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; @@ -219,32 +236,29 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { } } -void CodeGenMetal::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenMetal::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" << " = " << value << ";\n"; } -void CodeGenMetal::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { - os << "device"; + os << "device "; } else if (scope == "shared") { - os << "threadgroup"; + os << "threadgroup "; } else { - os << "thread"; + os << "thread "; } } -void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); PrintType(op->dtype, os); os << "("; @@ -274,9 +288,8 @@ runtime::Module BuildMetal(IRModule mod) { CodeGenMetal cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenMetal: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -295,9 +308,8 @@ runtime::Module BuildMetal(IRModule mod) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); } -TVM_REGISTER_GLOBAL("target.build.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildMetal(args[0]); - }); +TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildMetal(args[0]); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 644c962ab2d68..26abe34d998ee 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -36,22 +38,21 @@ class CodeGenMetal final : public CodeGenC { CodeGenMetal(); // override print thread tag. void PrintArgUnionDecl(); - void AddFunction(const PrimFunc& f); // NOLINT(*) + void AddFunction(const PrimFunc& f); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) // print store of single element. - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) // overload visitor - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // reuse parent's function. using CodeGenC::PrintType; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 0cb7422804645..746d418b6a377 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -20,20 +20,20 @@ /*! * \file codegen_opencl.cc */ +#include "codegen_opencl.h" + #include -#include #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/thread_storage_scope.h" +#include + #include "../../runtime/opencl/opencl_module.h" +#include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenCL::CodeGenOpenCL() { - restrict_keyword_ = "restrict"; -} +CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); @@ -44,34 +44,30 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { } } -void CodeGenOpenCL::PrintFuncPrefix() { - stream << "__kernel void"; -} +void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; } std::string CodeGenOpenCL::Finish() { // inject extension enable pragma for fp16 and fp64 if (enable_fp16_) { - decl_stream - << "#ifdef cl_khr_fp16\n" - "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" - "#elif defined(cl_amd_fp16)\n" - "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" - "#else\n" - "#error \"Half precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp16\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "#elif defined(cl_amd_fp16)\n" + "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" + "#else\n" + "#error \"Half precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } if (enable_fp64_) { - decl_stream - << "#ifdef cl_khr_fp64\n" - "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" - "#elif defined(cl_amd_fp64)\n" - "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" - "#else\n" - "#error \"Double precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp64\n" + "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "#elif defined(cl_amd_fp64)\n" + "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" + "#else\n" + "#error \"Double precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } return CodeGenC::Finish(); @@ -86,19 +82,19 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = - CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); } void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -107,16 +103,21 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "half"; enable_fp16_ = true; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; enable_fp64_ = true; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -124,41 +125,53 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 64: os << "long"; break; - case 1: os << "int"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 64: + os << "long"; + break; + case 1: + os << "int"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } -void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; auto it = alloc_storage_scope_.find(buffer); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - os << ' '; PrintType(t.element_of(), os); os << "*)"; } os << GetVarID(buffer) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -166,8 +179,7 @@ std::string CodeGenOpenCL::GetVecLoad( return os.str(); } -void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; @@ -188,12 +200,11 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { } } -void CodeGenOpenCL::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { - os << "__global"; + os << "__global "; } else if (scope == "shared") { - os << "__local"; + os << "__local "; } } @@ -213,7 +224,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType return os.str(); } -void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -225,7 +236,7 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } -void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { os << "-"; @@ -244,9 +255,8 @@ runtime::Module BuildOpenCL(IRModule mod, std::string target) { CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -261,7 +271,6 @@ runtime::Module BuildOpenCL(IRModule mod, std::string target) { return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.opencl") -.set_body_typed(BuildOpenCL); +TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index cc1fe994739f3..32a98e4d87ea4 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -38,24 +40,22 @@ class CodeGenOpenCL final : public CodeGenC { // override print thread tag. void InitFuncState(const PrimFunc& f) final; - void PrintFuncPrefix() final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const VarNode* buffer, - PrimExpr base) final; - void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + void PrintFuncPrefix() final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; + void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os); // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 0b85e26160152..fd5c3ba811e04 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -23,19 +23,20 @@ * We are targeting OpenGL 3.3. The reason of not targeting a recent version * of OpenGL is to have better compatibility of WebGL 2. */ -#include +#include "codegen_opengl.h" + #include -#include #include -#include "codegen_opengl.h" -#include "../build_common.h" +#include +#include + #include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenGL::CodeGenOpenGL() - : output_(nullptr), output_iter_var_(nullptr) {} +CodeGenOpenGL::CodeGenOpenGL() : output_(nullptr), output_iter_var_(nullptr) {} void CodeGenOpenGL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); @@ -160,20 +161,16 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( - this->decl_stream.str() + this->stream.str(), - std::move(arg_names), std::move(arg_kinds), - this->thread_extent_var_); + shaders_[static_cast(global_symbol.value())] = + runtime::OpenGLShader(this->decl_stream.str() + this->stream.str(), std::move(arg_names), + std::move(arg_kinds), this->thread_extent_var_); } -std::unordered_map CodeGenOpenGL::Finish() { - return shaders_; -} +std::unordered_map CodeGenOpenGL::Finish() { return shaders_; } void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) { CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x"; - CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) - << "Only support one thread iter var"; + CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) << "Only support one thread iter var"; CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var"; var_idmap_[iv->var.get()] = iv->thread_tag; @@ -211,8 +208,7 @@ std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) { // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r -std::string CodeGenOpenGL::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenOpenGL::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; @@ -274,11 +270,10 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { // Doesn't support store to vector. auto type = value.dtype(); - CHECK_EQ(type.lanes(), 1) - << "Vectorized store not implemented, type = " << type; + CHECK_EQ(type.lanes(), 1) << "Vectorized store not implemented, type = " << type; CHECK(inputs_.find(buffer) == inputs_.cend()) - << "Texture has been read from before. Must not store to it."; + << "Texture has been read from before. Must not store to it."; if (output_ == nullptr) { output_ = buffer; // Record that this texture is the output. } else { @@ -294,9 +289,8 @@ runtime::Module BuildOpenGL(IRModule mod, std::string target) { CodeGenOpenGL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenGL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -308,8 +302,7 @@ runtime::Module BuildOpenGL(IRModule mod, std::string target) { return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod)); } -TVM_REGISTER_GLOBAL("target.build.opengl") -.set_body_typed(BuildOpenGL); +TVM_REGISTER_GLOBAL("target.build.opengl").set_body_typed(BuildOpenGL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opengl.h b/src/target/source/codegen_opengl.h index 954806bbca59a..2748ae28cfd53 100644 --- a/src/target/source/codegen_opengl.h +++ b/src/target/source/codegen_opengl.h @@ -25,11 +25,13 @@ #define TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_ #include + #include -#include #include -#include "codegen_c.h" +#include + #include "../../runtime/opengl/opengl_module.h" +#include "codegen_c.h" namespace tvm { namespace codegen { @@ -45,11 +47,11 @@ class CodeGenOpenGL final : public CodeGenC { void VisitStmt_(const StoreNode* op) final; std::string TexelFetch(const VarNode* buffer, PrimExpr index); std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final; - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // Codegen for immediate values - void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) // Match glsl_texture_store Call. diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 0859428aa58b6..9b2f0345864f7 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -70,8 +70,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { - CHECK(!var_idmap_.count(v)) - << "Need input to be in SSA form dup " << v->name_hint; + CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; std::string vid = GetUniqueName(key); var_idmap_[v] = vid; @@ -80,8 +79,7 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 6723767b401f6..39016590abdc1 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -24,13 +24,15 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include -#include -#include -#include + #include +#include #include +#include + #include "../../runtime/meta_data.h" namespace tvm { @@ -103,8 +105,7 @@ class CodeGenSourceBase { * \param src The source expression. * \param t The type of target. */ - virtual void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) = 0; + virtual void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) = 0; /*! \brief the declaration stream */ std::ostringstream decl_stream; @@ -147,11 +148,8 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt); * \param fget_source a closure to replace default get source behavior. */ runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source = nullptr); + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source = nullptr); } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 71c36264afa46..e60e1f5027d7a 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -20,11 +20,13 @@ /*! * \file codegen_vhls.cc */ -#include -#include #include "codegen_vhls.h" -#include "../build_common.h" + +#include +#include + #include "../../runtime/opencl/sdaccel/sdaccel_module.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -40,37 +42,45 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { if (t.is_uint()) { switch (t.bits()) { case 8: - os << "unsigned char"; break; + os << "unsigned char"; + break; case 16: - os << "unsigned short"; break; + os << "unsigned short"; + break; case 32: - os << "unsigned int"; break; + os << "unsigned int"; + break; case 64: - os << "unsigned long long"; break; + os << "unsigned long long"; + break; default: - os << "ap_uint<" << t.bits() << ">"; break; + os << "ap_uint<" << t.bits() << ">"; + break; } } else if (t.is_int()) { switch (t.bits()) { case 8: - os << "char"; break; + os << "char"; + break; case 16: - os << "short"; break; + os << "short"; + break; case 32: - os << "int"; break; + os << "int"; + break; case 64: - os << "long long"; break; + os << "long long"; + break; default: - os << "ap_int<" << t.bits() << ">"; break; + os << "ap_int<" << t.bits() << ">"; + break; } } else { CodeGenC::PrintType(t, os); } } -void CodeGenVivadoHLS::PrintFuncPrefix() { - stream << "extern \"C\" void"; -} +void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; } void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { for (size_t i = 0; i < f->params.size(); ++i) { @@ -84,9 +94,8 @@ void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n"; } -template -inline void PrintBinaryExpr(const T* op, - const char *opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenVivadoHLS* p) { os << opstr << '('; @@ -96,35 +105,38 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } -void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::min"; +void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::min"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fminf"; break; + opstr = "fminf"; + break; case 64: - opstr = "fmin"; break; + opstr = "fmin"; + break; } } PrintBinaryExpr(op, opstr, os, this); } -void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::max"; +void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::max"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fmaxf"; break; + opstr = "fmaxf"; + break; case 64: - opstr = "fmax"; break; + opstr = "fmax"; + break; } } PrintBinaryExpr(op, opstr, os, this); } - runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { using tvm::runtime::Registry; bool output_ssa = false; @@ -133,9 +145,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for get_source(). cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenVHLS: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -148,9 +159,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for compilation. Array > kernel_info; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); CodeGenVivadoHLS cg; cg.Init(output_ssa); @@ -176,8 +186,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code); } -TVM_REGISTER_GLOBAL("target.build.sdaccel") -.set_body_typed(BuildSDAccel); +TVM_REGISTER_GLOBAL("target.build.sdaccel").set_body_typed(BuildSDAccel); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_vhls.h b/src/target/source/codegen_vhls.h index 10f9ea7679b6d..b9bec516bae9b 100644 --- a/src/target/source/codegen_vhls.h +++ b/src/target/source/codegen_vhls.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "codegen_c.h" namespace tvm { @@ -40,8 +42,8 @@ class CodeGenVivadoHLS final : public CodeGenC { void PrintFuncPrefix() final; void PreFunctionBody(const PrimFunc& f) final; - void VisitExpr_(const MinNode *op, std::ostream& os) final; - void VisitExpr_(const MaxNode *op, std::ostream& os) final; + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; }; } // namespace codegen diff --git a/src/target/source/intrin_rule_aocl.cc b/src/target/source/intrin_rule_aocl.cc index 6317a2fab381d..0cafd0255a864 100644 --- a/src/target/source/intrin_rule_aocl.cc +++ b/src/target/source/intrin_rule_aocl.cc @@ -27,73 +27,49 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index f40dd5e86bada..4e4abd9764c38 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -31,10 +31,14 @@ struct CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { switch (t.bits()) { - case 64: return name; - case 32: return name + 'f'; - case 16: return 'h' + name; - default: return ""; + case 64: + return name; + case 32: + return name + 'f'; + case 16: + return 'h' + name; + default: + return ""; } } return ""; @@ -55,14 +59,18 @@ struct CUDAFastMath : public CUDAMath { struct CUDAFastMathTan : public CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { - switch (t.bits()) { - case 64: return name; - // `__tanf` seems to produce some values too deviant from numpy tan version. - // So, let's use just `tanf` instead. - case 32: return name + 'f'; - case 16: LOG(FATAL) << "cuda tan unsupported for float16"; - default: return ""; - } + switch (t.bits()) { + case 64: + return name; + // `__tanf` seems to produce some values too deviant from numpy tan version. + // So, let's use just `tanf` instead. + case 32: + return name + 'f'; + case 16: + LOG(FATAL) << "cuda tan unsupported for float16"; + default: + return ""; + } } return ""; } @@ -72,96 +80,104 @@ struct CUDAPopcount { std::string operator()(DataType t, std::string name) const { if (t.is_uint()) { switch (t.bits()) { - case 32: return "__popc"; - case 64: return "__popcll"; - default: return ""; + case 32: + return "__popc"; + case 64: + return "__popcll"; + default: + return ""; } } return ""; } }; +struct CUDAWarpIntrinsic { + const char* operator()(DataType t, const std::string& name) const { + if (name == intrinsic::tvm_warp_shuffle) { + return "__shfl_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_up) { + return "__shfl_up_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_down) { + return "__shfl_down_sync"; + } + if (name == intrinsic::tvm_warp_activemask) { + return "__activemask"; + } + return ""; + } +}; + +template static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2]}}; - *rv = CallNode::make( - call->dtype, "__shfl", cuda_args, CallNode::PureExtern); + const char* name = T()(call->dtype, call->name); + *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") -.set_body(DispatchCUDAShuffle); + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up") + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") + .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 8bc87d2b280f8..00fb9f9a95dec 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -27,65 +27,45 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 7374e6d40032c..60fbde7552aa3 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -22,71 +22,52 @@ * \brief OpenCL intrinsic rules. */ #include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern); // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension @@ -94,17 +75,15 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; - CHECK(analyzer.CanProve(call->args[2] == call->args[3])) - << "Intel warp shuffle dose not support width != warp_size"; - Array cuda_args{{call->args[0], call->args[1]}}; - *rv = CallNode::make( - call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern); + CHECK(analyzer.CanProve(call->args[3] == call->args[4])) + << "Intel warp shuffle dose not support width != warp_size"; + Array opencl_args{{call->args[1], call->args[2]}}; + *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); } -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") -.set_body(DispatchIntelShuffle); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opengl.cc b/src/target/source/intrin_rule_opengl.cc index 1710d45d8bd6e..1f2a21a112a05 100644 --- a/src/target/source/intrin_rule_opengl.cc +++ b/src/target/source/intrin_rule_opengl.cc @@ -27,53 +27,37 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index 41e76f260ff42..fb01d6566dabe 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -27,62 +27,43 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 858ac8572a081..baf4ba733dce1 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) { } )"; +static constexpr const char* _cuda_warp_intrinsic_util = R"( +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) +#define __shfl_sync(mask, var, lane, width) \ + __shfl((var), (lane), (width)) + +#define __shfl_down_sync(mask, var, offset, width) \ + __shfl_down((var), (offset), (width)) + +#define __shfl_up_sync(mask, var, offset, width) \ + __shfl_up((var), (offset), (width)) +#endif + +)"; + #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 5f133212140c7..ba7f075d00456 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,43 +23,36 @@ */ #include #include -#include "codegen_source_base.h" + #include "../../runtime/file_util.h" #include "../../runtime/meta_data.h" +#include "codegen_source_base.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; +using runtime::FunctionInfo; using runtime::GetFileFormat; using runtime::GetMetaFilePath; -using runtime::FunctionInfo; using runtime::SaveBinaryToFile; // Simulator function class SourceModuleNode : public runtime::ModuleNode { public: - SourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "source"; - } + SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "source"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } protected: std::string code_; @@ -74,35 +67,25 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "c"; - } + CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "c"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "C Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cc") { CHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; } } @@ -119,20 +102,12 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: - DeviceSourceModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, + DeviceSourceModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string type_key, std::function fget_source) - : data_(data), - fmt_(fmt), - fmap_(fmap), - type_key_(type_key), - fget_source_(fget_source) {} - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -146,15 +121,11 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return type_key_.c_str(); - } + const char* type_key() const { return type_key_.c_str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -175,19 +146,14 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { }; runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source) { + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source) { auto n = make_object(data, fmt, fmap, type_key, fget_source); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate") -.set_body_typed(SourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") -.set_body_typed(CSourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 4873557b6d5c3..86d1614dc863e 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -22,44 +22,37 @@ * \brief Build SPIRV block */ // Use libspirv for parsing and validating code. -#include #include +#include #include -#include "codegen_spirv.h" -#include "../build_common.h" - -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../runtime/vulkan/vulkan_module.h" +#include "../../runtime/vulkan/vulkan_shader.h" +#include "../build_common.h" +#include "codegen_spirv.h" namespace tvm { namespace codegen { class SPIRVTools { public: - SPIRVTools() { - ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); - } - ~SPIRVTools() { - spvContextDestroy(ctx_); - } + SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); } + ~SPIRVTools() { spvContextDestroy(ctx_); } std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; spv_diagnostic diagnostic; spv_const_binary_t spv_bin{bin.data(), bin.size()}; spv_result_t res; - res = spvBinaryToText( - ctx_, spv_bin.code, spv_bin.wordCount, - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | - SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); + res = + spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); - CHECK_EQ(res, SPV_SUCCESS) - << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; + CHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line + << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; std::string ret(text->str); spvTextDestroy(text); @@ -70,7 +63,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, std::string target) { +runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -84,9 +77,8 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) { CodeGenSPIRV cg; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenSPIRV: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -98,7 +90,14 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) { std::string f_name = global_symbol.value(); VulkanShader shader; - shader.data = cg.BuildFunction(f); + std::string entry = webgpu_restriction ? "main" : f_name; + shader.data = cg.BuildFunction(f, entry); + + if (webgpu_restriction) { + for (auto param : f->params) { + CHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments"; + } + } if (postproc != nullptr) { TVMByteArray arr; @@ -114,12 +113,16 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) { smap[f_name] = std::move(shader); } - return runtime::VulkanModuleCreate( - smap, ExtractFuncInfo(mod), code_data.str()); + return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); } -TVM_REGISTER_GLOBAL("target.build.vulkan") -.set_body_typed(BuildSPIRV); +TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, false); +}); + +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index be058b7306eb6..e76e8bee81fc1 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -21,19 +21,21 @@ * \file codegen_spirv.cc * \brief Generate SPIRV block */ -#include +#include "codegen_spirv.h" + #include +#include + #include -#include "codegen_spirv.h" + #include "../../arith/compute_expr.h" namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { +std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) - << "SPIRV only takes restricted memory model"; + CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t num_buffer = 0; @@ -44,8 +46,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { auto* prim = ptr->element_type.as(); CHECK(prim); DataType value_type = prim->dtype; - spirv::Value arg_value = builder_->BufferArgument( - builder_->GetSType(value_type), 0, num_buffer); + spirv::Value arg_value = + builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer); storage_info_[arg.get()].UpdateContentType(value_type); var_map_[arg.get()] = arg_value; } else { @@ -67,8 +69,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { } spirv::Value ptr = builder_->DeclarePushConstant(value_types); for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant( - ptr, value_types[i], static_cast(i)); + spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; } } @@ -77,12 +78,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - - builder_->CommitKernelFunction( - func_ptr, static_cast(global_symbol.value())); + builder_->CommitKernelFunction(func_ptr, name); return builder_->Finalize(); } @@ -96,15 +92,14 @@ void CodeGenSPIRV::InitFuncState() { builder_->InitHeader(); } -spirv::Value CodeGenSPIRV::GetThreadIndex( - const IterVar& iv, const PrimExpr& extent) { +spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); auto* sizeptr = extent.as(); - CHECK(sizeptr) - << "SPIRV only allows constant thread group size " << " get " << extent; + CHECK(sizeptr) << "SPIRV only allows constant thread group size " + << " get " << extent; CHECK_LT(ts.dim_index, 3); workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { @@ -121,12 +116,12 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { } else if (sync == "shared") { auto type_int = builder_->GetSType(DataType::Int(32)); builder_->MakeInst( - spv::OpControlBarrier, - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast( - spv::MemorySemanticsSequentiallyConsistentMask | - spv::MemorySemanticsWorkgroupMemoryMask))); + spv::OpControlBarrier, + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, + static_cast(spv::MemorySemanticsSequentiallyConsistentMask | + spv::MemorySemanticsWorkgroupMemoryMask))); } else { LOG(FATAL) << "Do not support sync " << sync; } @@ -230,8 +225,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { - return builder_->Select(MakeValue(op->condition), - MakeValue(op->true_value), + return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value), MakeValue(op->false_value)); } @@ -245,14 +239,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = static_cast( - op->args[0].as()->value); + uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450( - builder_->GetSType(op->dtype), inst_id, values); + return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); } else if (op->is_intrinsic(CallNode::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -303,10 +295,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Label then_label = builder_->NewLabel(); spirv::Label else_label = builder_->NewLabel(); spirv::Label merge_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block, must get label after we see the value builder_->StartLabel(then_label); spirv::Value then_value = MakeValue(op->args[1]); @@ -324,19 +314,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->is_intrinsic("popcount")) { - return builder_->MakeValue( - spv::OpBitCount, - builder_->GetSType(op->dtype), - MakeValue(op->args[0])); + return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), + MakeValue(op->args[0])); } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { + LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } @@ -350,8 +334,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { for (int i = 0; i < op->lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue( - make_const(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -379,8 +362,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -390,18 +372,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK_EQ(info.content_type, op->dtype) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } else { if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. std::vector values; auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); - values.emplace_back( - builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); }; this->Scalarize(op->index, f); return builder_->Concat(values); @@ -410,13 +389,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->dtype.lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = analyzer_->Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } } @@ -427,8 +404,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { return spirv::Value(); } -void CodeGenSPIRV::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; @@ -438,8 +414,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, spirv::SType etype = builder_->GetSType(e.dtype().element_of()); spirv::Value value = MakeValue(e); for (int i = 0; i < e.dtype().lanes(); ++i) { - f(i, builder_->MakeValue( - spv::OpCompositeExtract, etype, value, i)); + f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i)); } } } @@ -457,8 +432,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -469,17 +443,14 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK_EQ(info.content_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); } else { if (op->value.dtype().element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue( - spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, elem, mask); }; this->Scalarize(op->index, f); @@ -488,13 +459,11 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->value.dtype().lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = analyzer_->Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); builder_->MakeInst(spv::OpStore, ptr, value, mask); return; } @@ -522,14 +491,11 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); - uint32_t control = ( - op->for_type == ForType::Unrolled ? - spv::LoopControlUnrollMask : spv::LoopControlMaskNone); - builder_->MakeInst( - spv::OpLoopMerge, merge_label, continue_label, control); - builder_->MakeInst( - spv::OpBranchConditional, loop_cond, body_label, merge_label, - weight_likely_branch_, 1); + uint32_t control = + (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); + builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, + weight_likely_branch_, 1); // loop body builder_->StartLabel(body_label); @@ -539,10 +505,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop continue builder_->StartLabel(continue_label); - spirv::Value one = - op->loop_var.dtype().is_int() ? - builder_->IntImm(loop_var.stype, 1) : - builder_->UIntImm(loop_var.stype, 1); + spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); builder_->MakeInst(spv::OpBranch, head_label); @@ -556,10 +520,8 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Label merge_label = builder_->NewLabel(); if (op->else_case.defined()) { spirv::Label else_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -569,11 +531,9 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { this->VisitStmt(op->else_case); builder_->MakeInst(spv::OpBranch, merge_label); } else { - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, merge_label, - weight_likely_branch_, 1); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label, + weight_likely_branch_, 1); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -587,23 +547,20 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); CHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; spirv::SType etype = builder_->GetSType(op->dtype); if (info.scope.rank == runtime::StorageRank::kLocal) { - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassFunction); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); } else { // shared memory CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassWorkgroup); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); } CHECK(!info.content_fixed); info.UpdateContentType(op->dtype); @@ -624,8 +581,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); - storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -653,9 +609,7 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index b51e8edf027aa..a8af29a194d5a 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -26,15 +26,16 @@ #include #include -#include #include +#include -#include #include +#include #include +#include -#include "ir_builder.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_builder.h" namespace tvm { namespace codegen { @@ -44,24 +45,22 @@ using namespace tir; /*! * \brief Code generator into SPIRV */ -class CodeGenSPIRV: - public ExprFunctor, - public StmtFunctor { +class CodeGenSPIRV : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Compile and add function f to the current module. * \param f The function to be added. + * \param name The name of the target function. * \return The final spirv module. */ - virtual std::vector BuildFunction(const PrimFunc& f); + virtual std::vector BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. * \return created value. */ - spirv::Value MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); } // override codegen spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; @@ -116,8 +115,7 @@ class CodeGenSPIRV: // Update content type if it hasn't beenupdated. void UpdateContentType(DataType type) { if (content_fixed) { - CHECK_EQ(type, content_type) - << "Cannot use two different content type in GLSL model"; + CHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; } else { this->content_type = type; content_fixed = true; @@ -129,8 +127,7 @@ class CodeGenSPIRV: // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); spirv::Value CreateStorageSync(const CallNode* op); - void Scalarize(const PrimExpr& e, - std::function f); + void Scalarize(const PrimExpr& e, std::function f); // The builder std::unique_ptr builder_; // Work group size of three @@ -148,5 +145,4 @@ class CodeGenSPIRV: } // namespace codegen } // namespace tvm - #endif // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_ diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ead6952b434ef..6b31bd73e05a9 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -20,9 +20,9 @@ /*! * \file intrin_rule_spirv.cc */ +#include #include #include -#include namespace tvm { namespace codegen { @@ -31,7 +31,7 @@ namespace spirv { using namespace runtime; // num_signature means number of arguments used to query signature -template +template inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -43,39 +43,55 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::CallNode::make(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin); + +// WebGPU rules. +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor") + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round") + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc") + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh").set_body(DispatchGLSLPureIntrin); } // namespace spirv } // namespace codegen diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bf43f11cce02d..305464ac398b2 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -32,10 +32,14 @@ namespace spirv { void IRBuilder::InitHeader() { CHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); - // Use SPIR-V v1.0. This needs to be kept in sync (or at least behind) - // `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API - // validation passes. + + // Use the spirv version as indicated in the SDK. +#if SPV_VERSION >= 0x10300 + header_.push_back(0x10300); +#else header_.push_back(0x10000); +#endif + // generator: set to 0, unknown header_.push_back(0U); // Bound: set during Finalize @@ -45,9 +49,9 @@ void IRBuilder::InitHeader() { // shader ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_); // memory model - ib_.Begin(spv::OpMemoryModel).AddSeq( - spv::AddressingModelLogical, - spv::MemoryModelGLSL450).Commit(&entry_); + ib_.Begin(spv::OpMemoryModel) + .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) + .Commit(&entry_); this->InitPreDefs(); } @@ -62,8 +66,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeFunction) - .AddSeq(t_void_func_, t_void_).Commit(&global_); + ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_); } SType IRBuilder::GetSType(const DataType& dtype) { @@ -89,8 +92,7 @@ SType IRBuilder::GetSType(const DataType& dtype) { return t; } -SType IRBuilder::GetPointerType(const SType& value_type, - spv::StorageClass storage_class) { +SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) { CHECK_NE(storage_class, spv::StorageClassMax); auto key = std::make_pair(value_type.id, storage_class); auto it = pointer_type_tbl_.find(key); @@ -102,14 +104,12 @@ SType IRBuilder::GetPointerType(const SType& value_type, t.type = DataType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; - ib_.Begin(spv::OpTypePointer) - .AddSeq(t, storage_class, value_type).Commit(&global_); + ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_); pointer_type_tbl_[key] = t; return t; } -SType IRBuilder::GetStructArrayType(const SType& value_type, - uint32_t num_elems) { +SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) { auto key = std::make_pair(value_type.id, num_elems); auto it = struct_array_type_tbl_.find(key); if (it != struct_array_type_tbl_.end()) { @@ -123,54 +123,50 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, if (num_elems != 0) { Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); - ib_.Begin(spv::OpTypeArray) - .AddSeq(arr_type, value_type, length).Commit(&global_); + ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_); } else { - ib_.Begin(spv::OpTypeRuntimeArray) - .AddSeq(arr_type, value_type).Commit(&global_); + ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } int nbits = value_type.type.bits() * value_type.type.lanes(); CHECK_EQ(nbits % 8, 0); uint32_t nbytes = static_cast(nbits) / 8; // decorate the array type. - this->Decorate(spv::OpDecorate, - arr_type, spv::DecorationArrayStride, nbytes); + this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); // declare struct of array SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; - ib_.Begin(spv::OpTypeStruct) - .AddSeq(struct_type, arr_type).Commit(&global_); + ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); // decorate the array type. ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, 0, spv::DecorationOffset, 0) .Commit(&decorate_); + +#if SPV_VERSION < 0x10300 + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. // runtime array are always decorated as BufferBlock(shader storage buffer) if (num_elems == 0) { - this->Decorate(spv::OpDecorate, - struct_type, spv::DecorationBufferBlock); + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); } +#else + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); +#endif struct_array_type_tbl_[key] = struct_type; return struct_type; } -Value IRBuilder::StructArrayAccess(const SType& res_type, - Value buffer, - Value index) { +Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) { CHECK(buffer.flag == kStructArrayPtr); - return MakeValue(spv::OpInBoundsAccessChain, - res_type, buffer, - const_i32_zero_, index); + return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index); } Value IRBuilder::IntImm(const SType& dtype, int64_t value) { return GetConst_(dtype, reinterpret_cast(&value)); } -Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { - return GetConst_(dtype, &value); -} +Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { return GetConst_(dtype, &value); } Value IRBuilder::FloatImm(const SType& dtype, double value) { if (dtype.type.bits() == 64) { @@ -182,23 +178,28 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { return GetConst_(dtype, &data); } else { CHECK_EQ(dtype.type.bits(), 16); - return Cast(dtype, - FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); } } -Value IRBuilder::BufferArgument(const SType& value_type, - uint32_t descriptor_set, +Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. +#if SPV_VERSION >= 0x10300 + spv::StorageClass storage_class = spv::StorageClassStorageBuffer; +#else + spv::StorageClass storage_class = spv::StorageClassUniform; +#endif + SType sarr_type = GetStructArrayType(value_type, 0); - SType ptr_type = GetPointerType(sarr_type, spv::StorageClassUniform); + SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassUniform).Commit(&global_); - this->Decorate(spv::OpDecorate, - val, spv::DecorationDescriptorSet, descriptor_set); - this->Decorate(spv::OpDecorate, - val, spv::DecorationBinding, binding); + + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); + + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); return val; } @@ -220,37 +221,30 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { .Commit(&decorate_); DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); - CHECK_EQ(nbits % 8 , 0); + CHECK_EQ(nbits % 8, 0); offset += nbits / 8; } // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType( - struct_type, spv::StorageClassPushConstant); + SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); Value val = NewValue(ptr_type, kPushConstantPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); return val; } -Value IRBuilder::GetPushConstant( - Value ptr_push_const, const SType& v_type, uint32_t index) { +Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); - Value ptr = this->MakeValue( - spv::OpAccessChain, ptr_vtype, ptr_push_const, - IntImm(t_int32_, static_cast(index))); + Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, + IntImm(t_int32_, static_cast(index))); return this->MakeValue(spv::OpLoad, v_type, ptr); } -Value IRBuilder::NewFunction() { - return NewValue(t_void_func_, kFunction); -} +Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { CHECK_EQ(func.flag, kFunction); - ib_.Begin(spv::OpEntryPoint) - .AddSeq(spv::ExecutionModelGLCompute, func, name); + ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); if (workgroup_id_.id != 0) { ib_.Add(workgroup_id_); } @@ -262,34 +256,31 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) void IRBuilder::StartFunction(const Value& func) { CHECK_EQ(func.flag, kFunction); - this->MakeInst( - spv::OpFunction, t_void_, func, 0, t_void_func_); + // add function declaration to the header. + ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_); + spirv::Label start_label = this->NewLabel(); - this->StartLabel(start_label); + ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_); + curr_label_ = start_label; } -void IRBuilder::SetLocalSize(const Value& func, - uint32_t local_size[3]) { +void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { CHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpExecutionMode) - .AddSeq(func, spv::ExecutionModeLocalSize, - local_size[0], local_size[1], local_size[2]) + .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2]) .Commit(&exec_mode_); } -Value IRBuilder::Allocate(const SType& value_type, - uint32_t num_elems, +Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class) { CHECK_NE(num_elems, 0U); SType sarr_type = GetStructArrayType(value_type, num_elems); SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); if (storage_class == spv::StorageClassFunction) { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&function_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&func_header_); } else { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); } return val; } @@ -297,19 +288,16 @@ Value IRBuilder::Allocate(const SType& value_type, Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { if (workgroup_id_.id == 0) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); - SType ptr_type = this->GetPointerType( - vec3_type, spv::StorageClassInput); + SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); workgroup_id_ = NewValue(ptr_type, kVectorPtr); ib_.Begin(spv::OpVariable) .AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput) .Commit(&global_); - this->Decorate(spv::OpDecorate, workgroup_id_, - spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); + this->Decorate(spv::OpDecorate, workgroup_id_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, workgroup_id_, - IntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, workgroup_id_, + IntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -318,16 +306,13 @@ Value IRBuilder::GetLocalID(uint32_t dim_index) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); local_id_ = NewValue(ptr_type, kVectorPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, local_id_, spv::StorageClassInput) - .Commit(&global_); - this->Decorate(spv::OpDecorate, local_id_, - spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, local_id_, spv::StorageClassInput).Commit(&global_); + this->Decorate(spv::OpDecorate, local_id_, spv::DecorationBuiltIn, + spv::BuiltInLocalInvocationId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, local_id_, - UIntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, local_id_, + UIntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -354,9 +339,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type.bits() > 32) { if (dtype.type.is_int()) { int64_t sign_mask = 0xFFFFFFFFL; - const int64_t* sign_ptr = - reinterpret_cast(pvalue); - ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); + const int64_t* sign_ptr = reinterpret_cast(pvalue); + ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); } else { ib_.Add(static_cast((pvalue[0] >> 32UL) & mask)); } @@ -390,8 +374,7 @@ SType IRBuilder::DeclareType(const DataType& dtype) { t.id = id_counter_++; t.type = dtype; SType base_type = GetSType(dtype.element_of()); - ib_.Begin(spv::OpTypeVector).AddSeq( - t, base_type, dtype.lanes()).Commit(&global_); + ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); return t; } } @@ -411,12 +394,10 @@ PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { return phi; } -Value IRBuilder::CallGLSL450(const SType& ret_type, - uint32_t inst_id, +Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args) { Value val = NewValue(ret_type, kNormal); - ib_.Begin(spv::OpExtInst) - .AddSeq(ret_type, val, ext_glsl450_, inst_id); + ib_.Begin(spv::OpExtInst).AddSeq(ret_type, val, ext_glsl450_, inst_id); for (const Value& v : args) { ib_.Add(v); } @@ -486,14 +467,12 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return MakeValue(spv::OpUConvert, dst_type, value); } else if (from.is_uint() && to.is_int()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_int() && to.is_uint()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_float() && to.is_int()) { @@ -507,21 +486,20 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (from.is_float() && to.is_float()) { return MakeValue(spv::OpFConvert, dst_type, value); } else { - LOG(FATAL) << "do not support type cast from " - << from << " to " << to; + LOG(FATAL) << "do not support type cast from " << from << " to " << to; return Value(); } } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI ## _Op, a.stype, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF ## _Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, a.stype, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ @@ -554,19 +532,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -574,17 +552,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index bdfea4ff7f1c6..c52f92fd7c207 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -27,14 +27,15 @@ #include #include +// clang-format off #include -#include -#include -#include #include +#include #include - +#include +#include #include +// clang-format on namespace tvm { namespace codegen { @@ -85,9 +86,7 @@ struct Label { class Instr { public: /*! \return the word count */ - uint32_t WordCount() const { - return word_count_; - } + uint32_t WordCount() const { return word_count_; } /*! * \brief Access idx-th word of instruction * \param idx The index @@ -122,9 +121,7 @@ struct PhiValue : public Value { * \param value The value to come * \param parent The parent label. */ - void SetIncoming(uint32_t index, - const Value& value, - const Label& parent) { + void SetIncoming(uint32_t index, const Value& value, const Label& parent) { CHECK_EQ(this->stype.id, value.stype.id); instr[3 + index * 2] = value.id; instr[3 + index * 2 + 1] = parent.id; @@ -203,12 +200,10 @@ class InstrBuilder { */ InstrBuilder& Add(const std::string& v) { const uint32_t kWordSize = sizeof(uint32_t); - uint32_t nwords = - (static_cast(v.length()) + kWordSize) / kWordSize; + uint32_t nwords = (static_cast(v.length()) + kWordSize) / kWordSize; size_t begin = data_.size(); data_.resize(begin + nwords, 0U); - std::copy(v.begin(), v.end(), - reinterpret_cast(&data_[begin])); + std::copy(v.begin(), v.end(), reinterpret_cast(&data_[begin])); return *this; } /*! @@ -217,8 +212,8 @@ class InstrBuilder { * \return reference to self. * \tparams Args The positional arguments */ - template - InstrBuilder& AddSeq(Args&& ...args) { + template + InstrBuilder& AddSeq(Args&&... args) { AddSeqHelper helper; helper.builder = this; runtime::detail::for_each(helper, std::forward(args)...); @@ -252,7 +247,7 @@ class InstrBuilder { // The reference to builder InstrBuilder* builder; // invoke function - template + template void operator()(size_t, const T& v) const { builder->Add(v); } @@ -301,6 +296,7 @@ class IRBuilder { data.insert(data.end(), debug_.begin(), debug_.end()); data.insert(data.end(), decorate_.begin(), decorate_.end()); data.insert(data.end(), global_.begin(), global_.end()); + data.insert(data.end(), func_header_.begin(), func_header_.end()); data.insert(data.end(), function_.begin(), function_.end()); return data; } @@ -322,17 +318,15 @@ class IRBuilder { curr_label_ = label; } /*! \return The current label */ - Label CurrentLabel() const { - return curr_label_; - } + Label CurrentLabel() const { return curr_label_; } /*! * \brief Add code to debug segment. * \param op The operator * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Debug(spv::Op op, Args&& ...args) { + template + void Debug(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&debug_); } /*! @@ -341,10 +335,9 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void ExecutionMode(Value func, Args&& ...args) { - ib_.Begin(spv::OpExecutionMode).AddSeq( - func, std::forward(args)...).Commit(&exec_mode_); + template + void ExecutionMode(Value func, Args&&... args) { + ib_.Begin(spv::OpExecutionMode).AddSeq(func, std::forward(args)...).Commit(&exec_mode_); } /*! * \brief Add code to decorate segment. @@ -352,8 +345,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Decorate(spv::Op op, Args&& ...args) { + template + void Decorate(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -362,8 +355,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void DeclareGlobal(spv::Op op, Args&& ...args) { + template + void DeclareGlobal(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -374,8 +367,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Instr MakeInst(spv::Op op, Args&& ...args) { + template + Instr MakeInst(spv::Op op, Args&&... args) { return ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&function_); } /*! @@ -387,8 +380,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) { + template + Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) { Value val = NewValue(out_type, kNormal); MakeInst(op, out_type, val, std::forward(args)...); return val; @@ -409,9 +402,7 @@ class IRBuilder { * \param args The arguments * \return The result value. */ - Value CallGLSL450(const SType& ret_type, - uint32_t inst_id, - const std::vector& args); + Value CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args); /*! * \brief Build vector by concatenating components * @@ -431,8 +422,7 @@ class IRBuilder { * \param storage_class The storage class * \return The corresponding spirv type. */ - SType GetPointerType(const SType& value_type, - spv::StorageClass storage_class); + SType GetPointerType(const SType& value_type, spv::StorageClass storage_class); /*! * \brief Get a struct{ value_type[num_elems] } type. * \param value_type the content value type. @@ -441,17 +431,14 @@ class IRBuilder { * * \return The corresponding spirv type. */ - SType GetStructArrayType(const SType& value_type, - uint32_t num_elems); + SType GetStructArrayType(const SType& value_type, uint32_t num_elems); /*! * \brief Get a struct array access with a given index. * \param ptr_type The pointer type. * \param buffer The buffer ptr to struct array * \param index The array index. */ - Value StructArrayAccess(const SType& ptr_type, - Value buffer, - Value index); + Value StructArrayAccess(const SType& ptr_type, Value buffer, Value index); /*! * \brief Create a cast that cast value to dst_type * \param dst_type The target type. @@ -485,9 +472,7 @@ class IRBuilder { * \param binding The binding locaiton in descriptor set. * \param The argument type. */ - Value BufferArgument(const SType& value_type, - uint32_t descriptor_set, - uint32_t binding); + Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); /*! * \brief Declare POD arguments through push constants. * @@ -533,9 +518,7 @@ class IRBuilder { * \param num_elems Number of elements to allocate. * \param storage_class The storage class we want to store to. */ - Value Allocate(const SType& value_type, - uint32_t num_elems, - spv::StorageClass storage_class); + Value Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class); /* * \brief Get the i-th workgroup id. * \return The value representing the workgroup id. @@ -610,8 +593,10 @@ class IRBuilder { std::vector debug_; /*! \brief Annotation segment */ std::vector decorate_; - /*! \brief Global segment: types, variables, types */ + /*! \brief Global segment: types, variables, types */ std::vector global_; + /*! \brief Function header segment */ + std::vector func_header_; /*! \brief Function segment */ std::vector function_; }; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index b125f374ce393..6dd2ca0ecb6c4 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -20,14 +20,17 @@ /*! * \file codegen_stackvm.cc */ -#include -#include +#include "codegen_stackvm.h" + #include -#include +#include +#include #include +#include + #include #include -#include "codegen_stackvm.h" + #include "../../runtime/stackvm/stackvm_module.h" namespace tvm { @@ -40,19 +43,32 @@ using namespace tir; StackVM::StructFieldKind MapFieldKind(int64_t kind) { auto val = static_cast(kind); switch (val) { - case intrinsic::kArrData: return StackVM::kArrData; - case intrinsic::kArrShape: return StackVM::kArrShape; - case intrinsic::kArrAddr: return StackVM::kArrAddr; - case intrinsic::kArrStrides: return StackVM::kArrStrides; - case intrinsic::kArrNDim: return StackVM::kArrNDim; - case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode; - case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits; - case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes; - case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset; - case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId; - case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType; - case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent; - default: LOG(FATAL) << "Do not know how to map field " << kind; + case intrinsic::kArrData: + return StackVM::kArrData; + case intrinsic::kArrShape: + return StackVM::kArrShape; + case intrinsic::kArrAddr: + return StackVM::kArrAddr; + case intrinsic::kArrStrides: + return StackVM::kArrStrides; + case intrinsic::kArrNDim: + return StackVM::kArrNDim; + case intrinsic::kArrTypeCode: + return StackVM::kArrTypeCode; + case intrinsic::kArrTypeBits: + return StackVM::kArrTypeBits; + case intrinsic::kArrTypeLanes: + return StackVM::kArrTypeLanes; + case intrinsic::kArrByteOffset: + return StackVM::kArrByteOffset; + case intrinsic::kArrDeviceId: + return StackVM::kArrDeviceId; + case intrinsic::kArrDeviceType: + return StackVM::kArrDeviceType; + case intrinsic::kTVMValueContent: + return StackVM::kTVMValueContent; + default: + LOG(FATAL) << "Do not know how to map field " << kind; } return StackVM::kArrData; } @@ -84,8 +100,7 @@ void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { } void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { - CHECK(operand >= std::numeric_limits::min() && - operand <= std::numeric_limits::max()); + CHECK(operand >= std::numeric_limits::min() && operand <= std::numeric_limits::max()); vm_.code.at(operand_index).v_int = static_cast(operand); } @@ -120,8 +135,7 @@ int CodeGenStackVM::AllocVarID(const VarNode* v) { int CodeGenStackVM::GetVarID(const VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } @@ -161,7 +175,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->Push(l->index); @@ -261,9 +275,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { } } -void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b) { +void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) { this->Push(a); this->Push(b); DataType t = a.dtype(); @@ -295,7 +307,7 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { CHECK(op->value >= std::numeric_limits::min() && op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); + this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { @@ -312,25 +324,15 @@ void CodeGenStackVM::VisitExpr_(const CastNode* op) { PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const AddNode* op) { - PushBinary(StackVM::ADD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const SubNode* op) { - PushBinary(StackVM::SUB_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const MulNode* op) { - PushBinary(StackVM::MUL_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const DivNode* op) { - PushBinary(StackVM::DIV_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const ModNode* op) { - PushBinary(StackVM::MOD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const MinNode* op) { this->Push(op->a); @@ -350,22 +352,16 @@ void CodeGenStackVM::VisitExpr_(const MaxNode* op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const EQNode* op) { - PushBinary(StackVM::EQ_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const LENode* op) { - PushBinary(StackVM::LE_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const NENode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const LTNode* op) { - PushBinary(StackVM::LT_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const GENode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); @@ -431,7 +427,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) { +void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { @@ -482,9 +478,7 @@ void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const RampNode* op) { - LOG(FATAL) << "Ramp is not supported"; -} +void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported"; } void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) { LOG(FATAL) << "Broadcast is not supported"; @@ -506,9 +500,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { - this->Push(op->body); -} +void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); } void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->value); @@ -521,17 +513,15 @@ runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { std::unordered_map fmap; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenStackVM: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); - CHECK(!fmap.count(f_name)) - << "Function name " << f_name << "already exist in list"; + CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; fmap[f_name] = std::move(vm); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { @@ -542,7 +532,6 @@ runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { return runtime::StackVMModuleCreate(fmap, entry_func); } -TVM_REGISTER_GLOBAL("target.build.stackvm") -.set_body_typed(BuildStackVM); +TVM_REGISTER_GLOBAL("target.build.stackvm").set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 31036822649d6..b77c40696de6c 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -24,12 +24,14 @@ #ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ #define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ +#include #include +#include #include -#include + #include -#include #include +#include #include "../../runtime/stackvm/stackvm.h" @@ -44,11 +46,10 @@ using runtime::StackVM; * This module is used to generate host wrapper * into device function when only device JIT is available. */ -class CodeGenStackVM - : public ExprFunctor, - public StmtFunctor { +class CodeGenStackVM : public ExprFunctor, + public StmtFunctor { public: - /*! + /*! * \brief Generate a stack VM representing * \param f The function to be compiled * \param device_funcs The extern device functions to be linked. @@ -59,9 +60,7 @@ class CodeGenStackVM /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ - void Push(const PrimExpr& n) { - VisitExpr(n); - } + void Push(const PrimExpr& n) { VisitExpr(n); } /*! * \brief Push the opcode to the code. * \param opcode The code to be pushed. @@ -81,9 +80,7 @@ class CodeGenStackVM */ void SetOperand(int64_t operand_index, int64_t operand); /*! \return The current program pointer */ - int64_t GetPC() const { - return static_cast(vm_.code.size()); - } + int64_t GetPC() const { return static_cast(vm_.code.size()); } /*! * \brief Get string id in vm * \param key The string to get id. @@ -103,9 +100,7 @@ class CodeGenStackVM */ int GetVarID(const VarNode* v) const; // Push binary operator - void PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b); + void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b); // push cast; void PushCast(DataType dst, DataType src); // overloadable functions diff --git a/src/target/target.cc b/src/target/target.cc index a72ce1c5b3e4c..010a14a19979f 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,11 +21,9 @@ * \file src/target/target.cc */ #include - -#include #include +#include #include - #include #include @@ -33,28 +31,27 @@ namespace tvm { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; TVM_REGISTER_NODE_TYPE(TargetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->str(); + }); /*! -* \brief Construct a Target node from the given name and options. -* \param target_name The major target name. Should be one of -* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", -* "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} -* \param options Additional options appended to the target -* \return The constructed Target -*/ -Target CreateTarget(const std::string& target_name, - const std::vector& options) { + * \brief Construct a Target node from the given name and options. + * \param target_name The major target name. Should be one of + * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", + * "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} + * \param options Additional options appended to the target + * \return The constructed Target + */ +Target CreateTarget(const std::string& target_name, const std::vector& options) { auto t = make_object(); t->target_name = target_name; @@ -110,11 +107,13 @@ Target CreateTarget(const std::string& target_name, if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; } - } else if (target_name == "metal" || target_name == "vulkan") { + } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") { if (target_name == "metal") { t->device_type = kDLMetal; - } else { + } else if (target_name == "vulkan") { t->device_type = kDLVulkan; + } else { + t->device_type = kDLWebGPU; } t->keys_array.push_back(target_name); t->keys_array.push_back("gpu"); @@ -139,16 +138,18 @@ Target CreateTarget(const std::string& target_name, } else if (target_name == "hexagon") { t->keys_array.push_back("hexagon"); t->device_type = kDLHexagon; + } else if (target_name == "webgpu") { + t->keys_array.push_back("webgpu"); + t->device_type = kDLWebGPU; } else { - LOG(ERROR) << "Unknown target name " << target_name; + LOG(ERROR) << "Unknown target name " << target_name << "; falling back to stackvm"; return target::stackvm(); } return Target(t); } -TVM_REGISTER_GLOBAL("target.TargetCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_name = args[0]; std::vector options; for (int i = 1; i < args.num_args; ++i) { @@ -157,13 +158,12 @@ TVM_REGISTER_GLOBAL("target.TargetCreate") } *ret = CreateTarget(target_name, options); - }); +}); -TVM_REGISTER_GLOBAL("target.TargetFromString") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; *ret = Target::Create(target_str); - }); +}); std::vector TargetNode::keys() const { std::vector result; @@ -193,14 +193,13 @@ const std::string& TargetNode::str() const { if (str_repr_.length() != 0) return str_repr_; std::ostringstream result; result << target_name; - for (const auto &x : options()) { + for (const auto& x : options()) { result << " " << x; } str_repr_ = result.str(); return str_repr_; } - bool StartsWith(const std::string& str, const std::string& pattern) { return str.compare(0, pattern.length(), pattern) == 0; } @@ -250,104 +249,75 @@ struct TVMTargetThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; void Target::EnterWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); entry->context_stack.push(*this); } void Target::ExitWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } tvm::Target Target::Current(bool allow_not_defined) { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } CHECK(allow_not_defined) - << "Target context required. Please set it by constructing a TargetContext"; + << "Target context required. Please set it by constructing a TargetContext"; return Target(); } -TVM_REGISTER_GLOBAL("target.GetCurrentTarget") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; *ret = Target::Current(allow_not_defined); - }); +}); class Target::Internal { public: - static void EnterScope(Target target) { - target.EnterWithScope(); - } - static void ExitScope(Target target) { - target.ExitWithScope(); - } + static void EnterScope(Target target) { target.EnterWithScope(); } + static void ExitScope(Target target) { target.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("target.EnterTargetScope") -.set_body_typed(Target::Internal::EnterScope); +TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope); -TVM_REGISTER_GLOBAL("target.ExitTargetScope") -.set_body_typed(Target::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope); namespace target { std::vector MergeOptions(std::vector opts, - const std::vector& new_opts) { + const std::vector& new_opts) { opts.insert(opts.end(), new_opts.begin(), new_opts.end()); return opts; } -Target llvm(const std::vector& options) { - return CreateTarget("llvm", options); -} +Target llvm(const std::vector& options) { return CreateTarget("llvm", options); } -Target cuda(const std::vector& options) { - return CreateTarget("cuda", options); -} +Target cuda(const std::vector& options) { return CreateTarget("cuda", options); } -Target rocm(const std::vector& options) { - return CreateTarget("rocm", options); -} +Target rocm(const std::vector& options) { return CreateTarget("rocm", options); } -Target opencl(const std::vector& options) { - return CreateTarget("opencl", options); -} +Target opencl(const std::vector& options) { return CreateTarget("opencl", options); } -Target metal(const std::vector& options) { - return CreateTarget("metal", options); -} +Target metal(const std::vector& options) { return CreateTarget("metal", options); } Target mali(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=mali" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=mali"})); } Target intel_graphics(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=intel_graphics" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"})); } -Target stackvm(const std::vector& options) { - return CreateTarget("stackvm", options); -} +Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } -Target ext_dev(const std::vector& options) { - return CreateTarget("ext_dev", options); -} +Target ext_dev(const std::vector& options) { return CreateTarget("ext_dev", options); } -Target hexagon(const std::vector& options) { - return CreateTarget("hexagon", options); -} +Target hexagon(const std::vector& options) { return CreateTarget("hexagon", options); } } // namespace target -BuildConfig BuildConfig::Create() { - return BuildConfig(make_object()); -} +BuildConfig BuildConfig::Create() { return BuildConfig(make_object()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMBuildConfigThreadLocalEntry { @@ -357,28 +327,26 @@ struct TVMBuildConfigThreadLocalEntry { /*! \brief The current build config context */ std::stack context_stack; - TVMBuildConfigThreadLocalEntry() : - default_config(BuildConfig::Create()) { - } + TVMBuildConfigThreadLocalEntry() : default_config(BuildConfig::Create()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; void BuildConfig::EnterWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); entry->context_stack.push(*this); } void BuildConfig::ExitWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } tvm::BuildConfig BuildConfig::Current() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + TVMBuildConfigThreadLocalEntry* entry = TVMBuildConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } @@ -389,80 +357,73 @@ tvm::BuildConfig BuildConfig::Current() { TVM_REGISTER_NODE_TYPE(BuildConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "build_config("; - p->stream << "data_alignment=" << op->data_alignment << ", "; - p->stream << "offset_factor=" << op->offset_factor << ", "; - p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; - p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; - p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; - p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; - p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; - p->stream << "restricted_func=" << op->restricted_func << ", "; - p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; - p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; - p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; - p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; - p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; - p->stream << "disable_vectorize=" << op->disable_vectorize; - p->stream << "disable_assert=" << op->disable_assert; - p->stream << ")"; -}); - -TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig") -.set_body([](TVMArgs args, TVMRetValue* ret) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "build_config("; + p->stream << "data_alignment=" << op->data_alignment << ", "; + p->stream << "offset_factor=" << op->offset_factor << ", "; + p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; + p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; + p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; + p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; + p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; + p->stream << "restricted_func=" << op->restricted_func << ", "; + p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; + p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; + p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; + p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; + p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; + p->stream << "disable_vectorize=" << op->disable_vectorize; + p->stream << "disable_assert=" << op->disable_assert; + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = BuildConfig::Current(); - }); +}); class BuildConfig::Internal { public: - static void EnterScope(BuildConfig target) { - target.EnterWithScope(); - } - static void ExitScope(BuildConfig target) { - target.ExitWithScope(); - } + static void EnterScope(BuildConfig target) { target.EnterWithScope(); } + static void ExitScope(BuildConfig target) { target.ExitWithScope(); } }; TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope") -.set_body_typed(BuildConfig::Internal::EnterScope); + .set_body_typed(BuildConfig::Internal::EnterScope); -TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") -.set_body_typed(BuildConfig::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope").set_body_typed(BuildConfig::Internal::ExitScope); TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig cfg = args[0]; - std::vector> add_lower_pass; - CHECK_EQ(args.size() % 2, 1); - for (int i = 1; i < args.size(); i += 2) { - add_lower_pass.push_back(std::make_pair( - args[i].operator int(), - args[i + 1].operator transform::Pass())); - } - cfg->add_lower_pass = add_lower_pass; - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + BuildConfig cfg = args[0]; + std::vector> add_lower_pass; + CHECK_EQ(args.size() % 2, 1); + for (int i = 1; i < args.size(); i += 2) { + add_lower_pass.push_back( + std::make_pair(args[i].operator int(), args[i + 1].operator transform::Pass())); + } + cfg->add_lower_pass = add_lower_pass; + }); TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - // Return one of the following: - // * Size of add_lower_pass if num_args == 1 - // * Phase index of pass if args are (config, index, true) - // * Function of pass if args are (config, index, false) - BuildConfig cfg = args[0]; - if (args.num_args == 1) { - *ret = static_cast(cfg->add_lower_pass.size()); - } else { - int index = args[1]; - bool get_phase = args[2]; - auto item = cfg->add_lower_pass[index]; - if (get_phase) { - *ret = item.first; - } else { - *ret = item.second; - } - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + // Return one of the following: + // * Size of add_lower_pass if num_args == 1 + // * Phase index of pass if args are (config, index, true) + // * Function of pass if args are (config, index, false) + BuildConfig cfg = args[0]; + if (args.num_args == 1) { + *ret = static_cast(cfg->add_lower_pass.size()); + } else { + int index = args[1]; + bool get_phase = args[2]; + auto item = cfg->add_lower_pass[index]; + if (get_phase) { + *ret = item.first; + } else { + *ret = item.second; + } + } + }); } // namespace tvm diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 73fe011cc9364..5ebb7edc80dc4 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -20,21 +20,21 @@ /*! * \file target/target_info.cc */ -#include #include +#include #include namespace tvm { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "mem-info(" - << "unit_bits=" << op->unit_bits << ", " - << "max_num_bits=" << op->max_num_bits << ", " - << "max_simd_bits=" << op->max_simd_bits << ", " - << "head_address=" << op->head_address << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "mem-info(" + << "unit_bits=" << op->unit_bits << ", " + << "max_num_bits=" << op->max_num_bits << ", " + << "max_simd_bits=" << op->max_simd_bits << ", " + << "head_address=" << op->head_address << ")"; + }); TVM_REGISTER_NODE_TYPE(MemoryInfoNode); diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index b1c97e3cecb8b..874a512250994 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -21,10 +21,12 @@ * \file ad_util.cc * \brief Utility for tensor-level auto-differentiation. */ +#include "ad_util.h" + #include #include + #include -#include "ad_util.h" namespace tvm { namespace te { @@ -34,8 +36,7 @@ std::pair, Map> CloneIterVars(const Array Map vmap; for (const IterVar& iv : vars) { IterVar new_v = - IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), - iv->iter_type, iv->thread_tag); + IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag); new_vars.push_back(new_v); vmap.Set(iv->var, new_v->var); } @@ -53,8 +54,8 @@ PrimExpr CloneReduction(const PrimExpr& expr) { src_with_newaxis.push_back(tir::Substitute(src, vmap)); } - return ReduceNode::make(red->combiner, src_with_newaxis, - new_axis, tir::Substitute(red->condition, vmap), red->value_index); + return ReduceNode::make(red->combiner, src_with_newaxis, new_axis, + tir::Substitute(red->condition, vmap), red->value_index); } else { return expr; } diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h index 7e511b1c5a22e..56ab6c18b929c 100644 --- a/src/te/autodiff/ad_util.h +++ b/src/te/autodiff/ad_util.h @@ -24,11 +24,12 @@ #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ #define TVM_TE_AUTODIFF_AD_UTIL_H_ -#include #include -#include +#include + #include #include +#include namespace tvm { namespace te { diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 0c54764e601ad..4afca681deebf 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -30,11 +30,12 @@ * (3) and sum them together to get the adjoint of the input itself. * The three steps are computed recursively. */ +#include +#include #include #include #include -#include -#include + #include #include @@ -47,27 +48,25 @@ Tensor Identity(const Tensor& output) { // add extra dimension for Jacobian shape.push_back(e); } - auto func = - [&output](const Array& input_indices) { - PrimExpr res = const_true(); - for (size_t i = 0; i < output->shape.size(); ++i) { - res = res && (PrimExpr(input_indices[i]) == - PrimExpr(input_indices[output->shape.size() + i])); - } - return CastNode::make(output->dtype, res); - }; + auto func = [&output](const Array& input_indices) { + PrimExpr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = + res && (PrimExpr(input_indices[i]) == PrimExpr(input_indices[output->shape.size() + i])); + } + return CastNode::make(output->dtype, res); + }; return te::compute(shape, func, "identity"); } -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) { +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head) { Tensor jac = Jacobian(output, input); Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), output->op->name + "." + input->op->name + ".grad"); return result; } -Array Gradient(const Tensor& output, - const Array& inputs, +Array Gradient(const Tensor& output, const Array& inputs, const Tensor& head_or_null) { // Diagonal identity tensor Tensor head = head_or_null.get() ? head_or_null : Identity(output); @@ -95,41 +94,40 @@ Array Gradient(const Tensor& output, // This is a recursive function that does all the work. It computes the adjoint for a given // tensor, adds it to the map, and returns it std::function compute_adjoint; - compute_adjoint = - [&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output] - (const Tensor& tensor) { - if (!adjoints.count(tensor)) { - // Here the adjoint hasn't been computed yet - Tensor res_adjoint; - std::vector direct_consumers = reverse_dependencies[tensor]; - if (direct_consumers.empty()) { - // No reverse dependencies means that the output does not depend on this tensor, - // return a zero tensor of the appropriate shape - // (i.e., output shape + tensor shape, aka shape of Jacobian) - Array result_shape(head->shape.begin(), - head->shape.end() + (-output->shape.size())); - for (auto e : tensor->shape) { - result_shape.push_back(e); - } - res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); - } else { - // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied - // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian - // and the multiplication is done in the function VectorJacobianProduct - for (const Tensor& direct_consumer : direct_consumers) { - // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) - Tensor part = VectorJacobianProduct( - direct_consumer, tensor, compute_adjoint(direct_consumer)); - res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; - } + compute_adjoint = [&compute_adjoint, &adjoints, &reverse_dependencies, &head, + &output](const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector direct_consumers = reverse_dependencies[tensor]; + if (direct_consumers.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + // (i.e., output shape + tensor shape, aka shape of Jacobian) + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); } - - adjoints[tensor] = res_adjoint; - return res_adjoint; + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); } else { - return adjoints[tensor]; + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function VectorJacobianProduct + for (const Tensor& direct_consumer : direct_consumers) { + // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) + Tensor part = + VectorJacobianProduct(direct_consumer, tensor, compute_adjoint(direct_consumer)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + } } - }; + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; // Adjoints corresponding to inputs Array result; @@ -141,15 +139,14 @@ Array Gradient(const Tensor& output, return result; } -TVM_REGISTER_GLOBAL("te.Gradient") -.set_body([](TVMArgs args, TVMRetValue *ret) { - LOG(WARNING) << "te.Gradient is an experimental feature."; - if (args.size() == 2) { - *ret = Gradient(args[0], args[1]); - } else if (args.size() == 3) { - *ret = Gradient(args[0], args[1], args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.Gradient").set_body([](TVMArgs args, TVMRetValue* ret) { + LOG(WARNING) << "te.Gradient is an experimental feature."; + if (args.size() == 2) { + *ret = Gradient(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Gradient(args[0], args[1], args[2]); + } +}); } // namespace te } // namespace tvm diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index d5b6fec3698ad..f770169e06e7b 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -23,19 +23,23 @@ * X must be direct input tensor of Y. * The result Jacobian shape will be (Y.shape, X.shape) */ -#include #include #include +#include #include #include + #include "ad_util.h" namespace tvm { namespace te { -#define NOT_IMPLEMENTED \ - { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); throw; } +#define NOT_IMPLEMENTED \ + { \ + LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); \ + throw; \ + } /*! \brief Differentiate an expression wrt a variable or a tensor element */ class JacobianMutator : public ExprMutator { @@ -46,7 +50,7 @@ class JacobianMutator : public ExprMutator { * \param indices The indices of the element with respect to which to differentiate. */ explicit JacobianMutator(Tensor input, Array indices) - : input_(input), indices_(indices) {} + : input_(input), indices_(indices) {} /*! * \brief Differentiate wrt the input variable. * \param input The input variable. @@ -71,14 +75,13 @@ class JacobianMutator : public ExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); if (op->call_type == CallNode::CallType::Halide) { - if (input_.get() && op->func.same_as(input_->op) && - op->value_index == input_->value_index) { + if (input_.get() && op->func.same_as(input_->op) && op->value_index == input_->value_index) { // Tensor(indices) CHECK_EQ(indices_.size(), op->args.size()); PrimExpr condition = const_true(); @@ -99,86 +102,71 @@ class JacobianMutator : public ExprMutator { return MulNode::make(Mutate(op->args[0]), MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr))); } else if (op->name == "sqrt") { - return DivNode::make(Mutate(op->args[0]), - MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); + return DivNode::make(Mutate(op->args[0]), MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); } else if (op->name == "tanh") { return MulNode::make(Mutate(op->args[0]), SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr))); } else if (op->name == "pow") { auto x = op->args[0], y = op->args[1]; - return expr * (Mutate(y)*log(x) + Mutate(x)*y/x); + return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); } else if (op->name == "fabs") { auto type = op->args[0].dtype(); return MulNode::make(Mutate(op->args[0]), SelectNode::make(GENode::make(op->args[0], make_zero(type)), FloatImm(type, 1.0), FloatImm(type, -1.0))); } else if (op->name == intrinsic::tvm_if_then_else) { - Array new_args = {op->args[0], - Mutate(op->args[1]), - Mutate(op->args[2])}; - return CallNode::make(op->dtype, op->name, new_args, - op->call_type, op->func, op->value_index); + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, + op->value_index); } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); } } - NOT_IMPLEMENTED + NOT_IMPLEMENTED; } - PrimExpr VisitExpr_(const AddNode* op) { - return AddNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const AddNode* op) { return AddNode::make(Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const SubNode* op) { - return SubNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const SubNode* op) { return SubNode::make(Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MulNode* op) { - return AddNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))); + return AddNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))); } PrimExpr VisitExpr_(const DivNode* op) { return DivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), + SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), MulNode::make(op->b, op->b)); } - PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const FloorDivNode* op) { return FloorDivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), + SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), MulNode::make(op->b, op->b)); } - PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const MinNode* op) { - return SelectNode::make(LENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return SelectNode::make(LENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MaxNode* op) { - return SelectNode::make(GENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return SelectNode::make(GENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const ReduceNode* op) { // This case is relatively difficult because a reduction expression @@ -265,9 +253,8 @@ class JacobianMutator : public ExprMutator { CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return analyzer_.Simplify( - ReduceNode::make(new_combiner, new_source, new_op->axis, - new_op->condition, new_op->value_index)); + return analyzer_.Simplify(ReduceNode::make(new_combiner, new_source, new_op->axis, + new_op->condition, new_op->value_index)); } PrimExpr VisitExpr_(const CastNode* op) { @@ -278,26 +265,21 @@ class JacobianMutator : public ExprMutator { } } - PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const SelectNode* op) { - return SelectNode::make(op->condition, - Mutate(op->true_value), Mutate(op->false_value)); + return SelectNode::make(op->condition, Mutate(op->true_value), Mutate(op->false_value)); } - PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED; - PrimExpr VisitExpr_(const IntImmNode* op) { - return IntImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); } - PrimExpr VisitExpr_(const FloatImmNode* op) { - return FloatImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); } - PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED; private: Tensor input_; @@ -336,8 +318,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { Array input_indices; size_t i = 0; for (PrimExpr ext : input->shape) { - IterVar new_v = IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), - IterVarType::kDataPar); + IterVar new_v = + IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar); // Append jacobian iter to new_axis new_axis.push_back(new_v); // Differentiate wrt input[input_indices] @@ -345,8 +327,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { } arith::Analyzer analzyer; // Compute Jacobian - PrimExpr new_body = Jacobian( - Substitute(op->body[output->value_index], vmap), input, input_indices); + PrimExpr new_body = + Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices); new_body = analzyer.Simplify(new_body); int value_index = 0; @@ -358,14 +340,14 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { value_index = red->value_index; for (size_t idx = 0; idx < red->source.size(); ++idx) { new_bodies.push_back( - ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); + ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); } } else { new_bodies.push_back(new_body); } - auto new_op = ComputeOpNode::make( - op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + auto new_op = + ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); // Jacobian shape = output.shape + input.shape Array new_shape = output->shape; diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 2d9f13baedaca..d8ad839e777eb 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -21,46 +21,45 @@ * \brief Compute Op. * \file compute_op.cc */ +#include "compute_op.h" + +#include #include #include -#include -#include #include +#include #include -#include + #include +#include #include -#include "compute_op.h" -#include "op_util.h" -#include "../schedule/message_passing.h" + #include "../../arith/compute_expr.h" #include "../../arith/interval_set.h" +#include "../schedule/message_passing.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "compute(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "compute(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. -static void VerifyComputeOp(const ComputeOpNode *op); +static void VerifyComputeOp(const ComputeOpNode* op); inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -int ComputeOpNode::num_outputs() const { - return body.size(); -} +int ComputeOpNode::num_outputs() const { return body.size(); } Array BaseComputeOpNode::root_iter_vars() const { if (reduce_axis.size() == 0) return axis; @@ -87,10 +86,7 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, - FCompute fcompute, - std::string name, - std::string tag, +Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, Map attrs) { auto op_node = make_object(); // compute dimension. @@ -100,20 +96,16 @@ Tensor compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back( + IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } - return ComputeOpNode::make( - name, tag, attrs, axis, {fcompute(args)}).output(0); + return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, - FBatchCompute fcompute, - std::string name, - std::string tag, - Map attrs) { +Array compute(Array shape, FBatchCompute fcompute, std::string name, + std::string tag, Map attrs) { auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); @@ -122,8 +114,8 @@ Array compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back( + IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -135,11 +127,8 @@ Array compute(Array shape, return outputs; } -Operation ComputeOpNode::make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body) { +Operation ComputeOpNode::make(std::string name, std::string tag, Map attrs, + Array axis, Array body) { if (!attrs.defined()) { attrs = Map(); } @@ -157,9 +146,7 @@ Operation ComputeOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp") -.set_body_typed(ComputeOpNode::make); - +TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make); // The schedule related logics Array ComputeOpNode::InputTensors() const { @@ -167,22 +154,21 @@ Array ComputeOpNode::InputTensors() const { std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (!visited.count(t)) { - ret.push_back(t); - visited.insert(t); - } + const tir::CallNode* call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Downcast(call->func).output(call->value_index); + if (!visited.count(t)) { + ret.push_back(t); + visited.insert(t); } - }); + } + }); } return ret; } -Operation ComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); VerifyComputeOp(this); Array arr; @@ -202,26 +188,22 @@ Operation ComputeOpNode::ReplaceInputs( arr = this->body; } } else { - arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) { - return te::ReplaceTensor(e, rmap); - }); + arr = + UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); } if (!arr.same_as(this->body)) { - return ComputeOpNode::make( - this->name, this->tag, this->attrs, this->axis, arr); + return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr); } else { return self; } } -void ComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { - auto *call = n.as(); + auto* call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); if (t->op.defined() && out_dom_map->count(t)) { @@ -231,7 +213,7 @@ void ComputeOpNode::PropBoundToInputs( // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = EvalSet(call->args[i], dom_map); + IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map)); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); @@ -239,12 +221,14 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - if (arith::is_neg_inf(min_value) || - analyzer->CanProve(shape_i_min_value >= min_value)) { + // We must update bound's ends in pairs. Here is an counter example: shape_i is + // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is + // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0], + // awkward for further analysis. + if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || + (analyzer->CanProve(shape_i_min_value >= min_value) && + analyzer->CanProve(shape_i_max_value <= max_value))) { min_value = shape_i_min_value; - } - if (arith::is_pos_inf(max_value) || - analyzer->CanProve(shape_i_max_value <= max_value)) { max_value = shape_i_max_value; } dom.data[i].push_back(IntSet::interval(min_value, max_value)); @@ -258,10 +242,9 @@ void ComputeOpNode::PropBoundToInputs( for (auto& e : body) tir::PostOrderVisit(e, fvisit); } -void BaseComputeOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void BaseComputeOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); const TensorDom& tdom = tensor_dom.at(self.output(0)); for (size_t i = 0; i < this->axis.size(); ++i) { @@ -275,10 +258,9 @@ void BaseComputeOpNode::GatherBound( } } -Stmt BaseComputeOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -286,24 +268,22 @@ Stmt BaseComputeOpNode::BuildRealize( } Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { - Tensor t = stage->op.output(i-1); - realize = tir::RealizeNode::make(t->op, t->value_index, - t->dtype, bounds, const_true(), realize); + Tensor t = stage->op.output(i - 1); + realize = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { - Array tuple = {static_cast(i), - attr->dim_align_factor, - attr->dim_align_offset}; - realize = tir::AttrStmtNode::make( - t, tir::attr::buffer_dim_align, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), - realize); + Array tuple = {static_cast(i), attr->dim_align_factor, + attr->dim_align_offset}; + realize = + tir::AttrStmtNode::make(t, tir::attr::buffer_dim_align, + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, + tuple, CallNode::Intrinsic), + realize); } } } @@ -311,16 +291,12 @@ Stmt BaseComputeOpNode::BuildRealize( return realize; } -size_t ComputeOpNode::num_schedulable_dims() const { - return axis.size(); -} +size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); } // Build a reduction body. -void MakeReduction(const ComputeOpNode* op, - const Array& tensors, - Stmt* init, +void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* init, Stmt* provide) { - Array args; + Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } @@ -339,10 +315,8 @@ void MakeReduction(const ComputeOpNode* op, Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits.emplace_back(ProvideNode::make( - t->op, t->value_index, init_value[i], args)); - provides.emplace_back(ProvideNode::make( - t->op, t->value_index, update_value[i], args)); + inits.emplace_back(ProvideNode::make(t->op, t->value_index, init_value[i], args)); + provides.emplace_back(ProvideNode::make(t->op, t->value_index, update_value[i], args)); } *init = SeqStmt::Flatten(inits); *provide = SeqStmt::Flatten(provides); @@ -352,8 +326,7 @@ void MakeReduction(const ComputeOpNode* op, } // Normal computation. -Stmt MakeProvide(const ComputeOpNode* op, - const Tensor& t) { +Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); @@ -361,8 +334,7 @@ Stmt MakeProvide(const ComputeOpNode* op, return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args); } -Stmt MakeComputeStmt(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { // grab the nest structure @@ -381,10 +353,10 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, init = MergeNest(n.init_nest, init); init = Substitute(init, n.init_vmap); // common nest - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > reduce( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > reduce(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.end()); provide = MergeNest(reduce, provide); if (debug_keep_trivial_loop) { provide = MergeNest(common, provide); @@ -407,14 +379,9 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, } } -enum class ComputeType { - kNormal, - kCrossThreadReduction, - kTensorize -}; +enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize }; -ComputeType DetectComputeType(const ComputeOpNode* self, - const Stage& stage) { +ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) { // Verify correctness of leaf nest. int normal_red = 0, thread_red = 0, tensorize = 0; @@ -434,13 +401,11 @@ ComputeType DetectComputeType(const ComputeOpNode* self, ++normal_red; } } else { - CHECK_EQ(thread_red, 0) - << "Cross thread reduce cannot swap with normal data axis"; + CHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis"; } } if (tensorize != 0) { - CHECK(thread_red == 0) - << "Cannot mix cross thread reduction with Tensorize"; + CHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize"; return ComputeType::kTensorize; } if (thread_red != 0) { @@ -451,10 +416,9 @@ ComputeType DetectComputeType(const ComputeOpNode* self, } // implement the provide utility. -Stmt ComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); ComputeType ctype = DetectComputeType(this, stage); if (ctype == ComputeType::kCrossThreadReduction) { @@ -467,20 +431,16 @@ Stmt ComputeOpNode::BuildProvide( } } -ComputeLoopNest ComputeLoopNest::make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest - ret.main_nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, - debug_keep_trivial_loop); - ret.main_predicates = MakeBoundCheck( - stage, dom_map, ret.main_vmap, false, - std::unordered_set()); + ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), + &ret.main_vmap, debug_keep_trivial_loop); + ret.main_predicates = + MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set()); for (auto& e : ret.main_predicates) { e = likely(e); } @@ -506,7 +466,8 @@ ComputeLoopNest ComputeLoopNest::make( auto iv = leaf_iter_vars[i]; int flag = update_state.at(iv); if ((flag & 2) != 0) { - begin_loop = i; break; + begin_loop = i; + break; } ret.init_vmap[iv] = ret.main_vmap.at(iv); } @@ -517,11 +478,9 @@ ComputeLoopNest ComputeLoopNest::make( int flag = kv.second; if (flag == 2) skip_iter.insert(kv.first); } - ret.init_nest = MakeLoopNest( - stage, dom_map, begin_loop, true, - skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); - ret.init_predicates = MakeBoundCheck( - stage, dom_map, ret.init_vmap, true, skip_iter); + ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), + debug_keep_trivial_loop); + ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } @@ -561,14 +520,12 @@ class ComputeVerifier final : protected tir::ExprVisitor { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions const tir::ReduceNode* reduce = e.as(); - CHECK((reduce && reduce_) || (!reduce && !reduce_)) - << "All ComputeOp should be consistent " - << "with being Reduce operation or not."; + CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " + << "with being Reduce operation or not."; if (reduce && reduce_) { - CHECK(ReduceEqual(reduce, reduce_)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } level_ = 0; @@ -587,16 +544,15 @@ class ComputeVerifier final : protected tir::ExprVisitor { void VisitExpr_(const tir::ReduceNode* op) final { // Check for non top level reductions - CHECK(0 == level_) - << "Reductions are only allowed at the top level of compute. " - << "Please create another tensor for further composition."; + CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " + << "Please create another tensor for further composition."; } //@} private: - const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify - const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation - int level_{0}; ///< Level of op being processed + const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify + const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation + int level_{0}; ///< Level of op being processed }; } // namespace @@ -606,11 +562,8 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update) { Array conds; std::unordered_set banned; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { @@ -631,19 +584,17 @@ Stmt TransformUpdate(const Stage& stage, } } - auto fbanned = [&](const VarNode* node) { - return banned.count(node); - }; + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; + LOG(FATAL) << "Tensorize update transform failed, the condition " << pred + << " has a conflict with the reset condition"; } } - return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), - update, body); + return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), update, + body); } } // namespace te diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 08db74f0d9a5f..610c014685098 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -24,10 +24,11 @@ #ifndef TVM_TE_OPERATION_COMPUTE_OP_H_ #define TVM_TE_OPERATION_COMPUTE_OP_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace te { @@ -58,11 +59,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ - static ComputeLoopNest make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); + static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); }; /*! @@ -73,11 +72,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); /*! * \brief Build body of compute for tensorization. @@ -87,10 +84,8 @@ Stmt MakeCrossThreadReduction( * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop); /*! * \brief Transform the update part when there is no init func in tensorizing @@ -101,11 +96,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, * \param update The update func in tensorize intrin * \return Transformed result. */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update); +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update); } // namespace te } // namespace tvm diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 1ec17e9e38a9c..09056314f4001 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -28,21 +28,17 @@ namespace tvm { namespace te { using namespace tir; -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { - Array args; +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + Array args; for (IterVar iv : self->axis) { args.push_back(iv->var); } std::unordered_map value_map; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &value_map, debug_keep_trivial_loop); - auto conds = MakeBoundCheck( - stage, dom_map, value_map, false, - std::unordered_set()); + auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), &value_map, + debug_keep_trivial_loop); + auto conds = MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set()); size_t size = self->body.size(); CHECK_GT(size, 0); @@ -96,10 +92,10 @@ Stmt MakeCrossThreadReduction( Array update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; - normal_init.emplace_back(StoreNode::make( - normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); - normal_update.emplace_back(StoreNode::make( - normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + normal_init.emplace_back( + StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); + normal_update.emplace_back( + StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); } } @@ -108,8 +104,7 @@ Stmt MakeCrossThreadReduction( for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { DataType t = reduces[i]->dtype; - freduce_args.push_back(LoadNode::make( - t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); } else { freduce_args.push_back(reduces[0]->source[i]); } @@ -124,8 +119,7 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { IterVar tv = (*it).second->bind_thread; freduce_args.push_back(tv->var); } @@ -138,14 +132,9 @@ Stmt MakeCrossThreadReduction( } Stmt reduce_body = EvaluateNode::make(CallNode::make( - DataType::Handle(), - tir::intrinsic::tvm_thread_allreduce, - freduce_args, CallNode::Intrinsic)); - reduce_body = AttrStmtNode::make( - reduces[0]->combiner, - tir::attr::reduce_scope, - make_zero(DataType::Handle()), - reduce_body); + DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); + reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope, + make_zero(DataType::Handle()), reduce_body); if (!normal_red.empty()) { Stmt init_body = SeqStmt::Flatten(normal_init); @@ -159,23 +148,22 @@ Stmt MakeCrossThreadReduction( for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; assigns[idx] = ProvideNode::make( - stage->op, idx, - LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + stage->op, idx, LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(MakeIfNest(conds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = AllocateNode::make( - res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = + AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, + StringImmNode::make("local"), body); if (!normal_red.empty()) { - body = AllocateNode::make( - normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, + const_true(), body); + body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope, + StringImmNode::make("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9d95e329c8f29..59d1ec10c6f73 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -21,11 +21,13 @@ * \brief External computation rule. * \file extern_op.cc */ +#include #include #include -#include #include + #include + #include "op_util.h" namespace tvm { @@ -33,37 +35,24 @@ namespace te { using namespace tir; // ExternOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "extern(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "extern(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ExternOpNode); -int ExternOpNode::num_outputs() const { - return static_cast(output_placeholders.size()); -} - -Array ExternOpNode::root_iter_vars() const { - return {}; -} +int ExternOpNode::num_outputs() const { return static_cast(output_placeholders.size()); } -DataType ExternOpNode::output_dtype(size_t i) const { - return output_placeholders[i]->dtype; -} +Array ExternOpNode::root_iter_vars() const { return {}; } -Array ExternOpNode::output_shape(size_t i) const { - return output_placeholders[i]->shape; -} +DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } +Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } -Operation ExternOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body) { +Operation ExternOpNode::make(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -76,7 +65,7 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { - CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } @@ -87,17 +76,12 @@ Operation ExternOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ExternOp") -.set_body_typed(ExternOpNode::make); +TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make); +Array ExternOpNode::InputTensors() const { return inputs; } -Array ExternOpNode::InputTensors() const { - return inputs; -} - -Operation ExternOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ExternOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); @@ -108,65 +92,54 @@ Operation ExternOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void ExternOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ExternOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void ExternOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void ExternOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt ExternOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt ExternOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } return realize_body; } -Stmt ExternOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ExternOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = + AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 7bb5d6153d8df..0022b6f1493c9 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -21,54 +21,44 @@ * \brief Hybrid computation rule. * \file hybrid_op.cc */ +#include "hybrid_op.h" + +#include #include #include -#include -#include -#include #include +#include #include -#include +#include + #include +#include #include + #include "op_util.h" -#include "hybrid_op.h" namespace tvm { namespace te { using namespace tir; // HybridOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "hybrid(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "hybrid(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { - return static_cast(outputs.size()); -} - -Array HybridOpNode::root_iter_vars() const { - return this->axis; -} +int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } -DataType HybridOpNode::output_dtype(size_t i) const { - return outputs[i]->dtype; -} +Array HybridOpNode::root_iter_vars() const { return this->axis; } -Array HybridOpNode::output_shape(size_t i) const { - return outputs[i]->shape; -} +DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } +Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } -Operation HybridOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body) { +Operation HybridOpNode::make(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -84,9 +74,7 @@ Operation HybridOpNode::make(std::string name, return res; } -TVM_REGISTER_GLOBAL("te.HybridOp") -.set_body_typed(HybridOpNode::make); - +TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make); Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, @@ -98,21 +86,20 @@ Array HybridOpNode::InputTensors() const { std::unordered_set visited; Array curr_inputs; tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (orig_inputs.count(t) && !visited.count(t)) { - curr_inputs.push_back(t); - visited.insert(t); - } + const tir::CallNode* call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Downcast(call->func).output(call->value_index); + if (orig_inputs.count(t) && !visited.count(t)) { + curr_inputs.push_back(t); + visited.insert(t); } + } }); return curr_inputs; } -Operation HybridOpNode::ReplaceInputs( - const Operation &self, - const std::unordered_map &rmap) const { +Operation HybridOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); @@ -123,46 +110,40 @@ Operation HybridOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void HybridOpNode::PropBoundToInputs( - const Operation &self, - arith::Analyzer* analyzer, - const std::unordered_map &dom_map, - std::unordered_map* out_dom_map) const { +void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; - TensorDom &dom = it->second; + TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void HybridOpNode::GatherBound( - const Operation &self, - const std::unordered_map &tensor_dom, - std::unordered_map* out_dom_map) const { +void HybridOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { for (auto iter_var : axis) { CHECK(!out_dom_map->count(iter_var)); out_dom_map->operator[](iter_var) = iter_var->dom; } } -Stmt HybridOpNode::BuildRealize( - const Stage &stage, - const std::unordered_map &realize_map, - const Stmt &body) const { +Stmt HybridOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -170,24 +151,20 @@ Stmt HybridOpNode::BuildRealize( Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = + tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), realize_body); } return realize_body; } -Stmt HybridOpNode::BuildProvide( - const Stage &stage, - const std::unordered_map &dom_map, - bool debug_keep_trivial_loop) const { +Stmt HybridOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = + AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); @@ -223,45 +200,44 @@ Stmt HybridOpNode::BuildProvide( return ret; } -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { class LoopSpliter : public StmtExprMutator { PrimExpr factor; - const VarNode *parent; + const VarNode* parent; IterVar inner, outer; public: bool splitted; - LoopSpliter(const SplitNode *split, - const std::unordered_map &dom_map) : - factor(split->factor), splitted(false) { + LoopSpliter(const SplitNode* split, const std::unordered_map& dom_map) + : factor(split->factor), splitted(false) { parent = split->parent->var.get(); - auto &inner_ = split->inner; + auto& inner_ = split->inner; CHECK(dom_map.count(inner_)); - auto &inner_dom = dom_map.find(inner_)->second; + auto& inner_dom = dom_map.find(inner_)->second; CHECK(is_const_int(inner_dom->min, 0)); - auto &outer_ = split->outer; + auto& outer_ = split->outer; CHECK(dom_map.count(outer_)); - auto &outer_dom = dom_map.find(outer_)->second; + auto& outer_dom = dom_map.find(outer_)->second; CHECK(is_const_int(outer_dom->min, 0)); inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type); outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); } - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == parent) { - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; Stmt ret = tir::Substitute(op->body, rmap); PrimExpr cond = likely(outer * factor < (op->extent - inner)); ret = IfThenElseNode::make(cond, ret); ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); + IterVarTypeToForType(inner->iter_type), op->device_api, ret); ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + IterVarTypeToForType(outer->iter_type), op->device_api, ret); splitted = true; return ret; } @@ -270,24 +246,27 @@ Stmt ApplyLoopShapes(const Stage &stage, }; class LoopFuser : public StmtExprMutator { - const IterVar &parent; - const VarNode *inner; - const VarNode *outer; + const IterVar& parent; + const VarNode* inner; + const VarNode* outer; bool under_outer; PrimExpr extent; public: bool fused; - explicit LoopFuser(const FuseNode *fuse_) - : parent(fuse_->fused), inner(fuse_->inner->var.get()), - outer(fuse_->outer->var.get()), under_outer(false), - extent(0), fused(false) {} + explicit LoopFuser(const FuseNode* fuse_) + : parent(fuse_->fused), + inner(fuse_->inner->var.get()), + outer(fuse_->outer->var.get()), + under_outer(false), + extent(0), + fused(false) {} // TODO(@were): Handle imperfect loops Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; @@ -295,15 +274,15 @@ Stmt ApplyLoopShapes(const Stage &stage, } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, - op->for_type, op->device_api, body); + return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type, + op->device_api, body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = tir::Substitute(body, rmap); extent = extent * op->extent; @@ -313,12 +292,12 @@ Stmt ApplyLoopShapes(const Stage &stage, } }; - for (auto &rel : stage->relations) { - if (const SplitNode *split = rel.as()) { + for (auto& rel : stage->relations) { + if (const SplitNode* split = rel.as()) { LoopSpliter Spliter(split, dom_map); stmt = Spliter(stmt); CHECK(Spliter.splitted); - } else if (const FuseNode *fuse = rel.as()) { + } else if (const FuseNode* fuse = rel.as()) { LoopFuser Fuser(fuse); stmt = Fuser(stmt); CHECK(Fuser.fused); @@ -328,45 +307,45 @@ Stmt ApplyLoopShapes(const Stage &stage, return stmt; } -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt) { class LoopAnnotator : public StmtMutator { - const VarNode *var; - const IterVarAttr &attr; + const VarNode* var; + const IterVarAttr& attr; public: - LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} + LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {} - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { tir::ExprDeepEqual expr_equal; if (op->loop_var.get() == var) { if (attr->bind_thread.defined()) { - const auto &iter_var = attr->bind_thread; + const auto& iter_var = attr->bind_thread; if (iter_var->dom.defined()) { CHECK(is_const_int(iter_var->dom->min, 0)); CHECK(expr_equal(iter_var->dom->extent, op->extent)) - << "Thread extent and loop extent mismatch!\n"; + << "Thread extent and loop extent mismatch!\n"; } - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = tir::Substitute(op->body, rmap); return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); } else { return ForNode::make(op->loop_var, op->min, op->extent, - IterVarTypeToForType(attr->iter_type), op->device_api, op->body); + IterVarTypeToForType(attr->iter_type), op->device_api, op->body); } } return StmtMutator::VisitStmt_(op); } }; - for (auto &iter_var : stage->leaf_iter_vars) { + for (auto& iter_var : stage->leaf_iter_vars) { bool need_change = false; int found = 0; - const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; - const VarNode *var = actual->var.get(); + const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + const VarNode* var = actual->var.get(); ForType expected = IterVarTypeToForType(iter_var->iter_type); IterVarAttr attr; if (stage->iter_var_attrs.count(iter_var)) { @@ -374,9 +353,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, expected = IterVarTypeToForType(attr->iter_type); } - PostOrderVisit(stmt, - [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { + if (const ForNode* op = node.as()) { if (op->loop_var.get() == var) { ++found; need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); @@ -392,23 +370,21 @@ Stmt ApplyLoopAnnotations(const Stage &stage, return stmt; } -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt) { std::vector current_order; PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { - if (const ForNode *op = node.as()) - current_order.push_back(op->loop_var.get()); + if (const ForNode* op = node.as()) current_order.push_back(op->loop_var.get()); }); std::reverse(current_order.begin(), current_order.end()); - auto &required_ord = stage->leaf_iter_vars; + auto& required_ord = stage->leaf_iter_vars; CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; - std::unordered_map reorder; + std::unordered_map reorder; bool need_reorder = false; for (size_t i = 0; i < current_order.size(); ++i) { - auto ¤t = current_order[i]; - const IterVar &iter_var = required_ord[i]; - const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + auto& current = current_order[i]; + const IterVar& iter_var = required_ord[i]; + const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n"; reorder[current] = required; if (current != required->var.get()) { @@ -417,15 +393,14 @@ Stmt ApplyLoopOrder(const Stage &stage, } class LoopReorder : public StmtMutator { - const Stage &stage; - const std::unordered_map &dom_map; - const std::unordered_map &reorder; + const Stage& stage; + const std::unordered_map& dom_map; + const std::unordered_map& reorder; public: - LoopReorder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &reorder) - : stage(stage), dom_map(dom_map), reorder(reorder) {} + LoopReorder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& reorder) + : stage(stage), dom_map(dom_map), reorder(reorder) {} Stmt VisitStmt_(const ForNode* op) final { // Reorder from in to out @@ -434,25 +409,23 @@ Stmt ApplyLoopOrder(const Stage &stage, auto target = reorder.find(op->loop_var.get())->second; if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) return GetRef(op); - const Stmt &body = op->body.same_as(body_) ? op->body : body_; + const Stmt& body = op->body.same_as(body_) ? op->body : body_; ForType for_type = IterVarTypeToForType(target->iter_type); if (stage->iter_var_attrs.count(target)) { for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); } - const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return ForNode::make(target->var, range->min, range->extent, - for_type, DeviceAPI::None, body); + const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; + return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); } }; - if (need_reorder) - return LoopReorder(stage, dom_map, reorder)(stmt); + if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt); return stmt; } -Stmt ApplySchedule(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { // TODO(@were): Eliminate loop rebase in script parser and move the burden here // Gather rebased variables std::unordered_map rebased; @@ -473,7 +446,7 @@ std::vector GatherLoopVars(Stmt stmt) { // TODO(@were): Write a comprehensive pass to analyze iter var types std::vector res_; PostOrderVisit(stmt, [&res_](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + if (const ForNode* op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type))); @@ -486,15 +459,14 @@ std::vector GatherLoopVars(Stmt stmt) { // replacer to replace tensors' usage in Provide class ProviderReplacer : public tir::StmtMutator { public: - explicit ProviderReplacer(const std::unordered_map &vmap) - : vmap_(vmap) {} + explicit ProviderReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} Stmt VisitStmt_(const tir::ProvideNode* op) final { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = tir::ProvideNode::make( - it->second->op, it->second->value_index, op->value, op->args); + Stmt ret = + tir::ProvideNode::make(it->second->op, it->second->value_index, op->value, op->args); found = true; return this->VisitStmt(ret); } @@ -505,11 +477,10 @@ class ProviderReplacer : public tir::StmtMutator { bool found{false}; private: - const std::unordered_map &vmap_; + const std::unordered_map& vmap_; }; -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map &replace) { +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace) { ProviderReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; diff --git a/src/te/operation/hybrid_op.h b/src/te/operation/hybrid_op.h index dadfecd3bdefc..a11ae89e23f71 100644 --- a/src/te/operation/hybrid_op.h +++ b/src/te/operation/hybrid_op.h @@ -24,16 +24,16 @@ #ifndef TVM_TE_OPERATION_HYBRID_OP_H_ #define TVM_TE_OPERATION_HYBRID_OP_H_ -#include #include +#include #include #include #include -#include "../schedule/message_passing.h" -#include "../../tir/transforms/ir_util.h" #include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" +#include "../schedule/message_passing.h" namespace tvm { namespace te { @@ -49,8 +49,7 @@ std::vector GatherLoopVars(Stmt stmt); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Apply the schedule manipulation on the function body. @@ -58,8 +57,8 @@ Stmt ReplaceProvideTensor(Stmt stmt, * \param dom_map The extents of the iterative variables may be used. * \param stage The schedule information to be applied. */ -Stmt ApplySchedule(const Stage& stage, - const std::unordered_map& dom_map, Stmt stmt); +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop splits and fuses in the schedule on the function body. @@ -67,9 +66,8 @@ Stmt ApplySchedule(const Stage& stage, * \param dom_map The extents of the iterative variables may be used. * \param stmt The statement to be processed. */ -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map& dom_map, Stmt stmt); - +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop annotation in the schedule on the function body. @@ -77,8 +75,8 @@ Stmt ApplyLoopShapes(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map& rebased, Stmt stmt); +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt); /*! * \brief Apply loop order in the schedule on the function body. @@ -87,9 +85,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt); +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt); } // namespace te } // namespace tvm diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index f7e0e51fd16a0..5b200ac0ce940 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -21,14 +21,17 @@ * \brief Utility to make loop nest. * \file op_util.cc */ +#include "op_util.h" + +#include #include #include -#include + #include -#include "op_util.h" -#include "../schedule/message_passing.h" + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "../schedule/message_passing.h" namespace tvm { namespace te { @@ -36,14 +39,12 @@ namespace te { using namespace arith; using namespace tir; -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop) { +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; Stmt no_op = EvaluateNode::make(0); // create the loop nest @@ -84,14 +85,21 @@ MakeLoopNest(const Stage& stage, } if (it_attr.defined()) { switch (it_attr->iter_type) { - case kUnrolled: for_type = ForType::Unrolled; break; - case kVectorized: for_type = ForType::Vectorized; break; - case kParallelized: for_type = ForType::Parallel; break; - case kDataPar: break; - case kTensorized: break; - default: LOG(FATAL) << "Unknown iter type" - << it_attr->iter_type - << " in the iter_var_attrs"; + case kUnrolled: + for_type = ForType::Unrolled; + break; + case kVectorized: + for_type = ForType::Vectorized; + break; + case kParallelized: + for_type = ForType::Parallel; + break; + case kDataPar: + break; + case kTensorized: + break; + default: + LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs"; } CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size()); for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) { @@ -105,38 +113,30 @@ MakeLoopNest(const Stage& stage, } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back( - LetStmtNode::make(var, dom->min, no_op)); + nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op)); value_map[iv] = dom->min; } else if (is_zero(dom->min)) { nest[i + 1].emplace_back( - ForNode::make(var, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); nest[i + 1].emplace_back( - ForNode::make(idx, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; - nest[i + 1].emplace_back( - LetStmtNode::make(var, new_value, no_op)); + nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op)); } if (it_attr.defined() && it_attr->prefetch_data.size() != 0) { - CHECK(!is_one(dom->extent)) - << "Cannot prefetch on trivial loop with extent=1"; - CHECK_EQ(it_attr->prefetch_data.size(), - it_attr->prefetch_offset.size()); + CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1"; + CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size()); for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { - nest[i + 1].emplace_back( - AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j], + tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } - } else if (bind_iv->thread_tag == "vthread" || - bind_iv->thread_tag == "cthread") { + } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") { // virtual thread // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); @@ -163,9 +163,21 @@ MakeLoopNest(const Stage& stage, value_map[iv] = dom->min; } else { runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); - if (stage->scope == "" || stage->scope == "warp" || + if (stage->scope == "" || static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { value_map[iv] = var; + } else if (stage->scope == "warp" && ts.rank == 1) { + // To determine whether a thread index is inside or outside a warp, we need + // to know the thread extent. We leave a warning for now. + if (ts.dim_index == 0) { + value_map[iv] = var; + } else { + LOG(WARNING) + << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " + << "TVM assumes only threadIdx.x indicates threads inside a warp, " + << "while threadIdx.y and threadIdx.z indicates different warps."; + value_map[iv] = dom->min; + } } else { value_map[iv] = dom->min; } @@ -173,8 +185,7 @@ MakeLoopNest(const Stage& stage, } // annotate the extent of the IterVar if (!new_loop_var) { - nest[i + 1].emplace_back( - AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); + nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. @@ -194,17 +205,15 @@ std::vector MakeIfNest(const std::vector& predicates) { // replacer to replace tensors class TensorReplacer : public tir::StmtExprMutator { public: - explicit TensorReplacer(const std::unordered_map& vmap) - : vmap_(vmap) {} + explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const tir::CallNode* op) final { if (op->call_type == tir::CallNode::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { - PrimExpr ret = tir::CallNode::make( - op->dtype, it->second->op->name, op->args, - op->call_type, it->second->op, it->second->value_index); + PrimExpr ret = tir::CallNode::make(op->dtype, it->second->op->name, op->args, op->call_type, + it->second->op, it->second->value_index); found = true; return this->VisitExpr(ret); } @@ -219,22 +228,18 @@ class TensorReplacer : public tir::StmtExprMutator { const std::unordered_map& vmap_; }; -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace) { +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace) { TensorReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; } -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace) { +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace) { TensorReplacer repl(replace); PrimExpr ret = repl(expr); return repl.found ? ret : expr; } - -Stmt Substitute(Stmt s, - const std::unordered_map& value_map) { +Stmt Substitute(Stmt s, const std::unordered_map& value_map) { std::unordered_map init; for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; @@ -244,31 +249,31 @@ Stmt Substitute(Stmt s, IterVarType ForTypeToIterVarType(tir::ForType for_type) { switch (for_type) { - case ForType::Serial: - return kDataPar; - case ForType::Parallel: - return kParallelized; - case ForType::Vectorized: - return kVectorized; - case ForType::Unrolled: - return kUnrolled; - default: - return kDataPar; + case ForType::Serial: + return kDataPar; + case ForType::Parallel: + return kParallelized; + case ForType::Vectorized: + return kVectorized; + case ForType::Unrolled: + return kUnrolled; + default: + return kDataPar; } } tir::ForType IterVarTypeToForType(IterVarType iter_type) { switch (iter_type) { - case kDataPar: - return ForType::Serial; - case kParallelized: - return ForType::Parallel; - case kVectorized: - return ForType::Vectorized; - case kUnrolled: - return ForType::Unrolled; - default: - return ForType::Serial; + case kDataPar: + return ForType::Serial; + case kParallelized: + return ForType::Parallel; + case kVectorized: + return ForType::Vectorized; + case kUnrolled: + return ForType::Unrolled; + default: + return ForType::Serial; } } diff --git a/src/te/operation/op_util.h b/src/te/operation/op_util.h index f95f84ac4d017..6c864fca67d50 100644 --- a/src/te/operation/op_util.h +++ b/src/te/operation/op_util.h @@ -24,13 +24,15 @@ #ifndef TVM_TE_OPERATION_OP_UTIL_H_ #define TVM_TE_OPERATION_OP_UTIL_H_ -#include #include +#include + #include #include #include -#include "../../tir/transforms/ir_util.h" + #include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" #include "../schedule/message_passing.h" namespace tvm { @@ -49,14 +51,12 @@ using tir::MergeNest; * \param p_value_map The result value of each IterVar. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 */ -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop); +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop); /*! * \brief Create a nest of if checking the predicates. @@ -71,15 +71,13 @@ std::vector MakeIfNest(const std::vector& predicates); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace); +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace); /*! * \brief Substitute the variables of stmt by value map. @@ -87,8 +85,7 @@ PrimExpr ReplaceTensor(PrimExpr expr, * \param value_map The value map. * \return Substituted result. */ -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); +Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); /*! * \brief Converts Halide ForType to its corresponding IterVarType diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index d48be4c536681..9c536ebb87859 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -29,20 +29,16 @@ namespace te { // PlaceholderOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "placeholder(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "placeholder(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(PlaceholderOpNode); -int PlaceholderOpNode::num_outputs() const { - return 1; -} +int PlaceholderOpNode::num_outputs() const { return 1; } -Array PlaceholderOpNode::root_iter_vars() const { - return {}; -} +Array PlaceholderOpNode::root_iter_vars() const { return {}; } DataType PlaceholderOpNode::output_dtype(size_t i) const { CHECK_EQ(i, 0U); @@ -54,9 +50,7 @@ Array PlaceholderOpNode::output_shape(size_t i) const { return shape; } -Operation PlaceholderOpNode::make(std::string name, - Array shape, - DataType dtype) { +Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { auto n = make_object(); n->name = name; n->shape = shape; @@ -69,44 +63,35 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") -.set_body_typed([](Array shape, DataType dtype, std::string name) { - return placeholder(shape, dtype, name); -}); + .set_body_typed([](Array shape, DataType dtype, std::string name) { + return placeholder(shape, dtype, name); + }); -Array PlaceholderOpNode::InputTensors() const { - return {}; -} +Array PlaceholderOpNode::InputTensors() const { return {}; } -Operation PlaceholderOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation PlaceholderOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { return self; } void PlaceholderOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { -} + std::unordered_map* out_dom_map) const {} -void PlaceholderOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void PlaceholderOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt PlaceholderOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { return body; } -Stmt PlaceholderOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt PlaceholderOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { return Stmt(); } } // namespace te diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 49929282efb3d..582e290035073 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -24,23 +24,22 @@ #include #include #include -#include "op_util.h" + #include "../schedule/graph.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "scan(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "scan(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -int ScanOpNode::num_outputs() const { - return static_cast(update.size()); -} +int ScanOpNode::num_outputs() const { return static_cast(update.size()); } Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { @@ -49,23 +48,16 @@ Array ScanOpNode::root_iter_vars() const { return ret; } -DataType ScanOpNode::output_dtype(size_t i) const { - return update[i]->dtype; -} +DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } Array ScanOpNode::output_shape(size_t i) const { CHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } -Operation ScanOpNode::make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array inputs) { +Operation ScanOpNode::make(std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { if (!attrs.defined()) { attrs = Map(); } @@ -82,31 +74,26 @@ Operation ScanOpNode::make(std::string name, CHECK_EQ(init[i]->dtype, update[i]->dtype); CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; - CHECK(prove_equal( - state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) + CHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) << "state_placeholder.shape[0] need to match" << " scan_axis.dom.min + scan_axis.dom.extent"; CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) << "The dimension of init need to match state_placeholder"; CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) << "The update.ndim need to be state_placeholder.ndim - 1"; - for (size_t k = 0; k < update[i].ndim(); ++k) { - CHECK(prove_equal( - update[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 0; k < update[i].ndim(); ++k) { + CHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); if (k != 0) { // setup spatial axis std::ostringstream spatial_name; spatial_name << name << ".out" << i << ".i" << k; - n->spatial_axis_.push_back( - IterVarNode::make( - Range::make_by_min_extent(0, update[i]->shape[k]), - Var(spatial_name.str()), kOpaque)); + n->spatial_axis_.push_back(IterVarNode::make( + Range::make_by_min_extent(0, update[i]->shape[k]), Var(spatial_name.str()), kOpaque)); } } - for (size_t k = 1; k < init[i].ndim(); ++k) { - CHECK(prove_equal( - init[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 1; k < init[i].ndim(); ++k) { + CHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); } } n->name = std::move(name); @@ -120,25 +107,16 @@ Operation ScanOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.ScanOp") -.set_body_typed(ScanOpNode::make); - +TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make); -Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs, - std::string name, - std::string tag, +Array scan(Array init, Array update, Array state_placeholder, + Array inputs, std::string name, std::string tag, Map attrs) { - IterVar scan_axis = - IterVarNode::make( - Range::make_by_min_extent( - init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), - Var(name + ".idx"), kOrdered); - Operation op = ScanOpNode::make( - name, tag, attrs, scan_axis, - init, update, state_placeholder, inputs); + IterVar scan_axis = IterVarNode::make( + Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), + Var(name + ".idx"), kOrdered); + Operation op = + ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); @@ -157,9 +135,8 @@ Array ScanOpNode::InputTensors() const { return ret; } -Operation ScanOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ScanOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); for (size_t i = 0; i < n->init.size(); ++i) { @@ -170,19 +147,16 @@ Operation ScanOpNode::ReplaceInputs( n->update.Set(i, rmap.at(n->update[i])); } } - if (!n->init.same_as(init) || - !n->update.same_as(update)) { + if (!n->init.same_as(init) || !n->update.same_as(update)) { return Operation(n); } else { return self; } } -void ScanOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { TensorDom* init_dom = nullptr; @@ -195,8 +169,8 @@ void ScanOpNode::PropBoundToInputs( } // first dimension, always needed. if (init_dom) { - init_dom->data[0].push_back(IntSet::range( - Range::make_by_min_extent(0, this->init[i]->shape[0]))); + init_dom->data[0].push_back( + IntSet::range(Range::make_by_min_extent(0, this->init[i]->shape[0]))); } if (update_dom) { update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get())); @@ -214,10 +188,9 @@ void ScanOpNode::PropBoundToInputs( } } -void ScanOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void ScanOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); CHECK(!out_dom_map->count(this->scan_axis)); std::vector output(this->num_outputs()); @@ -234,8 +207,8 @@ void ScanOpNode::GatherBound( arith::Analyzer analyzer; Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); - (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); + (*out_dom_map)[this->scan_axis] = + Range::make_by_min_extent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -256,15 +229,12 @@ void ScanOpNode::GatherBound( } } -Stmt ScanOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& dom_map, - const Stmt& body) const { +Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, + const Stmt& body) const { arith::Analyzer analyzer; CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); - Range tdom = Range::make_by_min_extent( - 0, analyzer.Simplify(sdom->extent + sdom->min)); + Range tdom = Range::make_by_min_extent(0, analyzer.Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { @@ -276,25 +246,19 @@ Stmt ScanOpNode::BuildRealize( IterVar sp_ax = this->spatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, - bounds, const_true(), ret); + ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, bounds, const_true(), ret); } return ret; } -Stmt ScanOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt provide = AttrStmtNode::make( - stage->op, tir::attr::scan_update_scope, this->scan_axis->var, - EvaluateNode::make(0)); - Stmt init = AttrStmtNode::make( - stage->op, tir::attr::scan_init_scope, 0, - EvaluateNode::make(0)); + Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, + EvaluateNode::make(0)); + Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0)); size_t begin_scan = 0; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) { CHECK_EQ(begin_scan, i); begin_scan = i + 1; @@ -302,12 +266,9 @@ Stmt ScanOpNode::BuildProvide( } std::unordered_map vmap; std::unordered_set empty; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); + auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); nest[begin_scan].push_back(init); - nest.push_back( - MakeIfNest( - MakeBoundCheck(stage, dom_map, vmap, false, empty))); + nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty))); return MergeNest(nest, provide); } } // namespace te diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index f714691f41711..236aff68b44cd 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -21,26 +21,27 @@ * \brief Tensor Compute Op. * \file tensor_compute_op.cc */ +#include #include #include -#include #include #include + #include -#include "./op_util.h" -#include "./compute_op.h" #include "../../arith/compute_expr.h" +#include "./compute_op.h" +#include "./op_util.h" namespace tvm { namespace te { using namespace tir; // TensorComputeOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); @@ -52,15 +53,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Operation TensorComputeOpNode::make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, - Array scalar_inputs) { +Operation TensorComputeOpNode::make(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, + TensorIntrin intrin, Array tensors, + Array regions, Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -74,17 +70,12 @@ Operation TensorComputeOpNode::make(std::string name, return Operation(n); } -TVM_REGISTER_GLOBAL("te.TensorComputeOp") -.set_body_typed(TensorComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make); +Array TensorComputeOpNode::InputTensors() const { return inputs; } -Array TensorComputeOpNode::InputTensors() const { - return inputs; -} - -Operation TensorComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); auto intrin = make_object(*(this->intrin.operator->())); @@ -104,8 +95,7 @@ Operation TensorComputeOpNode::ReplaceInputs( if (intrin->body.same_as(n->intrin->body) && intrin->reduce_init.same_as(n->intrin->reduce_init) && - intrin->reduce_update.same_as(n->intrin->reduce_update) && - inputs.same_as(n->inputs)) { + intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) { return self; } else { n->intrin = TensorIntrin(intrin); @@ -114,8 +104,7 @@ Operation TensorComputeOpNode::ReplaceInputs( } void TensorComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { @@ -131,14 +120,11 @@ void TensorComputeOpNode::PropBoundToInputs( } } -size_t TensorComputeOpNode::num_schedulable_dims() const { - return schedulable_ndim; -} +size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; } -Stmt TensorComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); // Start bind data. @@ -161,9 +147,8 @@ Stmt TensorComputeOpNode::BuildProvide( } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // output binding @@ -187,9 +172,8 @@ Stmt TensorComputeOpNode::BuildProvide( output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // Check variable remap @@ -213,8 +197,7 @@ Stmt TensorComputeOpNode::BuildProvide( ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); if (this->reduce_axis.size() == 0) { - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); CHECK(this->intrin->body.defined()) @@ -224,24 +207,23 @@ Stmt TensorComputeOpNode::BuildProvide( body = tir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); body = te::Substitute(body, n.main_vmap); - Stmt ret = MergeNest(nest, body); + Stmt ret = MergeNest(nest, body); return ret; } else { // Need to split reduction - CHECK(this->intrin->reduce_update.defined()) - << "Reduction update op is not defined"; + CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; // Need init and update steps CHECK_NE(this->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (this->intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -256,11 +238,9 @@ Stmt TensorComputeOpNode::BuildProvide( return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(this->intrin->body.defined()) - << "Normal body op is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - this->intrin->body, - this->intrin->reduce_update); + CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; + Stmt update = + TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 31d4b368ad895..f322e12f8db1f 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -21,14 +21,14 @@ * \brief Logics related to tensorize, used by ComputeOpNode. * \file tensorize.cc */ +#include +#include #include #include -#include -#include -#include "op_util.h" -#include "compute_op.h" #include "../schedule/message_passing.h" +#include "compute_op.h" +#include "op_util.h" namespace tvm { namespace te { @@ -39,12 +39,10 @@ using namespace tir; // out_dom: the domain of root iter vars in output op // in_region: region of each input tensor. // return The location of the tensorized scope start. -size_t InferTensorizeRegion( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - std::unordered_map* out_dom, - std::unordered_map >* in_region) { +size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + std::unordered_map* out_dom, + std::unordered_map >* in_region) { // Get the bound of the tensorized scope. bool found_point = false; size_t loc_scope = 0; @@ -52,8 +50,7 @@ size_t InferTensorizeRegion( // Loop over the leafs for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) { IterVar iv = stage->leaf_iter_vars[i - 1]; - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce); + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce); auto vit = dom_map.find(iv); CHECK(vit != dom_map.end()); const Range& vrange = vit->second; @@ -69,8 +66,7 @@ size_t InferTensorizeRegion( if (iit != stage->iter_var_attrs.end()) { const IterVarAttr& attr = (*iit).second; if (!found_point) { - CHECK(!attr->bind_thread.defined()) - << "Do not allow thread in tensorize scope"; + CHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope"; } if (attr->iter_type == kTensorized) { CHECK(!found_point) << "Do not allow two tensorized point"; @@ -113,18 +109,15 @@ size_t InferTensorizeRegion( return loc_scope; } -void VerifyTensorizeLoopNest(const ComputeOpNode* self, - const Stage& stage, - const ComputeLoopNest& n, - size_t tloc) { +void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, + const ComputeLoopNest& n, size_t tloc) { // Veirfication step. std::unordered_set banned; CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); - CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || - n.init_nest.size() == 0); + CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0); auto f_push_banned = [&banned](const Stmt& s) { if (const ForNode* op = s.as()) { - banned.insert(op->loop_var.get()); + banned.insert(op->loop_var.get()); } else if (const AttrStmtNode* op = s.as()) { if (const IterVarNode* iv = op->node.as()) { banned.insert(iv->var.get()); @@ -144,20 +137,18 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } - auto fbanned = [&](const VarNode* node) { - return banned.count(node); - }; + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; for (const PrimExpr& pred : n.main_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { if (tir::ExprUseVar(pred, fbanned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } } @@ -178,9 +169,8 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = e.start; i < e.region.size(); ++i) { args.push_back(op->args[i] - e.region[i]->min); } - return CallNode::make( - op->dtype, e.tensor->op->name, args, - op->call_type, e.tensor->op, e.tensor->value_index); + return CallNode::make(op->dtype, e.tensor->op->name, args, op->call_type, e.tensor->op, + e.tensor->value_index); } } return expr; @@ -205,16 +195,13 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis.push_back(it->second); } } - return ReduceNode::make( - op->combiner, op->source, axis, op->condition, op->value_index); + return ReduceNode::make(op->combiner, op->source, axis, op->condition, op->value_index); } - void Init(const ComputeOpNode* self, - const Stage& stage, + void Init(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, + const std::unordered_map >& in_region, const TensorIntrin& intrin, Map* compute_intrin_iter_space) { CHECK(self == stage->op.get()); @@ -243,8 +230,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { CHECK(is_one(canonical_extent)) << "Tensorize " << intrin->name << ":" << " Input dimension mismatch with tensor intrin " - << " expected shape=" << e.tensor->shape - << ", given region=" << e.region; + << " expected shape=" << e.tensor->shape << ", given region=" << e.region; } in_remap_[inputs[i]] = e; } @@ -257,10 +243,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { size_t axis_start = self->axis.size() - intrin_compute->axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Output mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->axis.size() - << ", tensorize-dim=" << self->axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->axis.size() + << ", tensorize-dim=" << self->axis.size(); var_remap_[self->axis[i]->var.get()] = r->min; } // Assume we tensorize at regin axis i [min, min + extent) @@ -280,10 +265,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->reduce_axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Reduction mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->reduce_axis.size() - << ", tensorize-dim=" << self->reduce_axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->reduce_axis.size() + << ", tensorize-dim=" << self->reduce_axis.size(); var_remap_[self->reduce_axis[i]->var.get()] = r->min; } for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { @@ -314,14 +298,12 @@ class TensorIntrinMatcher final : public StmtExprMutator { }; // Try to match tensor dataflow of the stage with the intrinsic -Array MatchTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, - Map* compute_intrin_iter_space) { +Array MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin, + Map* compute_intrin_iter_space) { TensorIntrinMatcher matcher; matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); Array ret; @@ -331,21 +313,18 @@ Array MatchTensorizeBody( return ret; } -void VerifyTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin) { +void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin) { StructuralEqual expr_equal; Map compute_intrin_iter_space; Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, - &compute_intrin_iter_space); + &compute_intrin_iter_space); const ComputeOpNode* intrin_compute = intrin->op.as(); CHECK(intrin_compute) << "Only support compute intrinsic for now"; - CHECK_EQ(body.size(), intrin_compute->body.size()) - << "Tensorize failed: body size mismatch"; + CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; arith::Analyzer ana; ana.Bind(compute_intrin_iter_space); @@ -353,29 +332,23 @@ void VerifyTensorizeBody( PrimExpr lhs = ana.Simplify(body[i]); PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); if (lhs.dtype() != rhs.dtype()) { - LOG(FATAL) - << "Failed to match the data type with TensorIntrin " - << intrin->name << "'s declaration " - << " provided=" << lhs.dtype() - << ", intrin=" << rhs.dtype(); + LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name + << "'s declaration " + << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype(); } - CHECK(expr_equal(lhs, rhs)) - << "Failed to match the compute with TensorIntrin " - << intrin->name << "'s declaration " - << " provided= " << lhs - << ", intrin= " << rhs; + CHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name + << "'s declaration " + << " provided= " << lhs << ", intrin= " << rhs; } } -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { std::unordered_map out_dom; std::unordered_map > in_region; size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); - TensorIntrin intrin = stage->iter_var_attrs.at( - stage->leaf_iter_vars[tloc])->tensor_intrin; + TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); @@ -384,8 +357,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, Stmt nop = EvaluateNode::make(0); std::vector input_bind_nest, output_bind_nest; Array inputs = self->InputTensors(); - CHECK_EQ(inputs.size(), intrin->inputs.size()) - << "Tensorize failed: input size mismatch "; + CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch "; // input binding for (size_t i = 0; i < intrin->inputs.size(); ++i) { Tensor tensor = inputs[i]; @@ -401,9 +373,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -423,9 +394,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + nop)); } // Check variable remap std::unordered_map vmap; @@ -437,8 +407,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, IterVar iv = self->reduce_axis[i]; auto it = out_dom.find(iv); CHECK(it != out_dom.end()); - CHECK(is_one(it->second->extent)) - << "Tensorization fail: reduction axis size do not match"; + CHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match"; } for (size_t i = start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; @@ -447,17 +416,14 @@ Stmt MakeTensorize(const ComputeOpNode* self, CHECK(it != out_dom.end()); binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0), "tensir_intrin.reduction.min"); - binder.Bind(target->dom->extent, it->second->extent, - "tensir_intrin.reduction.extent"); + binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent"); } if (tloc <= n.num_common_loop) { // Do no need to split reduction - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); - CHECK(intrin->body.defined()) - << "Normal store op for intrin " << intrin << " is not defined"; + CHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined"; Stmt body = MergeNest(output_bind_nest, intrin->body); body = MergeNest(input_bind_nest, body); body = tir::Substitute(body, vmap); @@ -470,16 +436,16 @@ Stmt MakeTensorize(const ComputeOpNode* self, << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps CHECK_NE(self->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -494,11 +460,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(intrin->body.defined()) - << "Normal body op for intrin " << intrin << " is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - intrin->body, - intrin->reduce_update); + CHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); @@ -511,36 +474,26 @@ Stmt MakeTensorize(const ComputeOpNode* self, } // Register functions for unittests -TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map dmap = args[1]; - std::unordered_map out_dom; - std::unordered_map > in_region; - CHECK(stage->op.as()); - InferTensorizeRegion(stage->op.as(), - stage, - as_unordered_map(dmap), - &out_dom, &in_region); - *ret = Array{Map(out_dom), - Map >(in_region)}; - }); +TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map dmap = args[1]; + std::unordered_map out_dom; + std::unordered_map > in_region; + CHECK(stage->op.as()); + InferTensorizeRegion(stage->op.as(), stage, as_unordered_map(dmap), &out_dom, + &in_region); + *ret = Array{Map(out_dom), Map >(in_region)}; +}); -TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map out_dom = args[1]; - Map > in_region = args[2]; - TensorIntrin intrin = args[3]; - Map vrange; - CHECK(stage->op.as()); - *ret = MatchTensorizeBody(stage->op.as(), - stage, - {{}}, - as_unordered_map(out_dom), - as_unordered_map(in_region), - intrin, - &vrange); - }); +TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map out_dom = args[1]; + Map > in_region = args[2]; + TensorIntrin intrin = args[3]; + Map vrange; + CHECK(stage->op.as()); + *ret = MatchTensorizeBody(stage->op.as(), stage, {{}}, as_unordered_map(out_dom), + as_unordered_map(in_region), intrin, &vrange); +}); } // namespace te } // namespace tvm diff --git a/src/te/schedule/auto_inline_elem_wise.cc b/src/te/schedule/auto_inline_elem_wise.cc index 6d79f4a8d1d67..e2b7215158b24 100644 --- a/src/te/schedule/auto_inline_elem_wise.cc +++ b/src/te/schedule/auto_inline_elem_wise.cc @@ -21,8 +21,8 @@ * \file auto_inline_elem_wise.cc */ #include -#include #include +#include #include namespace tvm { @@ -61,7 +61,6 @@ class ElemWiseDetector : public tir::ExprVisitor { Array axis_; }; - bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); @@ -112,12 +111,9 @@ void AutoInlineInjective(Schedule sch) { } } -TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") -.set_body_typed(AutoInlineElemWise); - +TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise); -TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") -.set_body_typed(AutoInlineInjective); +TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective); } // namespace te } // namespace tvm diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 4dde945baf8cc..01d4f93db45a0 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -22,13 +22,15 @@ * \brief The bound inference logic. */ #include -#include #include +#include + #include #include + +#include "../../runtime/thread_storage_scope.h" #include "graph.h" #include "message_passing.h" -#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace te { @@ -49,13 +51,11 @@ struct GraphContext { std::unordered_map op2stage_; }; -bool NeedRelax(const IterVar& iv, - bool found_attach, +bool NeedRelax(const IterVar& iv, bool found_attach, const std::unordered_map& bind_map, const runtime::StorageScope& scope) { auto it = bind_map.find(iv); - const std::string& tag = ( - it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } @@ -63,25 +63,21 @@ bool NeedRelax(const IterVar& iv, // When there is warp memory // threadIdx.x must be set to be warp index. - if (scope.rank == StorageRank::kWarp && - ts.rank == 1 && - ts.dim_index == 0) { + if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { return true; } return static_cast(scope.rank) <= ts.rank; } // infer storage scope, if not given -StorageScope InferStorageScope( - const Stage& stage, const GraphContext& ctx) { +StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { return StorageScope::make(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); - const std::string& tag = ( - it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { max_rank = std::max(max_rank, ThreadScope::make(tag).rank); } @@ -91,20 +87,16 @@ StorageScope InferStorageScope( return s; } - -void InferRootBound(const Stage& stage, - const GraphContext& ctx, +void InferRootBound(const Stage& stage, const GraphContext& ctx, std::unordered_map* rmap) { - CHECK_NE(stage->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops"; if (stage->attach_type == kInlinedAlready) return; if (stage->is_output) { // verify correctness. - CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) - << "Output must be attached at root"; + CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root"; } if (stage->is_output || stage->op.as()) { - for (auto iv : stage->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; @@ -137,7 +129,7 @@ void InferRootBound(const Stage& stage, Array stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { - std::unordered_map relax_set; + Map relax_set; std::unordered_map up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); @@ -154,9 +146,8 @@ void InferRootBound(const Stage& stage, if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << " call schedule.normalize to achieve this. "; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << " call schedule.normalize to achieve this. "; if (ctx.bind_map.count(iv)) { up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var); } else { @@ -172,13 +163,12 @@ void InferRootBound(const Stage& stage, found_attach = true; } Range vrange = rmap->at(iv); - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << "call schedule.normalize to achieve this."; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - relax_set[iv->var.get()] = IntSet::range(vrange); + relax_set.Set(iv->var, IntSet::range(vrange)); if (ctx.bind_map.count(iv)) { - relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); + relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange)); } } } @@ -190,6 +180,9 @@ void InferRootBound(const Stage& stage, // Relax if needed. std::unordered_map dom_map; arith::Analyzer analyzer; + for (auto entry : *rmap) { + analyzer.Bind(entry.first->var, entry.second); + } for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { @@ -198,11 +191,13 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = EvalSet(r, relax_set); + dom_map[iv->var.get()] = + IntSet::interval(analyzer.int_set(r->min, relax_set).min(), + analyzer.int_set(r->min + r->extent - 1, relax_set).max()); } else { dom_map[iv->var.get()] = IntSet::range(r); } - analyzer.Bind(iv->var, r); + analyzer.Bind(iv->var, r, true); } op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); } @@ -252,15 +247,13 @@ Map InferBound(const Schedule& sch) { } } for (auto& p : ret) { - ret[p.first] = Range::make_by_min_extent( - analyzer.Simplify(p.second->min), - analyzer.Simplify(p.second->extent)); + ret[p.first] = Range::make_by_min_extent(analyzer.Simplify(p.second->min), + analyzer.Simplify(p.second->extent)); } return Map(ret.begin(), ret.end()); } -TVM_REGISTER_GLOBAL("schedule.InferBound") -.set_body_typed(InferBound); +TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index 9dce36f220ef8..6414822290261 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -21,14 +21,16 @@ * \file graph.cc * \brief Utilities to get information about schedule graph. */ +#include "graph.h" + #include +#include #include #include -#include -#include -#include + #include -#include "graph.h" +#include +#include namespace tvm { namespace te { @@ -39,22 +41,14 @@ struct TensorDimKey { int dim; TensorDimKey() {} TensorDimKey(const tir::CallNode* op, int dim) - : f(op->func), value_index(op->value_index), dim(dim) { - } - TensorDimKey(const Tensor& t, int dim) - : f(t->op), value_index(t->value_index), dim(dim) { - } + : f(op->func), value_index(op->value_index), dim(dim) {} + TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {} TensorDimKey(const Tensor& t, size_t dim) - : f(t->op), value_index(t->value_index), dim(static_cast(dim)) { - } + : f(t->op), value_index(t->value_index), dim(static_cast(dim)) {} inline bool operator==(const TensorDimKey& other) const { - return f == other.f && - value_index == other.value_index && - dim == other.dim; - } - inline bool operator!=(const TensorDimKey& other) const { - return !operator==(other); + return f == other.f && value_index == other.value_index && dim == other.dim; } + inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); } }; } // namespace te } // namespace tvm @@ -64,15 +58,13 @@ template <> struct hash<::tvm::te::TensorDimKey> { std::size_t operator()(const ::tvm::te::TensorDimKey& k) const { size_t lhs = ::tvm::ObjectHash()(k.f); - size_t rhs = static_cast(k.value_index) << 16UL | - static_cast(k.dim); + size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; } }; } // namespace std - namespace tvm { namespace te { @@ -105,12 +97,9 @@ ReadGraph CreateReadGraph(const Array& roots) { // Do DFS visit to get the subgraph. // Return if op is inside the subgraph. -bool GetSubGraphByPostDFS_( - const Operation& op, - const std::unordered_set& boundary, - bool include_bounary, - std::unordered_map* visited, - Array* result) { +bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set& boundary, + bool include_bounary, std::unordered_map* visited, + Array* result) { if (visited->count(op.get())) { return visited->at(op.get()); } @@ -127,9 +116,7 @@ bool GetSubGraphByPostDFS_( // check if we can reach boundary. bool reach_boundary = false; for (Tensor t : op->InputTensors()) { - if (GetSubGraphByPostDFS_(t->op, boundary, - include_bounary, - visited, result)) { + if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) { reach_boundary = true; } } @@ -140,8 +127,7 @@ bool GetSubGraphByPostDFS_( return reach_boundary; } -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs) { Array result; std::unordered_set boundary; @@ -150,16 +136,12 @@ Array GetSubGraph(const Array& outputs, } std::unordered_map visited; for (Tensor t : outputs) { - GetSubGraphByPostDFS_(t->op, boundary, include_inputs, - &visited, &result); + GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); } return result; } - -void PostDFSOrder(const Operation& op, - const ReadGraph& g, - std::unordered_set* visited, +void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, Array* post_order) { if (visited->count(op)) return; visited->insert(op); @@ -169,9 +151,7 @@ void PostDFSOrder(const Operation& op, post_order->push_back(op); } -Array PostDFSOrder( - const Array& roots, - const ReadGraph& g) { +Array PostDFSOrder(const Array& roots, const ReadGraph& g) { std::unordered_set visited; Array post_order; for (Operation op : roots) { @@ -196,8 +176,7 @@ AttachPath CreateAttachPath(Schedule sch) { std::unordered_set visited; Array path; for (Stage s = stage; s.defined();) { - CHECK(!visited.count(s.get())) - << "Find loop in compute_at attach group"; + CHECK(!visited.count(s.get())) << "Find loop in compute_at attach group"; visited.insert(s.get()); Stage spec = s.GetAttachSpec(); bool start_attach; @@ -221,9 +200,8 @@ AttachPath CreateAttachPath(Schedule sch) { } if (start_attach) path.push_back(iv); } - CHECK(start_attach) - << "Invalid Schedule: cannot find attach point " << attach_ivar - << " in the schedule of " << s->op; + CHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; } if (!ret.count(stage->op)) { ret.Set(stage->op, path); @@ -233,7 +211,7 @@ AttachPath CreateAttachPath(Schedule sch) { } // graph of push reach relation of tensor dimensions -using ReachGraph = std::unordered_map >; +using ReachGraph = std::unordered_map>; ReachGraph GetReachGraph(const Array& ops) { ReachGraph reach; @@ -249,10 +227,8 @@ ReachGraph GetReachGraph(const Array& ops) { for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); for (int k = 1; k < static_cast(update[i]->shape.size()); ++k) { - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(update[i], k)); - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(init[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k)); } } } else if (const auto* compute_op = op.as()) { @@ -264,13 +240,13 @@ ReachGraph GetReachGraph(const Array& ops) { reach[TensorDimKey(t, i)] = {}; } auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { - const tir::CallNode *call = n.as(); + const tir::CallNode* call = n.as(); if (call != nullptr && call->func.defined()) { if (!bset.count(call->func.get())) return; for (size_t i = 0; i < call->args.size(); ++i) { TensorDimKey dkey(call, static_cast(i)); auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { - const VarNode *v = node.as(); + const VarNode* v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { reach[it->second].push_back(dkey); @@ -315,8 +291,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } // merge exact reach - auto f_merge_key = [&exact_reach, &fail_set]( - const TensorDimKey& dst, const TensorDimKey& src) { + auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) { auto sit = exact_reach.find(src); if (sit == exact_reach.end()) return; auto dit = exact_reach.find(dst); @@ -343,7 +318,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map > vmap; + std::unordered_map> vmap; const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; @@ -352,9 +327,8 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } vmap[axis[i]->var.get()] = std::move(keys); } - auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( - const ObjectRef& n) { - const tir::CallNode *call = n.as(); + auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) { + const tir::CallNode* call = n.as(); if (call != nullptr && call->func.defined()) { for (size_t i = 0; i < call->args.size(); ++i) { auto it = vmap.find(call->args[i].get()); @@ -391,8 +365,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { TensorDimKey key(scan->update[i], k); TensorDimKey target(scan->state_placeholder[i], k); IterVar sp_iv = scan->spatial_axis_[sp_idx]; - if (fail_set.count(sp_iv.get()) || - !exact_reach.count(key) || + if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) || exact_reach.at(key) != sp_iv.get()) { ret.Set(sp_iv, make_const(DataType::Int(32), 0)); } else { @@ -430,24 +403,18 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { return ret; } - -TVM_REGISTER_GLOBAL("schedule.CreateReadGraph") -.set_body_typed(CreateReadGraph); +TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") -.set_body_typed([](const Array& roots, - const ReadGraph& g) { - return PostDFSOrder(roots, g); -}); + .set_body_typed([](const Array& roots, const ReadGraph& g) { + return PostDFSOrder(roots, g); + }); -TVM_REGISTER_GLOBAL("schedule.CreateAttachPath") -.set_body_typed(CreateAttachPath); +TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath); -TVM_REGISTER_GLOBAL("schedule.ScanGetBody") -.set_body_typed(ScanGetBody); +TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody); -TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis") -.set_body_typed(ScanFixPointAnalysis); +TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.h b/src/te/schedule/graph.h index c3478c705145e..bb98ff4b706d3 100644 --- a/src/te/schedule/graph.h +++ b/src/te/schedule/graph.h @@ -24,9 +24,10 @@ #ifndef TVM_TE_SCHEDULE_GRAPH_H_ #define TVM_TE_SCHEDULE_GRAPH_H_ -#include -#include #include +#include +#include + #include #include #include @@ -72,8 +73,7 @@ ReadGraph CreateReadGraph(const Array& roots); * * \return The subgraph. */ -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs); /*! @@ -85,8 +85,7 @@ Array GetSubGraph(const Array& outputs, * \note PostDFSOrder is a special case of Topoligical order, * and can be used when topoligical order is needed. */ -Array PostDFSOrder( - const Array& roots, const ReadGraph& g); +Array PostDFSOrder(const Array& roots, const ReadGraph& g); /*! * \brief Create feedgraph for given Schedule diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 1453ed0683e4e..4f0e98243a020 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -21,9 +21,11 @@ * \file message_passing.cc * \brief The message passing domain. */ +#include "message_passing.h" + #include #include -#include "message_passing.h" + #include "../../arith/compute_expr.h" namespace tvm { @@ -31,22 +33,18 @@ namespace te { using namespace tir; -void Update(std::unordered_map* p_state, - const IterVar& iv, - Range r, +void Update(std::unordered_map* p_state, const IterVar& iv, Range r, arith::Analyzer* analyzer) { auto it = p_state->find(iv); if (it == p_state->end()) { (*p_state)[iv] = r; analyzer->Bind(iv->var, r); } else { - bool match = is_zero(it->second->min) && - analyzer->CanProve(r->extent - it->second->extent == 0); - CHECK(match) - << iv - << " domain already inferred," - << " cannot prove their extents are the same " - << it->second->extent << " vs " << r->extent; + bool match = + is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0); + CHECK(match) << iv << " domain already inferred," + << " cannot prove their extents are the same " << it->second->extent << " vs " + << r->extent; } } @@ -89,10 +87,8 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* } } -void PassDownDomain(const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* actx, - bool allow_missing) { +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* actx, bool allow_missing) { auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); @@ -100,7 +96,7 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { + auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(a < b)) { return actx->Simplify(a); } @@ -138,20 +134,16 @@ void PassDownDomain(const Stage& stage, }; if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->inner, r->factor)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx); Update(p_state, r->outer, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->factor)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->factor)), actx); } else { Update(p_state, r->outer, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->outer, r->nparts)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx); Update(p_state, r->inner, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->nparts)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx); } } else if (const FuseNode* r = rel.as()) { if (!state.count(r->outer) || !state.count(r->inner)) { @@ -160,16 +152,13 @@ void PassDownDomain(const Stage& stage, } const Range& range_outer = state.at(r->outer); const Range& range_inner = state.at(r->inner); - state[r->fused] = Range::make_by_min_extent( - 0, range_outer->extent * range_inner->extent); + state[r->fused] = Range::make_by_min_extent(0, range_outer->extent * range_inner->extent); } else if (const RebaseNode* r = rel.as()) { if (!state.count(r->parent)) { CHECK(allow_missing); continue; } - Update(p_state, r->rebased, - Range::make_by_min_extent( - 0, state.at(r->parent)->extent), actx); + Update(p_state, r->rebased, Range::make_by_min_extent(0, state.at(r->parent)->extent), actx); } else if (const SingletonNode* s = rel.as()) { Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx); } else { @@ -185,10 +174,8 @@ void PassDownDomain(const Stage& stage, } } -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; @@ -244,10 +231,8 @@ void PassUpIndex(const Stage& stage, } } -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { if (const SplitNode* s = rel.as()) { @@ -292,16 +277,10 @@ void PassDownIndex(const Stage& stage, } // Domain message passing. -void PassUpDomain(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent) { - if (dom_map.count(s->outer) && - dom_map.count(s->inner) && - dom_map.count(s->parent) && - outer.match_range(dom_map.at(s->outer)) && - inner.match_range(dom_map.at(s->inner))) { +void PassUpDomain(const SplitNode* s, const std::unordered_map& dom_map, + const IntSet& outer, const IntSet& inner, IntSet* parent) { + if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) && + outer.match_range(dom_map.at(s->outer)) && inner.match_range(dom_map.at(s->inner))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } @@ -310,16 +289,12 @@ void PassUpDomain(const SplitNode* s, CHECK(outer.defined()); CHECK(inner.defined()); CHECK(factor.defined()); - *parent = arith::EvalSet( - s->outer->var * factor + s->inner->var + parent_min, - {{s->outer, outer}, {s->inner, inner}}); + *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min, + {{s->outer, outer}, {s->inner, inner}}); } -void PassUpDomain(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner) { +void PassUpDomain(const FuseNode* s, const std::unordered_map& dom_map, + const IntSet& fused, IntSet* outer, IntSet* inner) { CHECK(dom_map.count(s->outer)); CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); @@ -336,8 +311,8 @@ void PassUpDomain(const FuseNode* s, if (fused.is_single_point()) { PrimExpr value = fused.point_value(); PrimExpr factor = dom_map.at(s->inner)->extent; - PrimExpr v_outer = indexdiv(value, factor); - PrimExpr v_inner = indexmod(value, factor); + PrimExpr v_outer = indexdiv(value, factor); + PrimExpr v_inner = indexmod(value, factor); if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min; *outer = IntSet::single_point(v_outer); @@ -345,9 +320,8 @@ void PassUpDomain(const FuseNode* s, } else { PrimExpr fused_extent = (fused.max() - fused.min() + 1); PrimExpr inner_extent = dom_map.at(s->inner)->extent; - *outer = IntSet::interval( - outer_min + indexdiv(fused.min(), inner_extent), - outer_min + indexdiv(fused.max(), inner_extent)); + *outer = IntSet::interval(outer_min + indexdiv(fused.min(), inner_extent), + outer_min + indexdiv(fused.max(), inner_extent)); if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { // fused never spans multiple rows, make a tight bounding box @@ -357,8 +331,8 @@ void PassUpDomain(const FuseNode* s, } else { // fused may span multiple rows, use full row widths if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { - LOG(WARNING) << - "fused and original axes are not aligned, this may cause redundant computations"; + LOG(WARNING) + << "fused and original axes are not aligned, this may cause redundant computations"; } *inner = IntSet::range(dom_map.at(s->inner)); } @@ -366,44 +340,34 @@ void PassUpDomain(const FuseNode* s, } } -void PassUpDomain(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& rebased, - IntSet* parent) { +void PassUpDomain(const RebaseNode* s, const std::unordered_map& dom_map, + const IntSet& rebased, IntSet* parent) { CHECK(dom_map.count(s->parent)); if (rebased.match_range(dom_map.at(s->rebased))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } PrimExpr parent_min = dom_map.at(s->parent)->min; - *parent = arith::EvalSet(s->rebased->var + parent_min, - {{s->rebased, rebased}}); + *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; if (const SplitNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->outer), state.at(r->inner), - &parent); + PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent); state[r->parent] = parent; } else if (const FuseNode* r = rel.as()) { IntSet outer, inner; - PassUpDomain(r, dom_map, - state.at(r->fused), - &outer, &inner); + PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner); state[r->outer] = outer; state[r->inner] = inner; } else if (const RebaseNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->rebased), - &parent); + PassUpDomain(r, dom_map, state.at(r->rebased), &parent); state[r->parent] = parent; } else if (rel.as()) { } else { @@ -413,8 +377,7 @@ void PassUpDomain(const Stage& stage, } // Pass up bit mask with or relation. -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { @@ -461,8 +424,7 @@ void PassUpBitMaskOr(const Stage& stage, } } -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { @@ -509,17 +471,14 @@ void PassDownBitMaskOr(const Stage& stage, } } - /*! * \brief message passing to find if boundary checking on IterVar is needed. * \param s The stage to be used. * \param p_state The message passing state * IterVar->flag */ -void PassUpBoundCheck(const Stage& s, - const Map& dom_map, - std::unordered_map* p_state, - arith::Analyzer* analyzer) { +void PassUpBoundCheck(const Stage& s, const Map& dom_map, + std::unordered_map* p_state, arith::Analyzer* analyzer) { auto& state = *p_state; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; @@ -560,16 +519,14 @@ bool IsRangeSame(const Range input_1, const Range input_2) { arith::Analyzer analyzer; if (input_1.same_as(input_2)) return true; - return (analyzer.CanProve(input_1->min == input_2->min) - && analyzer.CanProve(input_1->extent == input_2->extent)); + return (analyzer.CanProve(input_1->min == input_2->min) && + analyzer.CanProve(input_1->extent == input_2->extent)); } -std::vector MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter) { +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter) { arith::Analyzer analyzer; std::unordered_map bound_state; @@ -579,11 +536,15 @@ std::vector MakeBoundCheck( PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; - std::unordered_map iset_dmap; + Map iset_dmap; // setup domain map for set analysis for (const auto& kv : dom_map) { - iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); + iset_dmap.Set(kv.first->var, IntSet::range(kv.second)); + } + + for (auto entry : dom_map) { + analyzer.Bind(entry.first->var, entry.second); } for (const IterVar& iv : stage->all_iter_vars) { @@ -591,7 +552,7 @@ std::vector MakeBoundCheck( if (bound_state.at(iv)) { Range dom = dom_map.at(iv); PrimExpr value = value_map.at(iv) - dom->min; - PrimExpr vmax = EvalSet(value, iset_dmap).max(); + PrimExpr vmax = analyzer.int_set(value, iset_dmap).max(); if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } @@ -603,7 +564,7 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { PrimExpr value = value_map.at(iv) - iv->dom->min; - IntSet s = EvalSet(value, iset_dmap); + IntSet s = analyzer.int_set(value, iset_dmap); PrimExpr vmin = s.min(); PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] diff --git a/src/te/schedule/message_passing.h b/src/te/schedule/message_passing.h index 187723516f974..c382b90d630c7 100644 --- a/src/te/schedule/message_passing.h +++ b/src/te/schedule/message_passing.h @@ -25,10 +25,11 @@ #ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ #define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ -#include -#include -#include #include +#include +#include +#include + #include #include #include @@ -45,11 +46,8 @@ namespace te { * \param analyzer Analyzer context, storing information about bounds in p_state. * \param allow_missing Whether allow missing value. */ -void PassDownDomain( - const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* analyzer, - bool allow_missing = false); +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* analyzer, bool allow_missing = false); /*! * \param Upward inference of index of each IterVar. @@ -60,10 +58,8 @@ void PassDownDomain( * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Downward inference of index of each IterVar. @@ -74,10 +70,8 @@ void PassUpIndex(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Upward inference of domain set of each IterVar. @@ -87,8 +81,7 @@ void PassDownIndex(const Stage& stage, * \param dom_map The domain map of each iteration variable's maximum domain. * \param p_state The index state of each IterVar. */ -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state); /*! @@ -97,8 +90,7 @@ void PassUpDomain(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -107,8 +99,7 @@ void PassUpBitMaskOr(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -120,13 +111,10 @@ void PassDownBitMaskOr(const Stage& stage, * \param skip_iter The set of variables to skip bound condition. * \return List of predicates that we need to check. */ -std::vector -MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter); +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter); } // namespace te } // namespace tvm diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index c3f333e522c8c..8c8f092b70087 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -20,14 +20,16 @@ /*! * \file operation_inline.cc */ +#include "operation_inline.h" + +#include #include #include -#include #include + #include -#include "operation_inline.h" -#include "../../tir/transforms/ir_util.h" +#include "../../tir/transforms/ir_util.h" namespace tvm { namespace te { @@ -62,8 +64,7 @@ class OperationInliner final : public StmtExprMutator { for (size_t i = 0; i < args_.size(); ++i) { vmap.Set(args_[i], op->args[i]); } - expr = Substitute( - EvaluateNode::make(expr), vmap).as()->value; + expr = Substitute(EvaluateNode::make(expr), vmap).as()->value; } return expr; } else { @@ -77,12 +78,8 @@ class OperationInliner final : public StmtExprMutator { PrimExpr body_; }; -Stmt Inline(Stmt stmt, - Operation f, - Array args, - PrimExpr body) { - CHECK_EQ(f->num_outputs(), 1) - << "can only inline output single value operation"; +Stmt Inline(Stmt stmt, Operation f, Array args, PrimExpr body) { + CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; Stmt ret = OperationInliner(f, args, body)(std::move(stmt)); if (ret.same_as(stmt)) return ret; return ConvertSSA(ret); diff --git a/src/te/schedule/operation_inline.h b/src/te/schedule/operation_inline.h index d7d55cc660272..d475fbe3787ef 100644 --- a/src/te/schedule/operation_inline.h +++ b/src/te/schedule/operation_inline.h @@ -22,10 +22,10 @@ #ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_ #define TVM_TE_SCHEDULE_OPERATION_INLINE_H_ -#include -#include #include #include +#include +#include namespace tvm { namespace te { @@ -41,10 +41,7 @@ namespace te { * * \note All the passes in this file uses SSA form and outputs SSA form. */ -Stmt Inline(Stmt stmt, - Operation op, - Array args, - PrimExpr body); +Stmt Inline(Stmt stmt, Operation op, Array args, PrimExpr body); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index f3e76a45e7db5..ed2880653d63d 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -20,20 +20,21 @@ /*! * \file schedule_dataflow_rewrite.cc */ -#include #include +#include #include + #include -#include "message_passing.h" -#include "operation_inline.h" -#include "../../tir/transforms/ir_util.h" #include "../../arith/compute_expr.h" +#include "../../tir/transforms/ir_util.h" +#include "message_passing.h" +#include "operation_inline.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { @@ -45,9 +46,7 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { // The replacer of cache. class VarReplacer : public tir::StmtExprMutator { public: - explicit VarReplacer( - const std::unordered_map& vsub) - : vsub_(vsub) {} + explicit VarReplacer(const std::unordered_map& vsub) : vsub_(vsub) {} PrimExpr VisitExpr_(const VarNode* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; @@ -56,19 +55,16 @@ class VarReplacer : public tir::StmtExprMutator { tir::CommReducer MutateCommReducer(tir::CommReducer combiner) { // Replace free variables in combiner - auto new_identity = tir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); - auto new_result = tir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); + auto new_identity = tir::UpdateArray(combiner->identity_element, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); + auto new_result = tir::UpdateArray(combiner->result, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); if (combiner->identity_element.same_as(new_identity) && combiner->identity_element.same_as(new_result)) { return combiner; } else { - return tir::CommReducerNode::make( - combiner->lhs, combiner->rhs, new_result, new_identity); + return tir::CommReducerNode::make(combiner->lhs, combiner->rhs, new_result, new_identity); } } @@ -79,12 +75,8 @@ class VarReplacer : public tir::StmtExprMutator { if (op->combiner.same_as(new_combiner)) { return new_e; } else { - return tir::ReduceNode::make( - new_combiner, - new_reduce->source, - new_reduce->axis, - new_reduce->condition, - new_reduce->value_index); + return tir::ReduceNode::make(new_combiner, new_reduce->source, new_reduce->axis, + new_reduce->condition, new_reduce->value_index); } } @@ -92,8 +84,7 @@ class VarReplacer : public tir::StmtExprMutator { const std::unordered_map& vsub_; }; -PrimExpr InjectPredicate(const Array& predicates, - PrimExpr body) { +PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { using tir::ReduceNode; using tir::SelectNode; if (predicates.size() == 0) return body; @@ -103,16 +94,14 @@ PrimExpr InjectPredicate(const Array& predicates, n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), - body, - make_zero(body.dtype())); + return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), body, + make_zero(body.dtype())); } // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. // Need to keep an update to the date transitive closure property on the vmap by a reverse map. -void ReplaceDataFlow(const Array& stages, - std::unordered_map* vmap, +void ReplaceDataFlow(const Array& stages, std::unordered_map* vmap, std::unordered_map* rvmap) { for (Stage s : stages) { Operation op = s->op->ReplaceInputs(s->op, *vmap); @@ -132,14 +121,11 @@ void ReplaceDataFlow(const Array& stages, } inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -Tensor Schedule::cache_read(const Tensor& tensor, - const std::string& scope, +Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, const Array& readers) { (*this)->InvalidateCache(); // create identity mapping. @@ -153,9 +139,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, std::unordered_map vsub; Stage s = operator[](tensor->op); Tensor sugar_tensor = s->op.output(tensor->value_index); - Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array& i) { - return sugar_tensor(Array(i.begin(), i.end())); - }, os.str()); + Tensor cache = compute( + sugar_tensor->shape, + [&sugar_tensor](const Array& i) { + return sugar_tensor(Array(i.begin(), i.end())); + }, + os.str()); vsub[sugar_tensor] = cache; std::unordered_map vmap; @@ -163,9 +152,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, for (Operation op : readers) { Stage s = operator[](op); Operation repl_op = s->op->ReplaceInputs(s->op, vsub); - CHECK(!repl_op.same_as(s->op)) - << "Cannot find " << tensor - << " in the inputs of " << s->op; + CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op; vmap[s->op.output(0)] = repl_op.output(0); rvmap[repl_op.output(0)] = s->op.output(0); s->op = repl_op; @@ -177,8 +164,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, Stage cache_stage = Stage(cache->op); cache_stage.set_scope(scope); CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos + 1, - cache_stage); + stages->data.insert(stages->data.begin() + pos + 1, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage); // Update group cache_stage->group = op_stage->group; @@ -188,12 +174,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -template -void PrepareAxisMapping(Stage orig_stage, - OpType* op, - std::unordered_set* p_red_axis, - Array* p_new_axis, - std::unordered_map* p_dom_map, +template +void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set* p_red_axis, + Array* p_new_axis, std::unordered_map* p_dom_map, std::unordered_map* p_vsub, std::unordered_map* p_vsub2newvar, std::vector* p_predicates) { @@ -218,11 +201,9 @@ void PrepareAxisMapping(Stage orig_stage, std::unordered_map value_map; for (IterVar iv : orig_stage->leaf_iter_vars) { if (red_axis.count(iv)) continue; - CHECK_EQ(iv->iter_type, kDataPar) - << "Can only relayout with in data parallel dimensions"; + CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions"; Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make( - dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVarNode::make(dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { value_map[iv] = dom->min; @@ -237,8 +218,7 @@ void PrepareAxisMapping(Stage orig_stage, skip_bound_check.insert(iv); } PassUpIndex(orig_stage, dom_map, &value_map, true); - predicates = MakeBoundCheck( - orig_stage, dom_map, value_map, true, skip_bound_check); + predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis for (IterVar iv : op->axis) { if (value_map.count(iv)) { @@ -248,12 +228,8 @@ void PrepareAxisMapping(Stage orig_stage, } } -Array ReplaceOriginalOp(Schedule sch, - Stage orig_stage, - const std::string& scope, - Operation cache_op, - Operation orig_new_op, - size_t tensor_size) { +Array ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope, + Operation cache_op, Operation orig_new_op, size_t tensor_size) { Array cache_tensor_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); @@ -280,8 +256,7 @@ Array ReplaceOriginalOp(Schedule sch, Stage cache_stage = Stage(cache_op); cache_stage.set_scope(scope); CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage); + stages->data.insert(stages->data.begin() + pos, cache_stage); sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; @@ -291,10 +266,8 @@ Array ReplaceOriginalOp(Schedule sch, return cache_tensor_list; } - // Cache write and relayout the data according to loop pattern -Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -310,8 +283,8 @@ Array CacheWriteWithReLayout(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, compute, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); PrimExpr body; Array body_list; @@ -326,17 +299,14 @@ Array CacheWriteWithReLayout(Schedule sch, const tir::ReduceNode* reduce_body = body.as(); if (first_reduce != nullptr) { CHECK(ReduceEqual(reduce_body, first_reduce)); - body = tir::ReduceNode::make(first_reduce->combiner, - first_reduce->source, - first_reduce->axis, - first_reduce->condition, - reduce_body->value_index); + body = + tir::ReduceNode::make(first_reduce->combiner, first_reduce->source, first_reduce->axis, + first_reduce->condition, reduce_body->value_index); } else { first_reduce = reduce_body; } } else { - CHECK(first_reduce == nullptr) - << "cannot mix reduce and other node in ONE compute bodys"; + CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys"; } body_list.push_back(body); } @@ -354,26 +324,21 @@ Array CacheWriteWithReLayout(Schedule sch, args.push_back(value_map.at(iv)); } } - Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, compute->tag, compute->attrs, - new_axis, body_list); + Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag, + compute->attrs, new_axis, body_list); Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, + compute->axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - // for tensor compute op -Array CacheWriteWithReLayoutTensor(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -391,14 +356,12 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, tensor_op, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); - + PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar new_iv = IterVarNode::make( - iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVarNode::make(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); } Array new_regions; @@ -417,10 +380,10 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } - Operation cache_op = TensorComputeOpNode::make( - tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->schedulable_ndim, - tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs); + Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag, + new_axis, tensor_op->reduce_axis, + tensor_op->schedulable_ndim, tensor_op->intrin, + tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; @@ -455,19 +418,14 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - tensor_op->name, tensor_op->tag, {}, - compute_axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = + ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - -Array Schedule::cache_write(const Array& tensor_array, - const std::string& scope) { +Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); - CHECK(tensor_array.size() > 0) - << "size of tensor_array must be greater than 0"; + CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0"; Tensor tensor = tensor_array[0]; Stage orig_stage = operator[](tensor->op); const ComputeOpNode* compute = tensor->op.as(); @@ -475,15 +433,12 @@ Array Schedule::cache_write(const Array& tensor_array, << "size of input tensor list must be same as number of stage outputs"; for (size_t i = 1; i < tensor_array.size(); i++) { Stage tmp_stage = operator[](tensor_array[i]->op); - CHECK(orig_stage.same_as(tmp_stage)) - << "Input tensor list must be generated by ONE computeOp"; + CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } return CacheWriteWithReLayout(*this, tensor_array, scope); } - -Tensor Schedule::cache_write(const Tensor& tensor, - const std::string& scope) { +Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { // support original compute and tensor compute both (*this)->InvalidateCache(); if (tensor->op.as()) { @@ -496,7 +451,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, } } - void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { @@ -506,16 +460,14 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); - auto it = s->iter_var_attrs.find(iv); + auto it = s->iter_var_attrs.find(iv); // don;t need to rebase path that are binded. - if (it != s->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } if (idx < leaf_vars->data.size()) { // insert rebase - IterVar rebased = IterVarNode::make( - Range(), iv->var.copy_with_suffix(""), iv->iter_type); + IterVar rebased = IterVarNode::make(Range(), iv->var.copy_with_suffix(""), iv->iter_type); s->relations.push_back(RebaseNode::make(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); @@ -557,13 +509,11 @@ void InjectInline(ScheduleNode* sch) { { // setup args const ComputeOpNode* compute = stage->op.as(); - CHECK(compute) - << "can only inline compute op"; + CHECK(compute) << "can only inline compute op"; for (auto iv : compute->axis) { args.push_back(iv->var); } - CHECK_EQ(compute->body.size(), 1U) - << "can only inline compute op with 1 output"; + CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; body = compute->body[0]; } for (size_t j = i; j < sch->stages.size(); ++j) { @@ -580,12 +530,13 @@ void InjectInline(ScheduleNode* sch) { for (size_t k = 1; k < new_body[j].size(); ++k) { const tir::ReduceNode* reduce_ = new_body[j][k].as(); CHECK(reduce_); - CHECK(ReduceEqual(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } - PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]), - stage->op, args, body).as()->value; + PrimExpr new_value = + Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; const tir::ReduceNode* r = new_value.as(); @@ -600,8 +551,10 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]), - stage->op, args, body).as()->value; + PrimExpr new_value = + Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); changed[j] = true; @@ -632,9 +585,8 @@ void InjectInline(ScheduleNode* sch) { CHECK(compute); Operation op = s->op; if (changed[i]) { - op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, new_body[i]); + op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis, + new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -646,9 +598,8 @@ void InjectInline(ScheduleNode* sch) { } else if (hybrid_changed[i]) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); CHECK(hybrid); - Operation op = HybridOpNode::make( - hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); @@ -674,13 +625,10 @@ Schedule Schedule::normalize() { } // Handle reduction factor. -Array Schedule::rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis) { +Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) { (*this)->InvalidateCache(); using tir::ReduceNode; - CHECK_EQ(axis->iter_type, kCommReduce) - << "Can only factor reduction axis"; + CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); const ComputeOpNode* compute_op = reduce_stage->op.as(); CHECK(compute_op) << "Can only factor ComputeOp"; @@ -699,8 +647,7 @@ Array Schedule::rfactor(const Tensor& tensor, std::unordered_set skip_bound_check; // Verify normal axis are not touched. for (IterVar iv : compute_op->axis) { - CHECK(!touch_map.count(iv)) - << "Factor axis touches normal axis."; + CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis."; skip_bound_check.insert(iv); } // get analyzer. @@ -728,11 +675,11 @@ Array Schedule::rfactor(const Tensor& tensor, } } te::PassUpIndex(reduce_stage, dom_map, &value_map, true); - std::vector predicates = MakeBoundCheck( - reduce_stage, dom_map, value_map, true, skip_bound_check); + std::vector predicates = + MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check); // Get the factored op node. - const int factor_axis_pos = \ + const int factor_axis_pos = factor_axis >= 0 ? factor_axis : static_cast(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); auto n = make_object(); @@ -741,8 +688,7 @@ Array Schedule::rfactor(const Tensor& tensor, // axis relacement. auto iv_node = make_object(); iv_node->dom = dom_map.at(axis); - CHECK(is_zero(iv_node->dom->min)) - << "Can only factor reduction domain starting from 0"; + CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; iv_node->var = axis->var; iv_node->iter_type = kDataPar; @@ -786,18 +732,15 @@ Array Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - Array new_source = tir::UpdateArray(reduce->source, - [&replacer] (const PrimExpr& e) { return replacer(e); }); + Array new_source = + tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); }); PrimExpr new_pred = replacer(predicate); std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { - body.emplace_back(ReduceNode::make(reduce->combiner, - new_source, - n->reduce_axis, - new_pred, - idx)); + body.emplace_back( + ReduceNode::make(reduce->combiner, new_source, n->reduce_axis, new_pred, idx)); } n->body = Array(body); // refresh relations, keep the un-touched relations. @@ -824,16 +767,14 @@ Array Schedule::rfactor(const Tensor& tensor, Stage factor_stage = Stage(factor_op); factor_stage->relations = rels; CHECK_LT(stage_pos, stages->data.size()); - stages->data.insert(stages->data.begin() + stage_pos, - factor_stage); + stages->data.insert(stages->data.begin() + stage_pos, factor_stage); (*this)->stage_map.Set(factor_op, factor_stage); factor_stage->group = reduce_stage->group; if (factor_stage->group.defined()) { ++factor_stage->group->num_child_stages; } // Replace the old reduction. - IterVar repl_red_axis = reduce_axis( - dom_map.at(axis), axis->var->name_hint + ".v"); + IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v"); Array factor_tensors; Array old_tensors; int size = factor_op->num_outputs(); @@ -841,32 +782,33 @@ Array Schedule::rfactor(const Tensor& tensor, factor_tensors.push_back(factor_op.output(idx)); old_tensors.push_back(reduce_stage->op.output(idx)); } - Array repl_tensors = compute(old_tensors[0]->shape, - [&](const Array& i) { - Array indices; - const int idx_size = static_cast(i.size()); - for (int idx = 0; idx < idx_size; ++idx) { - if (factor_axis_pos == idx) { - indices.push_back(repl_red_axis->var); + Array repl_tensors = compute( + old_tensors[0]->shape, + [&](const Array& i) { + Array indices; + const int idx_size = static_cast(i.size()); + for (int idx = 0; idx < idx_size; ++idx) { + if (factor_axis_pos == idx) { + indices.push_back(repl_red_axis->var); + } + indices.push_back(i[idx]); } - indices.push_back(i[idx]); - } - if (factor_axis_pos == idx_size) { + if (factor_axis_pos == idx_size) { indices.push_back(repl_red_axis->var); - } - Array factor_exprs; - for (int idx = 0; idx < size; ++idx) { - factor_exprs.push_back(factor_tensors[idx](indices)); - } - Array reductions; - Array axis = {repl_red_axis}; - PrimExpr cond = const_true(); - for (int idx = 0; idx < size; ++idx) { - reductions.push_back(ReduceNode::make(reduce->combiner, - factor_exprs, axis, cond, idx)); - } - return reductions; - }, reduce_stage->op->name + ".repl"); + } + Array factor_exprs; + for (int idx = 0; idx < size; ++idx) { + factor_exprs.push_back(factor_tensors[idx](indices)); + } + Array reductions; + Array axis = {repl_red_axis}; + PrimExpr cond = const_true(); + for (int idx = 0; idx < size; ++idx) { + reductions.push_back(ReduceNode::make(reduce->combiner, factor_exprs, axis, cond, idx)); + } + return reductions; + }, + reduce_stage->op->name + ".repl"); std::unordered_map vmap; std::unordered_map rvmap; diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index bfee0d5a0a6b8..74ddca5486160 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -22,17 +22,19 @@ */ #include #include -#include #include +#include + #include #include + #include "graph.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { @@ -46,30 +48,23 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) if (pos < leaf_vars->data.size()) return pos; if (FindNodeRef(all_vars, v) < all_vars->data.size()) { - LOG(FATAL) << "Operate on iter var " << v - << "that has already been split"; + LOG(FATAL) << "Operate on iter var " << v << "that has already been split"; } else { - LOG(FATAL) << "Operate on iter var " << v - << "that is not part of the schedule"; + LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule"; } return 0; } -void Split(StageNode* self, - IterVar parent, - PrimExpr factor, - PrimExpr nparts, - IterVar* p_outer, +void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. - CHECK(parent->iter_type == kDataPar || - parent->iter_type == kCommReduce || + CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) << "Cannot split on " << IterVarType2String(parent->iter_type); - IterVar outer = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); - IterVar inner = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); + IterVar outer = + IterVarNode::make(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); + IterVar inner = + IterVarNode::make(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); *p_outer = outer; *p_inner = inner; // The splits @@ -112,8 +107,7 @@ bool Stage::is_scheduled() const { Stage Stage::GetAttachSpec() const { Stage attach_spec = *this; - while (attach_spec->attach_type == kGroupRoot && - attach_spec->group.defined()) { + while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) { attach_spec = attach_spec->group; } return attach_spec; @@ -124,9 +118,8 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*) return *this; } -Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; // Group constraint checking. Stage group = (*this)->group; if (group.defined()) { @@ -134,8 +127,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) while (pg.defined() && !pg.same_as(group)) { pg = pg->group; } - CHECK(pg.same_as(group)) - << "Can only assign compute_at to stages within the same group"; + CHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group"; } (*this)->attach_type = kScope; @@ -144,34 +136,30 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) bool found = false; for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { if (scope == parent->leaf_iter_vars[i]) { - found = true; break; + found = true; + break; } } - CHECK(found) - << "Cannot find the axis " << scope - << " in parent's leaf_iter_vars" - << " parent=" << parent; + CHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars" + << " parent=" << parent; return *this; } -Stage& Stage::compute_inline() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_inline() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kInline; return *this; } -Stage& Stage::compute_root() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_root() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kGroupRoot; return *this; } -Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) +Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) StageNode* self = operator->(); - CHECK(ivar->iter_type == kDataPar || - ivar->iter_type == kCommReduce) + CHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread"; CHECK(thread_ivar->iter_type == kThreadIndex) << "Cannot rebase by " << IterVarType2String(ivar->iter_type) @@ -184,10 +172,8 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) ObjectPtr n; if (it != self->iter_var_attrs.end()) { n = make_object(*(*it).second.operator->()); - if (n->bind_thread.defined() && - !n->bind_thread.same_as(thread_ivar)) { - LOG(WARNING) << "Axis " << ivar - << " is already bind to another thread " << n->bind_thread; + if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { + LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { n = make_object(); @@ -201,18 +187,15 @@ Stage& Stage::env_threads(Array threads) { StageNode* self = operator->(); CHECK(self->op.defined() && self->op.as()) << "env_threads is only valid for composite ops such as ScanOp"; - CHECK_EQ(self->env_threads.size(), 0U) - << "Already set env_threads"; + CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads"; ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); std::vector temp; for (IterVar iv : threads) { temp.push_back(iv); } - leaf_vars->data.insert( - leaf_vars->data.begin(), temp.begin(), temp.end()); - all_vars->data.insert( - all_vars->data.end(), temp.begin(), temp.end()); + leaf_vars->data.insert(leaf_vars->data.begin(), temp.begin(), temp.end()); + all_vars->data.insert(all_vars->data.end(), temp.begin(), temp.end()); self->env_threads = threads; return *this; } @@ -223,36 +206,32 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { return *this; } -Stage& Stage::split( - IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } -Stage& Stage::split_by_nparts( - IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) +Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) StageNode* self = operator->(); - CHECK(outer->iter_type == kDataPar || - outer->iter_type == kCommReduce || + CHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || outer->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(outer->iter_type); - CHECK(inner->iter_type == kDataPar || - inner->iter_type == kCommReduce || + CHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce || inner->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(inner->iter_type); IterVarType iter_type = outer->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type; - std::string fused_name = - outer->var->name_hint + "." + inner->var->name_hint + ".fused"; + std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; - IterVar fused = IterVarNode::make( - Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVarNode::make(Range(), Var(fused_name, outer->var.dtype()), iter_type); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -269,8 +248,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT all_vars->data.push_back(fused); leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, leaf_vars->data.begin() + pos_inner + 1); - leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, - fused); + leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, fused); *p_target = fused; return *this; } @@ -286,9 +264,8 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* StageNode* self = operator->(); // special handle fuse empty array. // insert at the outer most loop - IterVar singleton = IterVarNode::make( - Range::make_by_min_extent(0, 1), - Var("singleton", DataType::Int(32)), kDataPar); + IterVar singleton = IterVarNode::make(Range::make_by_min_extent(0, 1), + Var("singleton", DataType::Int(32)), kDataPar); self->relations.push_back(SingletonNode::make(singleton)); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -303,14 +280,11 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) std::unordered_set seen_var; StageNode* self = operator->(); for (IterVar iv : order) { - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce || + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce || iv->iter_type == kThreadIndex) - << "Cannot reorder IterVar(" - << IterVarType2String(iv->iter_type) << ")"; + << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")"; - CHECK_EQ(seen_var.count(iv), 0) - << "Same axis can not appear more than once " << iv; + CHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv; seen_var.insert(iv); } ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -331,20 +305,16 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) return *this; } -Stage& Stage::tile(IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner) { +Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, + IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { split(x_parent, x_factor, p_x_outer, p_x_inner); split(y_parent, y_factor, p_y_outer, p_y_inner); reorder(Array({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); return *this; } -template -inline void UpdateIterVarAttr(StageNode* self, - IterVar var, - FUpdate fupdate, +template +inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate, bool need_leaf = true) { if (need_leaf) { ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -363,60 +333,53 @@ inline void UpdateIterVarAttr(StageNode* self, } inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) { - UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { - n->iter_type = iter_type; - }); + UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; }); } -Stage& Stage::vectorize(IterVar var) { // NOLINT(*) - CHECK(var->iter_type == kDataPar || - var->iter_type == kOpaque || - var->iter_type == kUnrolled || - var->iter_type == kVectorized || - var->iter_type == kTensorized || +Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + CHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || + var->iter_type == kVectorized || var->iter_type == kTensorized || var->iter_type == kParallelized) << "Cannot vectorize on " << IterVarType2String(var->iter_type); SetAttrIterType(operator->(), var, kVectorized); return *this; } -Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) +Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { - n->iter_type = kTensorized; - n->tensor_intrin = f; - }); + n->iter_type = kTensorized; + n->tensor_intrin = f; + }); return *this; } -Stage& Stage::unroll(IterVar var) { // NOLINT(*) +Stage& Stage::unroll(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kUnrolled); return *this; } -Stage& Stage::parallel(IterVar var) { // NOLINT(*) +Stage& Stage::parallel(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kParallelized); return *this; } -Stage& Stage::pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value) { // NOLINT(*) +Stage& Stage::pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value) { // NOLINT(*) if (pragma_type == "unroll") { this->unroll(var); } else if (pragma_type == "vectorize") { this->vectorize(var); } else { - UpdateIterVarAttr( - operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { - n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); - n->pragma_values.push_back(pragma_value); - }); + UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { + n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); + n->pragma_values.push_back(pragma_value); + }); } return *this; } -Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { - StageNode *self = operator->(); +Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { + StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); @@ -434,16 +397,19 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { } Stage& Stage::storage_align(IterVar axis, int factor, int offset) { - StageNode *self = operator->(); - UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) { - n->dim_align_factor = factor; - n->dim_align_offset = offset; - }, false); + StageNode* self = operator->(); + UpdateIterVarAttr( + self, axis, + [factor, offset](IterVarAttrNode* n) { + n->dim_align_factor = factor; + n->dim_align_offset = offset; + }, + false); return *this; } Stage& Stage::double_buffer() { - StageNode *self = operator->(); + StageNode* self = operator->(); CHECK(!self->is_output) << "Cannot apply double buffer on output"; self->double_buffer = true; return *this; @@ -451,7 +417,7 @@ Stage& Stage::double_buffer() { Stage& Stage::opengl() { CHECK(!is_scheduled()) << "Must be a fresh schedule"; - StageNode *self = operator->(); + StageNode* self = operator->(); auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars CHECK(!all_iter_vars.empty()) << "At least one iter var"; @@ -475,8 +441,7 @@ Stage& Stage::opengl() { break; } default: { - LOG(ERROR) << "Invalid iter var type " - << IterVarType2String(iter_var->iter_type); + LOG(ERROR) << "Invalid iter var type " << IterVarType2String(iter_var->iter_type); break; } } @@ -492,8 +457,7 @@ Stage& Stage::opengl() { } Stage CopyStage(const Stage& s) { - ObjectPtr n = - make_object(*s.operator->()); + ObjectPtr n = make_object(*s.operator->()); return Stage(n); } @@ -521,24 +485,22 @@ Schedule Schedule::copy() const { for (Stage s : n->stages) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } for (Stage s : n->groups) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } @@ -548,8 +510,7 @@ Schedule Schedule::copy() const { Stage Schedule::operator[](const Operation& op) { auto it = (*this)->stage_map.find(op); CHECK(it != (*this)->stage_map.end()) - << "Cannot find Stage for operator " << op - << " in the schedule"; + << "Cannot find Stage for operator " << op << " in the schedule"; return (*it).second; } @@ -570,15 +531,13 @@ Stage LeastCommonAncestor(Stage g1, Stage g2) { return g; } -Array RemapTensor(ScheduleNode* self, - const Array& arr) { +Array RemapTensor(ScheduleNode* self, const Array& arr) { self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; Array ret; for (Tensor t : arr) { if (!op2stage_cache.count(t->op.get())) { - CHECK(self->stage_map.count(t->op)) - << "Given tensor is not in the schedule plan"; + CHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan"; t = self->stage_map[t->op]->op.output(t->value_index); } ret.push_back(t); @@ -587,17 +546,14 @@ Array RemapTensor(ScheduleNode* self, } // Group the schedule stages. -Stage Schedule::create_group(const Array& outputs, - const Array& inputs, +Stage Schedule::create_group(const Array& outputs, const Array& inputs, bool include_inputs) { ScheduleNode* self = operator->(); self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; // Get the ops. - Array ops = te::GetSubGraph( - RemapTensor(self, outputs), - RemapTensor(self, inputs), - include_inputs); + Array ops = + te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs); // local counter entry // Automatically initialize to 0 during creation. struct Entry { @@ -631,7 +587,7 @@ Stage Schedule::create_group(const Array& outputs, // Propagate the counter statistics from by checking if subgroup // Is full and propagate. std::vector stack; - for (auto &kv : counter) { + for (auto& kv : counter) { if (!kv.first.same_as(parent_group)) { if (kv.first->num_child_stages == kv.second.count) { stack.push_back(kv.first); @@ -650,7 +606,7 @@ Stage Schedule::create_group(const Array& outputs, } } // Verification and remappig the subgroups. - for (auto &kv : counter) { + for (auto& kv : counter) { if (kv.first.same_as(parent_group)) continue; CHECK_EQ(kv.first->num_child_stages, kv.second.count) << "Trying to group region that intersect with an already existed group"; @@ -695,9 +651,7 @@ Stage Schedule::create_group(const Array& outputs, return gstage; } -void ScheduleNode::InvalidateCache() { - op2stage_cache_.clear(); -} +void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); } void ScheduleNode::InitCache() { if (op2stage_cache_.size() == stages.size()) return; @@ -753,10 +707,7 @@ Schedule ScheduleNode::make(Array ops) { return sch; } -IterVarRelation SplitNode::make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, +IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { auto n = make_object(); n->parent = parent; @@ -767,8 +718,7 @@ IterVarRelation SplitNode::make(IterVar parent, return IterVarRelation(n); } -IterVarRelation FuseNode::make( - IterVar outer, IterVar inner, IterVar fused) { +IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) { auto n = make_object(); n->outer = outer; n->inner = inner; @@ -805,19 +755,19 @@ struct TVMSpecializationThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMSpecializationThreadLocalStore; void SpecializedCondition::EnterWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); entry->condition_stack.push(*this); } void SpecializedCondition::ExitWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); CHECK(!entry->condition_stack.empty()); CHECK(entry->condition_stack.top().same_as(*this)); entry->condition_stack.pop(); } SpecializedCondition SpecializedCondition::Current() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); SpecializedCondition cond; if (entry->condition_stack.size() > 0) { cond = entry->condition_stack.top(); @@ -827,13 +777,9 @@ SpecializedCondition SpecializedCondition::Current() { class SpecializedCondition::Internal { public: - static void EnterScope(SpecializedCondition cond) { - cond.EnterWithScope(); - } + static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); } - static void ExitScope(SpecializedCondition cond) { - cond.ExitWithScope(); - } + static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); } }; TVM_REGISTER_NODE_TYPE(StageNode); @@ -847,193 +793,158 @@ TVM_REGISTER_NODE_TYPE(SpecializedConditionNode); // Printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->op.defined()) { - p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; - } else { - p->stream << "group-stage(" << op << ")"; - } -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << IterVarType2String(op->iter_type); -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split(parent="; - p->Print(op->parent); - p->stream << ", outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split("; - p->stream << "outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ", fused="; - p->Print(op->fused); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "rebase("; - p->stream << "parent="; - p->Print(op->parent); - p->stream << ", rebased="; - p->Print(op->rebased); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "singleton("; - p->Print(op->iter); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "schedule(" << op << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "specialized_condition("; - p->Print(op->clauses); - p->stream << ')'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->op.defined()) { + p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; + } else { + p->stream << "group-stage(" << op << ")"; + } + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << IterVarType2String(op->iter_type); + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split(parent="; + p->Print(op->parent); + p->stream << ", outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split("; + p->stream << "outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ", fused="; + p->Print(op->fused); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rebase("; + p->stream << "parent="; + p->Print(op->parent); + p->stream << ", rebased="; + p->Print(op->rebased); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "singleton("; + p->Print(op->iter); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "schedule(" << op << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "specialized_condition("; + p->Print(op->clauses); + p->stream << ')'; + }); -TVM_REGISTER_GLOBAL("te.CreateSchedule") -.set_body_typed(create_schedule); +TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule); -TVM_REGISTER_GLOBAL("te.StageSetScope") -.set_body_method(&Stage::set_scope); +TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); -TVM_REGISTER_GLOBAL("te.StageBind") -.set_body_method(&Stage::bind); +TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("te.StageSplitByFactor") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { - IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + IterVar outer, inner; + stage.split(parent, factor, &outer, &inner); + return Array({outer, inner}); + }); TVM_REGISTER_GLOBAL("te.StageSplitByNParts") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { - IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + IterVar outer, inner; + stage.split_by_nparts(parent, nparts, &outer, &inner); + return Array({outer, inner}); + }); -TVM_REGISTER_GLOBAL("te.StageFuse") -.set_body_typed([](Stage stage, Array axes) { - IterVar fused; - stage.fuse(axes, &fused); - return fused; - }); +TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array axes) { + IterVar fused; + stage.fuse(axes, &fused); + return fused; +}); -TVM_REGISTER_GLOBAL("te.StageComputeAt") -.set_body_method(&Stage::compute_at); +TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at); -TVM_REGISTER_GLOBAL("te.StageComputeInline") -.set_body_method(&Stage::compute_inline); +TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline); -TVM_REGISTER_GLOBAL("te.StageComputeRoot") -.set_body_method(&Stage::compute_root); +TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root); -TVM_REGISTER_GLOBAL("te.StageReorder") -.set_body_method(&Stage::reorder); +TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder); TVM_REGISTER_GLOBAL("te.StageTile") -.set_body_typed([]( - Stage stage, - IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor -) { - IterVar x_outer, y_outer, x_inner, y_inner; - stage.tile(x_parent, y_parent, - x_factor, y_factor, - &x_outer, &y_outer, - &x_inner, &y_inner); - return Array({x_outer, y_outer, x_inner, y_inner}); - }); + .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor, + PrimExpr y_factor) { + IterVar x_outer, y_outer, x_inner, y_inner; + stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); + return Array({x_outer, y_outer, x_inner, y_inner}); + }); -TVM_REGISTER_GLOBAL("te.StageEnvThreads") -.set_body_method(&Stage::env_threads); +TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads); -TVM_REGISTER_GLOBAL("te.StageSetStorePredicate") -.set_body_method(&Stage::set_store_predicate); +TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate); -TVM_REGISTER_GLOBAL("te.StageUnroll") -.set_body_method(&Stage::unroll); +TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll); -TVM_REGISTER_GLOBAL("te.StageVectorize") -.set_body_method(&Stage::vectorize); +TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize); -TVM_REGISTER_GLOBAL("te.StageTensorize") -.set_body_method(&Stage::tensorize); +TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize); -TVM_REGISTER_GLOBAL("te.StageParallel") -.set_body_method(&Stage::parallel); +TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel); -TVM_REGISTER_GLOBAL("te.StagePragma") -.set_body_method(&Stage::pragma); +TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma); -TVM_REGISTER_GLOBAL("te.StagePrefetch") -.set_body_method(&Stage::prefetch); +TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch); -TVM_REGISTER_GLOBAL("te.StageStorageAlign") -.set_body_method(&Stage::storage_align); +TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align); -TVM_REGISTER_GLOBAL("te.StageDoubleBuffer") -.set_body_method(&Stage::double_buffer); +TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); -TVM_REGISTER_GLOBAL("te.StageOpenGL") -.set_body_method(&Stage::opengl); +TVM_REGISTER_GLOBAL("te.StageOpenGL").set_body_method(&Stage::opengl); -TVM_REGISTER_GLOBAL("te.ScheduleNormalize") -.set_body_method(&Schedule::normalize); +TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); -TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup") -.set_body_method(&Schedule::create_group); +TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); -TVM_REGISTER_GLOBAL("te.ScheduleCacheRead") -.set_body_method(&Schedule::cache_read); +TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read); -TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsObjectRef()) { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Tensor(), args[2]); - } else { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Array(), args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]); + } else { + *ret = args[0].operator Schedule().cache_write(args[1].operator Array(), args[2]); + } +}); -TVM_REGISTER_GLOBAL("te.ScheduleRFactor") -.set_body_method(&Schedule::rfactor); +TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor); -TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition") -.set_body_typed([](Array condition) { - return SpecializedCondition(condition); +TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array condition) { + return SpecializedCondition(condition); }); -TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SpecializedCondition::Current(); +TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializedCondition::Current(); }); TVM_REGISTER_GLOBAL("te.EnterSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::EnterScope); + .set_body_typed(SpecializedCondition::Internal::EnterScope); TVM_REGISTER_GLOBAL("te.ExitSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::ExitScope); + .set_body_typed(SpecializedCondition::Internal::ExitScope); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index bdb77b6ba472a..3a26e9842c8df 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -21,31 +21,30 @@ * \file schedule_ops.cc */ #include -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include -#include "graph.h" -#include "../operation/op_util.h" +#include + #include "../../tir/transforms/ir_util.h" +#include "../operation/op_util.h" +#include "graph.h" namespace tvm { namespace te { using namespace tir; -Stmt MakePipeline(const Stage& s, - const std::unordered_map& dom_map, - Stmt consumer, +Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { - producer = AttrStmtNode::make( - s->op, tir::attr::double_buffer_scope, 1, producer); + producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer); } Stmt pipeline = producer; @@ -54,14 +53,12 @@ Stmt MakePipeline(const Stage& s, } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. - pipeline = AttrStmtNode::make( - s->op, tir::attr::realize_scope, - StringImmNode::make(s->scope), - pipeline); + pipeline = + AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImmNode::make(s->scope), pipeline); if (s->is_opengl) { - pipeline = AttrStmtNode::make( - s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); + pipeline = + AttrStmtNode::make(s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); } return pipeline; } @@ -69,28 +66,25 @@ Stmt MakePipeline(const Stage& s, // inject the operator's realization on the stmt. class InjectAttach : public StmtMutator { public: - InjectAttach(const Stage& stage, - const Stage& attach_spec, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) - : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), + InjectAttach(const Stage& stage, const Stage& attach_spec, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop) + : stage_(stage), + attach_spec_(attach_spec), + dom_map_(dom_map), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - op->attr_key == tir::attr::loop_scope) { - if (attach_spec_->attach_type == kScope && - op->node == attach_spec_->attach_ivar) { - CHECK(!found_attach) - << "Find IterVar" << attach_spec_->attach_ivar - << " in multiple places in the IR"; + if (op != nullptr && op->attr_key == tir::attr::loop_scope) { + if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) { + CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar + << " in multiple places in the IR"; found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = + AttrStmtNode::make(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -113,27 +107,27 @@ class InjectAttach : public StmtMutator { // inject the operator's realization on the stmt. class InjectScanStep : public StmtMutator { public: - InjectScanStep(const Stage& stage, - const Operation& scan_op, - const std::unordered_map& dom_map, - bool is_init, + InjectScanStep(const Stage& stage, const Operation& scan_op, + const std::unordered_map& dom_map, bool is_init, bool debug_keep_trivial_loop) - : stage_(stage), scan_op_(scan_op), - dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} + : stage_(stage), + scan_op_(scan_op), + dom_map_(dom_map), + is_init_(is_init), + debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); // update const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || - (op->attr_key == tir::attr::scan_init_scope && is_init_))) { + if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || + (op->attr_key == tir::attr::scan_init_scope && is_init_))) { if (op->node.same_as(scan_op_)) { found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = + AttrStmtNode::make(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -169,8 +163,7 @@ class SchedulePostProc : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::loop_scope || - op->attr_key == tir::attr::scan_init_scope) { + if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) { return this->VisitStmt(op->body); } else if (op->attr_key == tir::attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); @@ -194,8 +187,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - Stmt ret = AttrStmtNode::make( - it->second, op->attr_key, op->value, op->body); + Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -208,8 +200,8 @@ class SchedulePostProc : public StmtExprMutator { if (it != replace_op_.end()) { if (it->second.defined()) { return AttrStmtNode::make( - Array{tuple[0], it->second.output(tensor->value_index)}, - op->attr_key, op->value, this->VisitStmt(op->body)); + Array{tuple[0], it->second.output(tensor->value_index)}, op->attr_key, + op->value, this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -219,9 +211,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make( - it->second.output(tensor->value_index), - op->attr_key, op->value, this->VisitStmt(op->body)); + return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value, + this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -235,9 +226,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = RealizeNode::make( - it->second->op, it->second->value_index, - op->dtype, op->bounds, op->condition, op->body); + Stmt ret = RealizeNode::make(it->second->op, it->second->value_index, op->dtype, op->bounds, + op->condition, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -252,8 +242,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Stmt ret = ProvideNode::make( - dst->op, dst->value_index, op->value, op->args); + Stmt ret = ProvideNode::make(dst->op, dst->value_index, op->value, op->args); return this->VisitStmt(ret); } else { return StmtExprMutator::VisitStmt_(op); @@ -266,9 +255,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - PrimExpr ret = CallNode::make( - op->dtype, dst->op->name, op->args, - op->call_type, dst->op, dst->value_index); + PrimExpr ret = CallNode::make(op->dtype, dst->op->name, op->args, op->call_type, dst->op, + dst->value_index); return this->VisitExpr(ret); } } @@ -299,8 +287,7 @@ class SchedulePostProc : public StmtExprMutator { if (!s->op.same_as(s->origin_op)) { for (int i = 0; i < s->op->num_outputs(); ++i) { Tensor target = s->origin_op.output(i); - AddReplace(s->op.output(i), target, - target, s->origin_op); + AddReplace(s->op.output(i), target, target, s->origin_op); } } // Specially add replacements for scan op. @@ -316,9 +303,7 @@ class SchedulePostProc : public StmtExprMutator { } private: - void AddReplace(Tensor src, - Tensor dst, - Tensor repl_realize = Tensor(), + void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { TensorKey key{src->op, src->value_index}; replace_buffer_[key] = dst; @@ -339,8 +324,7 @@ class SchedulePostProc : public StmtExprMutator { arith::Analyzer analyzer_; }; -Stmt ScheduleOps( - Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { +Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { Stmt body = Stmt(); std::unordered_map dom_map = as_unordered_map(dom_map_); // scan init and scan updates @@ -350,8 +334,7 @@ Stmt ScheduleOps( if (!scan) continue; for (Tensor t : scan->init) { if (scan_init.count(t->op)) { - CHECK(scan_init.at(t->op).same_as(s->op)) - << "Scan init tensor can only belong to one scan"; + CHECK(scan_init.at(t->op).same_as(s->op)) << "Scan init tensor can only belong to one scan"; } else { scan_init[t->op] = s->op; } @@ -365,8 +348,7 @@ Stmt ScheduleOps( // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; - CHECK_NE(s->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; CHECK(s->op.defined()); // no need to specify place holder op. if (s->op.as()) continue; @@ -377,15 +359,13 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.init"; + CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.update"; + CHECK(mu.found_attach) << "did not find attachment point for scan.update"; } else if (attach_spec->attach_type == kInlinedAlready) { // do nothing } else if (attach_spec->attach_type == kGroupRoot) { @@ -396,11 +376,10 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); body = mutator(std::move(body)); - CHECK(mutator.found_attach) - << "did not find attachment point for " << s << " in " - << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar - << ", body:\n" - << body; + CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " + << attach_spec->attach_stage->op << " x " + << attach_spec->attach_ivar << ", body:\n" + << body; } } SchedulePostProc post_proc; @@ -408,8 +387,7 @@ Stmt ScheduleOps( return post_proc(std::move(body)); } -TVM_REGISTER_GLOBAL("schedule.ScheduleOps") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 2) *ret = ScheduleOps(args[0], args[1], false); else diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 2198827279ac1..84166d11881b5 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -23,17 +23,19 @@ * \brief Rewrite the Stmt generated by ScheduleOps * to accomondate tensorcore. */ +#include #include +#include +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include -#include -#include + #include + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -41,10 +43,10 @@ namespace tvm { namespace te { using namespace te; +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; struct Tile { int m{-1}; @@ -61,7 +63,7 @@ std::string simplify_name(std::string input) { } } -PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { +PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { auto cast = input.as(); if (cast == nullptr) { return input; @@ -74,7 +76,7 @@ PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { // MMAMatcher matches C = Cast(A)*Cast(B)+C, // where A & B are fp16/int8 local buffers, // and C is fp32/int32 local buffer. -class MMAMatcher: public StmtVisitor { +class MMAMatcher : public StmtVisitor { public: explicit MMAMatcher(Map extern_buffer) { for (auto kv : extern_buffer) { @@ -130,7 +132,7 @@ class MMAMatcher: public StmtVisitor { } } - inline bool Matched() const {return matched_;} + inline bool Matched() const { return matched_; } friend class ScheduleAnalyser; friend class BufferAnalyser; @@ -141,7 +143,7 @@ class MMAMatcher: public StmtVisitor { DataType dtype; bool external{false}; bool released{false}; - bool same_as(const BufferInfo &bi) { + bool same_as(const BufferInfo& bi) { if (this->dtype != bi.dtype) return false; if (this->name != bi.name) return false; if (this->external != bi.external) return false; @@ -183,10 +185,8 @@ class MMAMatcher: public StmtVisitor { auto* load_c = add->a.as(); BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) - || !buffer_c.same_as(store_buffer) - || !(buffer_c.dtype == DataType::Float(32) || - buffer_c.dtype == DataType::Int(32))) { + if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || + !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { return false; } @@ -198,26 +198,20 @@ class MMAMatcher: public StmtVisitor { auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); auto load_a = load_a_expr.as(); BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) - || !(buffer_a.dtype == DataType::Float(16) || - buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || - buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_a, &buffer_a) || + !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || + buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); auto load_b = load_b_expr.as(); BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) - || !(buffer_b.dtype == DataType::Float(16) || - buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || - buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_b, &buffer_b) || + !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || + buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } @@ -226,8 +220,7 @@ class MMAMatcher: public StmtVisitor { frag_reg_.insert(buffer_b.name); buf_name_.insert(std::make_pair(load_a, buffer_a.name)); buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, - Array{load_a_expr, load_b_expr, add->a})); + mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); return true; } @@ -280,9 +273,8 @@ class BodyVisitor : public StmtExprVisitor { // ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major class ScheduleAnalyser { public: - explicit ScheduleAnalyser(const MMAMatcher &mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), - buf_name_(mma_matcher.buf_name_) {} + explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) + : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} bool MatrixIdentify(Schedule schedule) { // TODO(minmin): handle the case where MatMul is not the output stage @@ -299,8 +291,8 @@ class ScheduleAnalyser { } const VarNode* axis_var[2]; const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size()-2]->var.as(); - axis_var[1] = axis[axis.size()-1]->var.as(); + axis_var[0] = axis[axis.size() - 2]->var.as(); + axis_var[1] = axis[axis.size() - 1]->var.as(); reduce_axis_var = reduce_axis[0]->var.as(); BodyVisitor body_visitor; @@ -342,8 +334,8 @@ class ScheduleAnalyser { matrix_major_.insert(std::make_pair(compute->name, "col_major")); } - for (auto &mma_sync : mma_sync_) { - auto &operands = mma_sync.second; + for (auto& mma_sync : mma_sync_) { + auto& operands = mma_sync.second; auto* load_a = operands[0].as(); auto* load_b = operands[1].as(); auto input0 = simplify_name(buf_name_.find(load_a)->second); @@ -398,8 +390,7 @@ class IndexVisitor : public StmtExprVisitor { class BufferAnalyser : public StmtExprVisitor { public: explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser &schedule_analyser, - const MMAMatcher &mma_matcher) + const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) : matrix_abc_(schedule_analyser.matrix_abc_), matrix_major_(schedule_analyser.matrix_major_), frag_reg_(mma_matcher.frag_reg_) { @@ -418,9 +409,7 @@ class BufferAnalyser : public StmtExprVisitor { if (op->attr_key == tir::attr::thread_extent) { if (const IntImmNode* value = op->value.as()) { thread_extent_.insert( - std::make_pair( - op->node.as()->var->name_hint, - value->value)); + std::make_pair(op->node.as()->var->name_hint, value->value)); } StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == tir::attr::realize_scope) { @@ -447,11 +436,9 @@ class BufferAnalyser : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; if (matrix_abc_.count(key.GetName())) { if (bi.shape.size() < 2) { @@ -483,12 +470,7 @@ class BufferAnalyser : public StmtExprVisitor { strides_.insert(std::make_pair(key.GetName(), strides)); if (frag_reg_.count(bi.name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); frag_load_.insert(std::make_pair(op, dst)); auto rel_index = bi.RelIndex(op->args); @@ -545,12 +527,7 @@ class BufferAnalyser : public StmtExprVisitor { const CallNode* value = op->value.as(); if (value != nullptr && frag_reg_.count(value->name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + PrimExpr dst = CallNode::make(bi.dtype, bi.name, op->args, CallNode::Halide, op->func, 0); frag_store_.insert(std::make_pair(op, dst)); } } @@ -560,11 +537,9 @@ class BufferAnalyser : public StmtExprVisitor { if (op->call_type == CallNode::Halide) { TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key.f; const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; if (matrix_abc_.count(op->name)) { if (bi.shape.size() < 2) { @@ -642,8 +617,7 @@ class BufferAnalyser : public StmtExprVisitor { if (dim < avec.size() && avec[dim].align_factor != 0) { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + \ - indexmod(factor + offset - indexmod(stride, factor), factor); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); stride = analyzer_.Simplify(stride); } rstrides.push_back(stride); @@ -730,29 +704,19 @@ class BufferAnalyser : public StmtExprVisitor { } bool supported_warp_tile_() { - if (warp_tile_.m == 16 && - warp_tile_.n == 16 && - warp_tile_.k == 16) { + if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 32 && - warp_tile_.k == 16) { + if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 32 && - warp_tile_.n == 8 && - warp_tile_.k == 16) { + if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 32) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 128) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { return true; } @@ -760,7 +724,7 @@ class BufferAnalyser : public StmtExprVisitor { } std::unordered_map buf_map_; - std::unordered_map > dim_align_; + std::unordered_map> dim_align_; std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; @@ -780,7 +744,7 @@ class BufferAnalyser : public StmtExprVisitor { // ThreadIdxMutator does the thread index unification inside a warp class ThreadIdxMutator : public StmtExprMutator { public: - explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {} + explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} PrimExpr VisitExpr_(const VarNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); @@ -807,18 +771,18 @@ class ThreadIdxMutator : public StmtExprMutator { // based on tensor core intrinsics class TensorCoreIRMutator : public StmtExprMutator { public: - explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, - const BufferAnalyser &buffer_analyser) + explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, + const BufferAnalyser& buffer_analyser) : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} + matrix_major_(schedule_analyser.matrix_major_), + mma_sync_(schedule_analyser.mma_sync_), + strides_(buffer_analyser.strides_), + frag_reg_(buffer_analyser.frag_reg_), + loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), + frag_load_(buffer_analyser.frag_load_), + frag_store_(buffer_analyser.frag_store_), + warp_tile_(buffer_analyser.warp_tile_), + warp_threads_y_(buffer_analyser.warp_threads_y_) {} Stmt VisitStmt_(const RealizeNode* op) final { TensorKey key{op->func, op->value_index}; @@ -836,16 +800,14 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 0; i < op->bounds.size() - 2; ++i) { new_bounds.push_back(op->bounds[i]); } - CHECK_GE(op->bounds.size(), 2) - << "Less than 2 dimensions for matrix " << key.GetName(); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return RealizeNode::make(op->func, op->value_index, - op->dtype, new_bounds, - op->condition, op->body); + CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key.GetName(); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); + + return RealizeNode::make(op->func, op->value_index, op->dtype, new_bounds, op->condition, + op->body); } return stmt; } @@ -860,14 +822,10 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto it = matrix_abc_.find(simplify_name(node->name)); - CHECK(it != matrix_abc_.end()) - << "Cannot find matrix info for " << node->name; + CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make(op->node, - op->attr_key, - matrix_abc, - body); + return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body); } } return stmt; @@ -877,7 +835,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { - const auto &operands = it->second; + const auto& operands = it->second; PrimExpr a = operands[0]; auto ca = a.as(); PrimExpr b = operands[1]; @@ -889,97 +847,75 @@ class TensorCoreIRMutator : public StmtExprMutator { ObjectPtr buffer_node_b = make_object(); ObjectPtr buffer_node_c = make_object(); - auto mma_sync_call = - [&buffer_node_a, &buffer_node_b, &ca, &cb] - (const Buffer &buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_bmma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } else { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_mma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } - }; + auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { + Buffer buffer_a(buffer_node_a); + Buffer buffer_b(buffer_node_b); + if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { + return EvaluateNode::make(CallNode::make( + DataType::Handle(), intrinsic::tvm_bmma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } else { + return EvaluateNode::make(CallNode::make( + DataType::Handle(), intrinsic::tvm_mma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } + }; - auto call_add_c = - [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, - TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->dtype); - }; + auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { + return add_buffer_bind_scope_(cc, buffer_node_c, TensorKey{cc->func, cc->value_index}, + mma_sync_call, cc->dtype); + }; - auto call_add_b = - [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, - TensorKey{cb->func, cb->value_index}, call_add_c, cb->dtype); - }; + auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { + return add_buffer_bind_scope_(cb, buffer_node_b, TensorKey{cb->func, cb->value_index}, + call_add_c, cb->dtype); + }; - return add_buffer_bind_scope_(ca, buffer_node_a, - TensorKey{ca->func, ca->value_index}, call_add_b, ca->dtype); + return add_buffer_bind_scope_(ca, buffer_node_a, TensorKey{ca->func, ca->value_index}, + call_add_b, ca->dtype); } auto it2 = frag_load_.find(op); if (it2 != frag_load_.end()) { PrimExpr dst = it2->second; - if (op->value.as() != nullptr || - op->value.as() != nullptr) { + if (op->value.as() != nullptr || op->value.as() != nullptr) { auto call = dst.as(); - auto fill_fragment_call = - [this, &op](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_fill_fragment, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value}, - CallNode::Intrinsic)); - }; + auto fill_fragment_call = [this, &op](const Buffer& buffer) { + return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment, + {buffer->data, warp_tile_.m, warp_tile_.n, + warp_tile_.k, buffer->elem_offset, op->value}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, fill_fragment_call, call->dtype); } const CallNode* value = op->value.as(); - CHECK(value != nullptr) - << "Can only load fragment from a buffer"; + CHECK(value != nullptr) << "Can only load fragment from a buffer"; auto it = strides_.find(value->name); - CHECK(it != strides_.end()) - << "Cannot find stride for " << value->name; + CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = CallNode::make(value->dtype, - "&", - {mutated_value}, - CallNode::Extern); + PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern); auto call = dst.as(); PrimExpr matrix_major; auto iter2 = matrix_major_.find(simplify_name(call->name)); - CHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << call->name; + CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << call->name; if (iter2->second == "col_major") { matrix_major = StringImmNode::make("col_major"); } else if (iter2->second == "row_major") { @@ -988,20 +924,16 @@ class TensorCoreIRMutator : public StmtExprMutator { LOG(FATAL) << "invalid matrix major for " << call->name; } - auto load_matrix_call = - [this, &src, &stride, &matrix_major](const Buffer &buffer) { + auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_load_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + CallNode::make(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{op->func, op->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index}, load_matrix_call, call->dtype); } @@ -1009,39 +941,30 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it3 != frag_store_.end()) { TensorKey key{op->func, op->value_index}; auto it = strides_.find(key.GetName()); - CHECK(it != strides_.end()) - << "Cannot find stride for " << key.GetName(); + CHECK(it != strides_.end()) << "Cannot find stride for " << key.GetName(); auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; PrimExpr dst = it3->second; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = CallNode::make(DataType::Handle(), - "&", - {dst}, - CallNode::Extern); + dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern); auto call = op->value.as(); - auto store_matrix_call = - [this, &dst, &stride](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_store_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, - StringImmNode::make("col_major")}, - CallNode::Intrinsic)); - }; + auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { + return EvaluateNode::make( + CallNode::make(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, StringImmNode::make("col_major")}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, + return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, store_matrix_call, call->dtype); } @@ -1056,54 +979,54 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it != loop_scaling_.end()) { int scale_factor = it->second; int scaled_extent_value = 1; - if (const IntImmNode *ori_extent = op->extent.as()) { + if (const IntImmNode* ori_extent = op->extent.as()) { int ori_extent_value = ori_extent->value; scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, - op->device_api, op->body); + stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, + op->body); } } return stmt; } private: - Array get_tile_size_(const std::string &name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; + Array get_tile_size_(const std::string& name) { + auto it = matrix_abc_.find(name); + auto it2 = matrix_major_.find(name); + CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) + << "Cannot find matrix info for " << name; + PrimExpr size0 = make_const(DataType::Int(32), 16); + PrimExpr size1 = make_const(DataType::Int(32), 16); + if (it->second == "matrix_a" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + if (it->second == "matrix_a" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.m); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_b" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.n); + } + if (it->second == "matrix_b" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_c") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + Array tile_size = {size0, size1}; + return tile_size; } - Stmt add_buffer_bind_scope_(const CallNode* call, - const ObjectPtr &buffer_node, const TensorKey &key, - const std::function &call_back, - DataType datatype) { + Stmt add_buffer_bind_scope_(const CallNode* call, const ObjectPtr& buffer_node, + const TensorKey& key, + const std::function& call_back, + DataType datatype) { auto it = bounds_.find(key); CHECK(it != bounds_.end()); Array min_bound; @@ -1134,13 +1057,11 @@ class TensorCoreIRMutator : public StmtExprMutator { CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( - elem_offset, MulNode::make( - strides[i], SubNode::make(call->args[i], min_bound[i]))); + elem_offset, MulNode::make(strides[i], SubNode::make(call->args[i], min_bound[i]))); } auto it2 = matrix_abc_.find(simplify_name(call->name)); - CHECK(it2 != matrix_abc_.end()) - << "Cannot find matrix info for " << call->name; + CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << call->name; buffer_node->data = Var(call->name, DataType::Handle()); buffer_node->name = call->name; buffer_node->scope = "wmma." + it2->second; @@ -1164,15 +1085,10 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(call->args[i]); args.push_back(shape[i]); } - auto tuple = CallNode::make(DataType::Handle(), - intrinsic::tvm_tuple, - args, - CallNode::Intrinsic); + auto tuple = + CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; - return AttrStmtNode::make(node, - "buffer_bind_scope", - tuple, - call_back(buffer)); + return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); } std::unordered_map matrix_abc_; @@ -1189,10 +1105,8 @@ class TensorCoreIRMutator : public StmtExprMutator { int warp_threads_y_{-1}; }; -Stmt SchedulePostProcRewriteForTensorCore( - Stmt stmt, - Schedule schedule, - Map extern_buffer) { +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer) { // Check if current lower target is CUDA auto target = tvm::Target::Current(true); if (target.defined() && target->target_name != "cuda") { @@ -1217,8 +1131,7 @@ Stmt SchedulePostProcRewriteForTensorCore( return stmt; } - BufferAnalyser buffer_analyser(extern_buffer, - schedule_analyser, mma_matcher); + BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); buffer_analyser(stmt); if (!buffer_analyser.QualifiedForTensorCore()) { return stmt; @@ -1228,12 +1141,9 @@ Stmt SchedulePostProcRewriteForTensorCore( } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") -.set_body_typed([](Stmt stmt, - Schedule schedule, - Map extern_buffer) { - return SchedulePostProcRewriteForTensorCore( - stmt, schedule, extern_buffer); -}); + .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { + return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); + }); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index bb52be4c02026..57e5528870ea1 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -36,14 +36,15 @@ * - Add annotation of extern buffers using the buffer_map field * in the PrimFunc type. */ -#include #include +#include +#include #include #include #include -#include -#include + #include +#include namespace tvm { namespace te { @@ -62,8 +63,7 @@ Buffer CreateBufferFor(const Tensor& tensor) { class TensorToBufferMapper : public StmtExprMutator { public: explicit TensorToBufferMapper(std::unordered_map buffer_map) - : buffer_map_(buffer_map) { - } + : buffer_map_(buffer_map) {} Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); @@ -76,22 +76,19 @@ class TensorToBufferMapper : public StmtExprMutator { Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); - body = AttrStmtNode::make( - buffer, op->attr_key, op->value, body); + body = AttrStmtNode::make(buffer, op->attr_key, op->value, body); } return body; } else if (op->attr_key == tir::attr::buffer_bind_scope) { - Array tuple = Downcast >(op->node); + Array tuple = Downcast>(op->node); Tensor tensor = Downcast(tuple[1]); - return AttrStmtNode::make( - Array{tuple[0], GetOrAllocBuffer(tensor)}, - op->attr_key, op->value, op->body); - } else if (op->attr_key == tir::attr::buffer_dim_align|| + return AttrStmtNode::make(Array{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, + op->value, op->body); + } else if (op->attr_key == tir::attr::buffer_dim_align || op->attr_key == tir::attr::prefetch_scope) { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); - return AttrStmtNode::make( - buffer, op->attr_key, op->value, op->body); + return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body); } else { return ret; } @@ -131,9 +128,7 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { - return GetBuffer(tensor, true); - } + Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { auto it = buffer_map_.find(tensor); @@ -149,9 +144,7 @@ class TensorToBufferMapper : public StmtExprMutator { std::unordered_map buffer_map_; }; - -PrimFunc SchedulePostProcToPrimFunc(Array arg_list, - Stmt body, +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, Optional> extern_buffer_opt) { std::unordered_map extern_buffer; @@ -188,7 +181,7 @@ PrimFunc SchedulePostProcToPrimFunc(Array arg_list, } TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") -.set_body_typed(SchedulePostProcToPrimFunc); + .set_body_typed(SchedulePostProcToPrimFunc); } // namespace te } // namespace tvm diff --git a/src/te/schedule/verify_compact_buffer.cc b/src/te/schedule/verify_compact_buffer.cc index 759adb9e76fc3..0089c36dc6070 100644 --- a/src/te/schedule/verify_compact_buffer.cc +++ b/src/te/schedule/verify_compact_buffer.cc @@ -22,12 +22,12 @@ * \brief Verify if there was any compact buffer bound to a statement. */ #include +#include +#include #include #include #include #include -#include -#include #include @@ -57,8 +57,7 @@ bool VerifyCompactBuffer(const Stmt& stmt) { return verifier.Verify(stmt); } -TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer") -.set_body_typed(VerifyCompactBuffer); +TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer").set_body_typed(VerifyCompactBuffer); } // namespace te } // namespace tvm diff --git a/src/te/tensor.cc b/src/te/tensor.cc index cb14f6a35270d..606797da5e872 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -21,27 +21,24 @@ * \file tensor.cc */ #include -#include #include +#include #include + #include namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { - return IterVarNode::make( - dom, Var(tag), kThreadIndex, tag); + return IterVarNode::make(dom, Var(tag), kThreadIndex, tag); } IterVar reduce_axis(Range dom, std::string name) { - return IterVarNode::make( - dom, Var(name), kCommReduce); + return IterVarNode::make(dom, Var(name), kCommReduce); } -Var var(std::string name_hint, DataType t) { - return Var(name_hint, t); -} +Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor PrimExpr Tensor::operator()(Array indices) const { @@ -52,13 +49,11 @@ PrimExpr Tensor::operator()(Array indices) const { PrimExpr Tensor::operator()(Array indices) const { using tir::CallNode; if (ndim() != 0) { - CHECK_EQ(ndim(), indices.size()) - << "Tensor dimension mismatch in read" - << "ndim = " << ndim() << ", indices.size=" << indices.size(); + CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" + << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - auto n = CallNode::make( - (*this)->dtype, (*this)->op->name, indices, CallNode::Halide, - (*this)->op, (*this)->value_index); + auto n = CallNode::make((*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op, + (*this)->value_index); return n; } @@ -71,10 +66,7 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, - DataType dtype, - Operation op, - int value_index) { +Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; @@ -84,25 +76,18 @@ Tensor TensorNode::make(Array shape, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* t = static_cast(node.get()); - p->stream << "Tensor(shape=" << t->shape - << ", op.name=" << t->op->name << ')'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* t = static_cast(node.get()); + p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; + }); TVM_REGISTER_NODE_TYPE(TensorNode); - // TensorIntrin -TensorIntrin TensorIntrinNode::make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update) { +TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); @@ -116,20 +101,17 @@ TensorIntrin TensorIntrinNode::make(std::string name, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); - // TensorIntrinCall -TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, +TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, + Array regions, Array reduce_axis, Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); @@ -141,40 +123,32 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* n = static_cast(node.get()); - p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* n = static_cast(node.get()); + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); -TVM_REGISTER_GLOBAL("te.Tensor") -.set_body_typed(TensorNode::make); +TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make); -TVM_REGISTER_GLOBAL("te.TensorIntrin") -.set_body_typed(TensorIntrinNode::make); +TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make); -TVM_REGISTER_GLOBAL("te.TensorIntrinCall") -.set_body_typed(TensorIntrinCallNode::make); +TVM_REGISTER_GLOBAL("te.TensorIntrinCall").set_body_typed(TensorIntrinCallNode::make); -TVM_REGISTER_GLOBAL("te.TensorEqual") -.set_body_method(&Tensor::operator==); +TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); -TVM_REGISTER_GLOBAL("te.TensorHash") -.set_body_typed([](Tensor tensor) -> int64_t { - return static_cast(std::hash()(tensor)); - }); +TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { + return static_cast(std::hash()(tensor)); +}); -TVM_REGISTER_GLOBAL("te.OpGetOutput") -.set_body_typed([](Operation op, int64_t output) { +TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_GLOBAL("te.OpNumOutputs") -.set_body_method(&OperationNode::num_outputs); +TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_GLOBAL("te.OpInputTensors") -.set_body_method(&OperationNode::InputTensors); +TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 763e3eb7cdae1..7eb8013f2a854 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,16 +21,15 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ -#include #include +#include #include #include namespace tvm { namespace tir { -class DeepCmpSEqualHandler : - public SEqualReducer::Handler { +class DeepCmpSEqualHandler : public SEqualReducer::Handler { public: // use direct recursion. bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { @@ -41,12 +40,9 @@ class DeepCmpSEqualHandler : return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false)); } - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { - return ObjectRef(nullptr); - } + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } - void MarkGraphNode() final { - } + void MarkGraphNode() final {} private: // reflection vtable @@ -67,9 +63,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { } TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") -.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { - return ExprDeepEqual()(lhs, rhs); -}); + .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { + return ExprDeepEqual()(lhs, rhs); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/side_effect.cc b/src/tir/analysis/side_effect.cc index 10039d9c1f114..b5fb328bf2b9c 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tir/analysis/side_effect.cc @@ -21,9 +21,9 @@ * \file side_effect.cc * \brief side effect analysis */ +#include #include #include -#include namespace tvm { namespace tir { @@ -37,7 +37,8 @@ class ExprSideEffect : public ExprVisitor { void VisitExpr_(const CallNode* op) final { if (!op->is_pure()) { - has_side_effect_ = true; return; + has_side_effect_ = true; + return; } else { ExprVisitor::VisitExpr_(op); } diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc index ffc7792a15c73..2a23329555829 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tir/analysis/var_touch.cc @@ -21,27 +21,23 @@ * \file simple_analysis.cc * \brief Implementation of simple passes */ +#include #include #include -#include namespace tvm { namespace tir { class VarTouchVisitor : public ExprVisitor { public: - explicit VarTouchVisitor( - std::function var_set) - : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} void VisitExpr(const PrimExpr& e) final { if (use_var_) return; ExprVisitor::VisitExpr(e); } - void VisitExpr_(const VarNode* op) final { - Handle(op); - } + void VisitExpr_(const VarNode* op) final { Handle(op); } void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); @@ -58,9 +54,7 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; - -bool ExprUseVar(const PrimExpr& e, - std::function var_set) { +bool ExprUseVar(const PrimExpr& e, std::function var_set) { VarTouchVisitor visitor(var_set); visitor(e); return visitor.use_var_; diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 3dd15002ea735..2ad20ffd4f007 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -35,12 +35,8 @@ namespace tir { class GPUCodeVerifier : public StmtVisitor { public: - bool Verify(Stmt stmt, - int64_t max_local_memory_per_block, - int64_t max_shared_memory_per_block, - int64_t max_threads_per_block, - int64_t max_thread_x, - int64_t max_thread_y, + bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, + int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); @@ -84,7 +80,7 @@ class GPUCodeVerifier : public StmtVisitor { } Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); // record the number of threads in a block @@ -136,8 +132,8 @@ class GPUCodeVerifier : public StmtVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -164,8 +160,7 @@ class GPUCodeVerifier : public StmtVisitor { } }; -bool VerifyGPUCode(const PrimFunc& func, - Map constraints) { +bool VerifyGPUCode(const PrimFunc& func, Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -193,18 +188,11 @@ bool VerifyGPUCode(const PrimFunc& func, LOG(FATAL) << "Invalid check item: " << iter.first; } - return verifier.Verify(func->body, - max_local_memory_per_block, - max_shared_memory_per_block, - max_threads_per_block, - max_thread_x, - max_thread_y, - max_thread_z); + return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); } - -TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code") -.set_body_typed(VerifyGPUCode); +TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); namespace transform { @@ -213,9 +201,7 @@ Pass VerifyGPUCode(Map constraints) { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); - CHECK(VerifyGPUCode(func, constraints)) - << "RuntimeError: GPU constraint violated" - << func; + CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func; } } return mod; @@ -223,8 +209,7 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode") -.set_body_typed(VerifyGPUCode); +TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 03a36066bc089..8eb846b7d6181 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -21,13 +21,12 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ -#include #include +#include +#include #include +#include #include -#include -#include - namespace tvm { namespace tir { @@ -47,13 +46,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ - explicit MemoryAccessVerifier(PrimFunc f, int device_type) - : func_(f), dev_type_(device_type) {} + explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; - MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier(MemoryAccessVerifier &&) = delete; - MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete; + MemoryAccessVerifier(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier(MemoryAccessVerifier&&) = delete; + MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete; //@} /// Interface to perform memory access verification @@ -68,12 +66,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { protected: /// Visitor implementation //@{ - void VisitExpr(const PrimExpr &n) final { + void VisitExpr(const PrimExpr& n) final { if (Failed()) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt &n) final { + void VisitStmt(const Stmt& n) final { if (Failed()) return; StmtExprVisitor::VisitStmt(n); } @@ -85,8 +83,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (!InThreadEnv() && (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope)) { + if (!InThreadEnv() && + (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); @@ -107,8 +105,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { //@} /// Check if the value of a Variable comes from function argument. - bool IsFromFunctionArgs(const VarNode *var) const { - const VarNode *V = var; + bool IsFromFunctionArgs(const VarNode* var) const { + const VarNode* V = var; for (auto kv : func_->buffer_map) { if (V == kv.second->data.get()) return true; } @@ -119,9 +117,9 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. - const auto &iter = defs_.find(V); + const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; - const CallNode *C = iter->second.as(); + const CallNode* C = iter->second.as(); if (!C || C->name != intrinsic::tvm_struct_get) return false; V = C->args[0].as(); } @@ -129,7 +127,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Handle memory access to a Variable - void HandleLoadStoreToVariable(const Var &var) { + void HandleLoadStoreToVariable(const Var& var) { // We skip the access within thread env. if (InThreadEnv()) return; @@ -153,14 +151,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { - return kDLGPU == dev_type || kDLOpenCL == dev_type || - kDLVulkan == dev_type || kDLMetal == dev_type || - kDLROCM == dev_type || kOpenGL == dev_type; + return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || + kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; } /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. - static bool IsFPGADevice(int dev_type) { - return kDLSDAccel == dev_type || kDLAOCL == dev_type; - } + static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } private: /// Status of visitor @@ -168,21 +163,19 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool in_thread_env_{false}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} - tir::PrimFunc func_{nullptr}; ///< Function to be verified. - int dev_type_{kDLCPU}; ///< Device type - std::unordered_map defs_; ///< Variable definitions + tir::PrimFunc func_{nullptr}; ///< Function to be verified. + int dev_type_{kDLCPU}; ///< Device type + std::unordered_map defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass bool VerifyMemory(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); return !v.Failed(); @@ -191,29 +184,28 @@ bool VerifyMemory(const PrimFunc& func) { } } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") -.set_body_typed(VerifyMemory); +TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { Pass VerifyMemory() { - auto pass_func = [=](IRModule mod, PassContext ctx) { - for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - CHECK(VerifyMemory(func)) - << "RuntimeError: Direct host side access to device memory is detected." - << " Did you forget to bind?\n" - << func; - } - } - return mod; - }; + auto pass_func = + [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyMemory(func)) + << "RuntimeError: Direct host side access to device memory is detected." + << " Did you forget to bind?\n" + << func; + } + } + return mod; + }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory") -.set_body_typed(VerifyMemory); +TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index 97eaf2437523c..c57cbf7d0703d 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -24,11 +24,12 @@ * \file verify_ssa.cc */ #include +#include #include #include -#include -#include + #include +#include #include namespace tvm { @@ -101,7 +102,8 @@ class IRVerifySSA final : public StmtExprVisitor { void MarkDef(const VarNode* v, bool allow_dup = false) { if (defined_.count(v) != 0) { if (!allow_dup) { - is_ssa = false; return; + is_ssa = false; + return; } } else { defined_[v] = 1; @@ -112,16 +114,13 @@ class IRVerifySSA final : public StmtExprVisitor { std::unordered_map defined_; }; - bool VerifySSA(const PrimFunc& func) { IRVerifySSA visitor; visitor.Run(func); return visitor.is_ssa; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa") -.set_body_typed(VerifySSA); - +TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); namespace transform { @@ -130,9 +129,7 @@ Pass VerifySSA() { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); - CHECK(VerifySSA(func)) - << "RuntimeError: IR is not in SSA form" - << func; + CHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func; } } return mod; @@ -140,8 +137,7 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifySSA") -.set_body_typed(VerifySSA); +TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); } // namespace transform diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 0f1c572fd0a48..45b9680ab3f3c 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -20,15 +20,16 @@ /*! * \file buffer.cc */ +#include +#include #include +#include #include -#include #include -#include -#include #include #include + #include "../../arith/compute_expr.h" namespace tvm { @@ -44,23 +45,13 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, - DataType dtype, - std::string name) { - return BufferNode::make( - Var(name, PointerType(PrimType(dtype))), - dtype, - shape, - Array(), - PrimExpr(), - name, - "", - 0, 0, - kDefault); +Buffer decl_buffer(Array shape, DataType dtype, std::string name) { + return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault); } // Split the given expression w.r.t the add operator -inline std::vector ExprSplitAddition(const PrimExpr &expr) { +inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; std::vector ret; std::stack split_buffer; @@ -79,7 +70,6 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { return ret; } - // Searches for the following types of expr: // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki // mod_l_expr = c @@ -87,9 +77,9 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr &mult_expr, - const PrimExpr &mod_l_expr, - const PrimExpr &mod_r_expr) { +inline std::pair MergeMulModInner(const PrimExpr& mult_expr, + const PrimExpr& mod_l_expr, + const PrimExpr& mod_r_expr) { using namespace tir; const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, PrimExpr()); @@ -124,9 +114,8 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, return std::make_pair(false, PrimExpr()); } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; - if (expr_equal(overall_mult, inner_div_ptr->b) - && expr_equal(overall_mult, mod_r_expr) - && expr_equal(inner_div_ptr->a, mod_l_expr)) { + if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && + expr_equal(inner_div_ptr->a, mod_l_expr)) { // Found! PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); @@ -157,9 +146,7 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, inline void MergeMulModInsertElements(const std::vector& eles, std::list* mult_exprs, std::list >* mod_exprs, - PrimExpr* no_opt_sum, - bool* has_mult, - bool* has_mod) { + PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { using namespace tir; *has_mult = false; *has_mod = false; @@ -185,7 +172,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { +inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and @@ -199,8 +186,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { PrimExpr no_opt_sum; bool has_mult; bool has_mod; - MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); bool find_opt = false; std::list >::iterator search_mod_it = mod_exprs.begin(); // 2. Exhaustive Search @@ -208,9 +194,8 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { std::list::iterator mult_it = mult_exprs.begin(); bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { - std::pair ret = MergeMulModInner(*mult_it, - search_mod_it->first, - search_mod_it->second); + std::pair ret = + MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; @@ -218,8 +203,8 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { mod_exprs.erase(temp_mod_it); mult_exprs.erase(mult_it); std::vector ret_eles = ExprSplitAddition(ret.second); - MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, + &has_mod); if (has_mult) { search_mod_it = mod_exprs.begin(); } else if (has_mod && search_mod_it == mod_exprs.end()) { @@ -242,9 +227,9 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; } for (std::list >::iterator it = mod_exprs.begin(); - it != mod_exprs.end(); ++it) { - no_opt_sum = no_opt_sum.get() ? - no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second); + it != mod_exprs.end(); ++it) { + no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) + : indexmod(it->first, it->second); } return no_opt_sum; } @@ -300,20 +285,16 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { return tir::CastNode::make( DataType::Bool(), - tir::LoadNode::make( - DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); + tir::LoadNode::make(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), + const_true())); } else { - return tir::LoadNode::make( - dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::LoadNode::make(dtype, n->data, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); } } @@ -321,18 +302,14 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); DataType dtype = value.dtype(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return tir::StoreNode::make(n->data, - tir::CastNode::make(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), - const_true()); + return tir::StoreNode::make(n->data, tir::CastNode::make(DataType::Int(8), value), + BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + const_true(dtype.lanes())); } } @@ -342,7 +319,7 @@ Buffer Buffer::MakeStrideView() const { std::vector temp; auto n = make_object(*operator->()); PrimExpr acc = make_const(n->DefaultIndexType(), 1); - for (size_t i = n->shape.size(); i != 0 ; --i) { + for (size_t i = n->shape.size(); i != 0; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; } @@ -364,8 +341,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const // check if stride is needed. for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { - if (!is_zero(begins[i]) || - !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { + if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } @@ -376,21 +352,11 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return BufferNode::make(n->data, - n->dtype, - extents, - strides, - elem_offset, - n->name + "_slice", - n->scope, - n->data_alignment, - 0, - n->buffer_type); + return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", + n->scope, n->data_alignment, 0, n->buffer_type); } -PrimExpr Buffer::access_ptr(int access_mask, - DataType ptr_type, - int content_lanes, +PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset) const { const BufferNode* self = operator->(); PrimExpr e_dtype; @@ -407,28 +373,19 @@ PrimExpr Buffer::access_ptr(int access_mask, if (content_lanes > 1) { e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); - elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), - content_lanes); + elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { e_dtype = tir::TypeAnnotation(self->dtype); } - Array acc_args{ - e_dtype, self->data, elem_offset, - extent, make_const(DataType::Int(32), access_mask)}; - return tir::CallNode::make( - ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); + Array acc_args{e_dtype, self->data, elem_offset, extent, + make_const(DataType::Int(32), access_mask)}; + return tir::CallNode::make(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, + tir::CallNode::Intrinsic); } -Buffer BufferNode::make(Var data, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, - BufferType buffer_type) { +Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, + int data_alignment, int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); n->dtype = dtype; @@ -461,31 +418,26 @@ Buffer BufferNode::make(Var data, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "buffer(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "buffer(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(BufferNode); +TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], + args[8], type); +}); -TVM_REGISTER_GLOBAL("tir.Buffer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 10); - auto buffer_type = args[9].operator std::string(); - BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8], type); - }); - -TVM_REGISTER_GLOBAL("tir.BufferAccessPtr") -.set_body_method(&Buffer::access_ptr); +TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); -TVM_REGISTER_GLOBAL("tir.BufferVLoad") -.set_body_method(&Buffer::vload); +TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); -TVM_REGISTER_GLOBAL("tir.BufferVStore") -.set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 77de9f4aacbc4..23e13edadde57 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -21,46 +21,43 @@ * \file src/lang/data_layout.cc * \brief Data Layout expression. */ +#include #include #include #include -#include #include namespace tvm { namespace tir { -using tir::Var; using tir::IterVar; using tir::IterVarNode; +using tir::Var; TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); const LayoutAxis LayoutAxis::UPPER_CASE[] = { - LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), - LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), - LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), - LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), - LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), - LayoutAxis('Z') -}; + LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), + LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), + LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), + LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), + LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), + LayoutAxis('Z')}; const LayoutAxis LayoutAxis::LOWER_CASE[] = { - LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), - LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), - LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), - LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), - LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), - LayoutAxis('z') -}; + LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), + LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), + LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), + LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), + LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), + LayoutAxis('z')}; const LayoutAxis& LayoutAxis::Get(const char name) { CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) - << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; - return (name >= 'A' && name <= 'Z') ? - LayoutAxis::UPPER_CASE[name-'A'] : - LayoutAxis::LOWER_CASE[name-'a']; + << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; + return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A'] + : LayoutAxis::LOWER_CASE[name - 'a']; } const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { @@ -83,8 +80,8 @@ Layout::Layout(const Array& axes) { CHECK_GT(factor->value, 0); repr << factor->value; } - CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis " - << axis->var.get()->name_hint; + CHECK_EQ(axis->var.get()->name_hint.size(), 1) + << "Invalid layout axis " << axis->var.get()->name_hint; char c = axis->var.get()->name_hint[0]; CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; repr << axis->var.get()->name_hint; @@ -93,7 +90,7 @@ Layout::Layout(const Array& axes) { data_ = std::move(node); } -Layout::Layout(const std::string& name) { // NOLINT(*) +Layout::Layout(const std::string& name) { // NOLINT(*) if (name == "__undef__") return; auto node = make_object(); @@ -105,19 +102,18 @@ Layout::Layout(const std::string& name) { // NOLINT(*) int32_t factor = 0; for (char c : name) { if (c >= 'A' && c <= 'Z') { - CHECK_EQ(factor, 0) << "Invalid layout " << name - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), - Var(std::string(1, c)), tir::kDataPar); + IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), + tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { - CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " - << factor << " for dimension " << c; - IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(std::string(1, c)), tir::kDataPar); + CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor + << " for dimension " << c; + IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), + tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -141,16 +137,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*) for (const IterVar& v : node->axes) { char axis = v->var.get()->name_hint[0]; if (axis >= 'a' && axis <= 'z') { - CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis " - << std::toupper(axis); + CHECK(exist_axis[axis - 'a' + 'A']) + << "Invalid layout " << name << ": missing axis " << std::toupper(axis); } } data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { - return Layout(layout); -} +Layout LayoutNode::make(const std::string& layout) { return Layout(layout); } Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); @@ -164,16 +158,16 @@ Layout Layout::SubLayout(size_t pos, size_t len) const { return Layout(new_layout); } -Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const { +Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { if (!defined()) return Layout::Undef(); const std::string& name = operator->()->name; const auto axes = operator->()->axes; - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name; CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; - CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis - << " has already been split in " << name; + CHECK(!this->Contains(axis.ToSubordinate())) + << "Axis " << axis << " has already been split in " << name; CHECK(factor > 0) << "Invalid split size " << factor; Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { @@ -202,16 +196,15 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* l = static_cast(node.get()); - p->stream << "Layout(" << l->name << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* l = static_cast(node.get()); + p->stream << "Layout(" << l->name << ")"; + }); -inline bool GetStoreRule(Array* rule, - const Layout& src_layout, +inline bool GetStoreRule(Array* rule, const Layout& src_layout, const Layout& dst_layout) { - if (!src_layout.defined() || src_layout.name().empty() || - !dst_layout.defined() || dst_layout.name().empty()) { + if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() || + dst_layout.name().empty()) { return false; } for (size_t i = 0; i < dst_layout.ndim(); ++i) { @@ -273,16 +266,15 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(src_index.size(), self->src_layout->axes.size()) - << "Input mismatch with layout " << self->src_layout; + << "Input mismatch with layout " << self->src_layout; return TransformIndex(src_index, self->src_layout->axes, self->forward_rule); } - Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) - << "Output mismatch with layout " << self->dst_layout; + << "Output mismatch with layout " << self->dst_layout; return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule); } @@ -310,8 +302,8 @@ inline Array TransformShape(const Array& src_shape, const auto* orig_axis_extent = orig_axis->dom->extent.as(); if (orig_shape_const) { CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) - << "Input shape mismatch at index " << i << ". Expected " - << orig_axis->dom->extent << ", get " << orig_shape; + << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent + << ", get " << orig_shape; } } bind_map[orig_axis->var.get()] = PrimExpr(0); @@ -343,15 +335,13 @@ inline Array TransformShape(const Array& src_shape, Array BijectiveLayout::ForwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->src_layout->axes, - self->dst_layout->axes, self->forward_rule); + return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule); } Array BijectiveLayout::BackwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->dst_layout->axes, - self->src_layout->axes, self->backward_rule); + return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule); } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { @@ -369,51 +359,47 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* b = static_cast(node.get()); - p->stream << "BijectiveLayout(" << b->src_layout.name() - << "->" << b->dst_layout.name() << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* b = static_cast(node.get()); + p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() + << ")"; + }); -TVM_REGISTER_GLOBAL("tir.Layout") -.set_body_typed(LayoutNode::make); +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make); -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf") -.set_body_typed([](Layout layout, std::string axis) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); -}); + .set_body_typed([](Layout layout, std::string axis) -> int { + return layout.FactorOf(LayoutAxis::make(axis)); + }); -TVM_REGISTER_GLOBAL("tir.LayoutNdim") -.set_body_typed([](Layout layout) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem") -.set_body_typed([](Layout layout, int idx) -> std::string { +TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { const LayoutAxis& axis = layout[idx]; return axis.name(); }); TVM_REGISTER_GLOBAL("tir.BijectiveLayout") -.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { - return BijectiveLayout(src_layout, dst_layout); -}); + .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { + return BijectiveLayout(src_layout, dst_layout); + }); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") -.set_body_method(&BijectiveLayout::ForwardIndex); + .set_body_method(&BijectiveLayout::ForwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") -.set_body_method(&BijectiveLayout::BackwardIndex); + .set_body_method(&BijectiveLayout::BackwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") -.set_body_method(&BijectiveLayout::ForwardShape); + .set_body_method(&BijectiveLayout::ForwardShape); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") -.set_body_method(&BijectiveLayout::BackwardShape); + .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index a36d81f8f306a..569415546ec80 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -22,11 +22,12 @@ */ #include #include -#include #include +#include #include -#include + #include +#include #include "../../support/str_escape.h" @@ -68,9 +69,7 @@ SizeVar::SizeVar(std::string name_hint, DataType dtype) { data_ = std::move(n); } - -TVM_REGISTER_GLOBAL("tir.Var") -.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { if (type.IsObjectRef()) { return Var(name_hint, type.operator Type()); } else { @@ -78,16 +77,11 @@ TVM_REGISTER_GLOBAL("tir.Var") } }); -TVM_REGISTER_GLOBAL("tir.SizeVar") -.set_body_typed([](std::string s, DataType t) { - return SizeVar(s, t); +TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](std::string s, DataType t) { + return SizeVar(s, t); }); - -IterVar IterVarNode::make(Range dom, - Var var, - IterVarType t, - std::string thread_tag) { +IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; @@ -97,29 +91,25 @@ IterVar IterVarNode::make(Range dom, } TVM_REGISTER_GLOBAL("tir.IterVar") -.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { - return IterVarNode::make( - dom, var, - static_cast(iter_type), - thread_tag); -}); + .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { + return IterVarNode::make(dom, var, static_cast(iter_type), thread_tag); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "iter_var("; - if (op->var->name_hint.length() != 0) { - p->stream << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - p->stream << op->dom; - } - if (op->thread_tag.length() != 0) { - p->stream << ", " << op->thread_tag; - } - p->stream << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "iter_var("; + if (op->var->name_hint.length() != 0) { + p->stream << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + p->stream << op->dom; + } + if (op->thread_tag.length() != 0) { + p->stream << ", " << op->thread_tag; + } + p->stream << ")"; + }); TVM_REGISTER_NODE_TYPE(IterVarNode); @@ -130,9 +120,7 @@ PrimExpr StringImmNode::make(std::string value) { return PrimExpr(node); } -TVM_REGISTER_GLOBAL("tir.StringImm") -.set_body_typed(StringImmNode::make); - +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed(StringImmNode::make); PrimExpr CastNode::make(DataType t, PrimExpr value) { CHECK(value.defined()); @@ -143,7 +131,6 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) { return PrimExpr(node); } - PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; @@ -172,7 +159,6 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { return PrimExpr(node); } - PrimExpr NotNode::make(PrimExpr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); @@ -183,15 +169,12 @@ PrimExpr NotNode::make(PrimExpr a) { return PrimExpr(node); } - - PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; CHECK(condition.dtype().is_bool()); - CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || - condition.dtype().lanes() == 1); + CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); @@ -259,11 +242,24 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { return PrimExpr(node); } -const char* CallNode::vectorizable_intrinsics[] = { - "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right, - tir::CallNode::likely, tir::CallNode::popcount -}; +const char* CallNode::vectorizable_intrinsics[] = {"floor", + "ceil", + "sign", + "trunc", + "fabs", + "round", + "exp", + "tanh", + "sqrt", + "log", + "sin", + "cos", + "pow", + "tan", + tir::CallNode::shift_left, + tir::CallNode::shift_right, + tir::CallNode::likely, + tir::CallNode::popcount}; bool CallNode::is_vectorizable() const { size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); @@ -275,12 +271,8 @@ bool CallNode::is_vectorizable() const { return false; } -PrimExpr CallNode::make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func, - int value_index) { +PrimExpr CallNode::make(DataType dtype, std::string name, Array args, CallType call_type, + FunctionRef func, int value_index) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } @@ -301,8 +293,7 @@ PrimExpr CallNode::make(DataType dtype, return PrimExpr(node); } -PrimExpr ShuffleNode::make(Array vectors, - Array indices) { +PrimExpr ShuffleNode::make(Array vectors, Array indices) { CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); @@ -341,9 +332,7 @@ PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) { return make({vector}, {Integer(index)}); } -CommReducer CommReducerNode::make(Array lhs, - Array rhs, - Array result, +CommReducer CommReducerNode::make(Array lhs, Array rhs, Array result, Array identity_element) { auto node = make_object(); node->lhs = lhs; @@ -363,24 +352,19 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(rhs[i], b[i]); } auto ret = this->result; - ret.MutateByApply([&value_map] (const PrimExpr& e) { - return Substitute(e, value_map); - }); + ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); return ret; } -TVM_REGISTER_GLOBAL("tir.CommReducer") -.set_body_typed(CommReducerNode::make); +TVM_REGISTER_GLOBAL("tir.CommReducer").set_body_typed(CommReducerNode::make); TVM_REGISTER_GLOBAL("tir.CommReducerCombine") -.set_body_method(&tir::CommReducerNode::operator()); - + .set_body_method(&tir::CommReducerNode::operator()); -PrimExpr ReduceNode::make(CommReducer combiner, Array source, - Array axis, PrimExpr condition, int value_index) { +PrimExpr ReduceNode::make(CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { - CHECK_EQ(axis[i]->iter_type, kCommReduce) - << "Can only take axis created by reduce_axis"; + CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); @@ -399,10 +383,7 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array source, return PrimExpr(n); } - -TVM_REGISTER_GLOBAL("tir.Reduce") -.set_body_typed(ReduceNode::make); - +TVM_REGISTER_GLOBAL("tir.Reduce").set_body_typed(ReduceNode::make); PrimExpr AnyNode::make() { auto n = make_object(); @@ -417,285 +398,277 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferLoad") -.set_body_typed([](Buffer buffer, Array indices) { +TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { return BufferLoad(buffer, indices); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '\"' << support::StrEscape(op->value) << '\"'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '\"' << support::StrEscape(op->value) << '\"'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dtype << '('; - p->Print(op->value); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - p->stream << op->name_hint; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " + "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " - "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "*"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "/"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " % "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "min("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "max("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " == "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " != "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " < "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " <= "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " > "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " >= "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dtype << '('; + p->Print(op->value); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + p->stream << op->name_hint; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " + "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " - "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "*"; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "/"; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " % "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "min("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "max("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " == "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " != "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " < "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " <= "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " > "; + p->Print(op->b); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " >= "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floordiv(" << op->a << ", " << op->b << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floordiv(" << op->a << ", " << op->b << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floormod(" << op->a << ", " << op->b << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floormod(" << op->a << ", " << op->b << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " && "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " && "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " || "; - p->Print(op->b); - p->stream << ')'; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " || "; + p->Print(op->b); + p->stream << ')'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '!'; - p->Print(op->a); -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '!'; + p->Print(op->a); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "select("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->true_value); - p->stream << ", "; - p->Print(op->false_value); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "select("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->true_value); + p->stream << ", "; + p->Print(op->false_value); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "]"; - if (!is_one(op->predicate)) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "]"; + if (!is_one(op->predicate)) { p->stream << " if "; p->Print(op->predicate); - } -}); + } + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ramp("; - p->Print(op->base); - p->stream << ", "; - p->Print(op->stride); - p->stream << ", " << op->lanes << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ramp("; + p->Print(op->base); + p->stream << ", "; + p->Print(op->stride); + p->stream << ", " << op->lanes << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "x" << op->lanes << "("; - p->Print(op->value); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "x" << op->lanes << "("; + p->Print(op->value); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->name << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->name << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + p->stream << ", "; + } } - } - p->stream << ")"; - }); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } } - } - p->stream << "]"; - }); + p->stream << "]"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(let " << op->var << " = "; - p->Print(op->value); - p->stream << " in "; - p->Print(op->body); - p->stream << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "(let " << op->var << " = "; + p->Print(op->value); + p->stream << " in "; + p->Print(op->body); + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << "?"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "reduce(combiner=" - << op->combiner; - p->stream << ", source=" << op->source; - p->stream << ", axis=" << op->axis; - p->stream << ", where=" << op->condition; - p->stream << ", value_index=" << op->value_index; - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "reduce(combiner=" << op->combiner; + p->stream << ", source=" << op->source; + p->stream << ", axis=" << op->axis; + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; + p->stream << ")"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "comm_reducer(result=" << op->result - << ", lhs=" << op->lhs - << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element - << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs + << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")"; + }); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); @@ -728,112 +701,78 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode); TVM_REGISTER_NODE_TYPE(ReduceNode); TVM_REGISTER_NODE_TYPE(AnyNode); +TVM_REGISTER_GLOBAL("tir.Add").set_body_typed(AddNode::make); -TVM_REGISTER_GLOBAL("tir.Add") -.set_body_typed(AddNode::make); +TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed(SubNode::make); -TVM_REGISTER_GLOBAL("tir.Sub") -.set_body_typed(SubNode::make); +TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed(MulNode::make); -TVM_REGISTER_GLOBAL("tir.Mul") -.set_body_typed(MulNode::make); +TVM_REGISTER_GLOBAL("tir.Div").set_body_typed(DivNode::make); -TVM_REGISTER_GLOBAL("tir.Div") -.set_body_typed(DivNode::make); +TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed(ModNode::make); -TVM_REGISTER_GLOBAL("tir.Mod") -.set_body_typed(ModNode::make); +TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed(FloorDivNode::make); -TVM_REGISTER_GLOBAL("tir.FloorDiv") -.set_body_typed(FloorDivNode::make); +TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed(FloorModNode::make); -TVM_REGISTER_GLOBAL("tir.FloorMod") -.set_body_typed(FloorModNode::make); +TVM_REGISTER_GLOBAL("tir.Min").set_body_typed(MinNode::make); -TVM_REGISTER_GLOBAL("tir.Min") -.set_body_typed(MinNode::make); +TVM_REGISTER_GLOBAL("tir.Max").set_body_typed(MaxNode::make); -TVM_REGISTER_GLOBAL("tir.Max") -.set_body_typed(MaxNode::make); +TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed(EQNode::make); -TVM_REGISTER_GLOBAL("tir.EQ") -.set_body_typed(EQNode::make); +TVM_REGISTER_GLOBAL("tir.NE").set_body_typed(NENode::make); -TVM_REGISTER_GLOBAL("tir.NE") -.set_body_typed(NENode::make); +TVM_REGISTER_GLOBAL("tir.LT").set_body_typed(LTNode::make); -TVM_REGISTER_GLOBAL("tir.LT") -.set_body_typed(LTNode::make); +TVM_REGISTER_GLOBAL("tir.LE").set_body_typed(LENode::make); -TVM_REGISTER_GLOBAL("tir.LE") -.set_body_typed(LENode::make); +TVM_REGISTER_GLOBAL("tir.GT").set_body_typed(GTNode::make); -TVM_REGISTER_GLOBAL("tir.GT") -.set_body_typed(GTNode::make); +TVM_REGISTER_GLOBAL("tir.GE").set_body_typed(GENode::make); -TVM_REGISTER_GLOBAL("tir.GE") -.set_body_typed(GENode::make); +TVM_REGISTER_GLOBAL("tir.And").set_body_typed(AndNode::make); -TVM_REGISTER_GLOBAL("tir.And") -.set_body_typed(AndNode::make); +TVM_REGISTER_GLOBAL("tir.Or").set_body_typed(OrNode::make); -TVM_REGISTER_GLOBAL("tir.Or") -.set_body_typed(OrNode::make); +TVM_REGISTER_GLOBAL("tir.Not").set_body_typed(NotNode::make); -TVM_REGISTER_GLOBAL("tir.Not") -.set_body_typed(NotNode::make); +TVM_REGISTER_GLOBAL("tir.Select").set_body_typed(SelectNode::make); -TVM_REGISTER_GLOBAL("tir.Select") -.set_body_typed(SelectNode::make); +TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed(RampNode::make); -TVM_REGISTER_GLOBAL("tir.Ramp") -.set_body_typed(RampNode::make); +TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed(CastNode::make); -TVM_REGISTER_GLOBAL("tir.Cast") -.set_body_typed(CastNode::make); +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed(BroadcastNode::make); -TVM_REGISTER_GLOBAL("tir.Broadcast") -.set_body_typed(BroadcastNode::make); +TVM_REGISTER_GLOBAL("tir.Shuffle").set_body_typed(ShuffleNode::make); -TVM_REGISTER_GLOBAL("tir.Shuffle") -.set_body_typed(ShuffleNode::make); - -TVM_REGISTER_GLOBAL("tir.Let") -.set_body_typed(LetNode::make); - -TVM_REGISTER_GLOBAL("tir.Load") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DataType t = args[0]; - if (args.size() == 3) { - *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); - } else { - *ret = LoadNode::make(t, args[1], args[2], args[3]); - } - }); +TVM_REGISTER_GLOBAL("tir.Let").set_body_typed(LetNode::make); -TVM_REGISTER_GLOBAL("tir.Call") -.set_body_typed([]( - DataType type, std::string name, - Array args, int call_type, - FunctionRef func, int value_index -) { - Array prim_expr_args; - for (const auto& it : args) { - CHECK(it->IsInstance() || - it->IsInstance()); - if (const auto* str = it.as()) { - prim_expr_args.push_back(StringImmNode::make(str->data)); - } else { - prim_expr_args.push_back(Downcast(it)); - } +TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { + DataType t = args[0]; + if (args.size() == 3) { + *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); + } else { + *ret = LoadNode::make(t, args[1], args[2], args[3]); } - return CallNode::make(type, - name, - prim_expr_args, - static_cast(call_type), - func, - value_index); }); +TVM_REGISTER_GLOBAL("tir.Call") + .set_body_typed([](DataType type, std::string name, Array args, int call_type, + FunctionRef func, int value_index) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImmNode::make(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } + return CallNode::make(type, name, prim_expr_args, static_cast(call_type), + func, value_index); + }); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 57ff627ceaf13..7f30abea36131 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -20,6 +20,7 @@ * \file expr_functor.cc */ #include + #include "functor_common.h" namespace tvm { @@ -49,10 +50,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -#define DEFINE_BINOP_VISIT_(OP) \ - void ExprVisitor::VisitExpr_(const OP* op) { \ - this->VisitExpr(op->a); \ - this->VisitExpr(op->b); \ +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->VisitExpr(op->a); \ + this->VisitExpr(op->b); \ } DEFINE_BINOP_VISIT_(AddNode); @@ -79,20 +80,16 @@ void ExprVisitor::VisitExpr_(const StringImmNode* op) {} void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { - this->VisitExpr(r->dom->min); - this->VisitExpr(r->dom->extent); - }); + this->VisitExpr(r->dom->min); + this->VisitExpr(r->dom->extent); + }); VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->condition); } -void ExprVisitor::VisitExpr_(const CastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const NotNode* op) { - this->VisitExpr(op->a); -} +void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); @@ -110,13 +107,9 @@ void ExprVisitor::VisitExpr_(const ShuffleNode* op) { VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -void ExprVisitor::VisitExpr_(const BroadcastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { - return GetRef(op); -} +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); @@ -145,8 +138,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -160,34 +152,26 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, - op->name, - args, - op->call_type, - op->func, - op->value_index); + return CallNode::make(op->dtype, op->name, args, op->call_type, op->func, op->value_index); } } -#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ - return GetRef(op); \ - } +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return OP::make(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP::make(a, b); \ + } \ } DEFINE_BIOP_EXPR_MUTATE_(AddNode); @@ -209,17 +193,15 @@ DEFINE_BIOP_EXPR_MUTATE_(AndNode); DEFINE_BIOP_EXPR_MUTATE_(OrNode); PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { - auto fitervar = [this](const IterVar& v) { + auto fitervar = [this](const IterVar& v) { Range r = v->dom; PrimExpr min = this->VisitExpr(r->min); PrimExpr extent = this->VisitExpr(r->extent); - if (min.same_as(r->min) && - extent.same_as(r->extent)) { + if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; } else { - return IterVarNode::make( - Range::make_by_min_extent(min, extent), - v->var, v->iter_type, v->thread_tag); + return IterVarNode::make(Range::make_by_min_extent(min, extent), v->var, v->iter_type, + v->thread_tag); } }; Array axis = MutateArray(op->axis, fitervar); @@ -229,13 +211,10 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { PrimExpr condition = this->VisitExpr(op->condition); - if (axis.same_as(op->axis) && - source.same_as(op->source) && - condition.same_as(op->condition)) { + if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) { return GetRef(op); } else { - return ReduceNode::make( - op->combiner, source, axis, condition, op->value_index); + return ReduceNode::make(op->combiner, source, axis, condition, op->value_index); } } @@ -261,8 +240,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr true_value = this->VisitExpr(op->true_value); PrimExpr false_value = this->VisitExpr(op->false_value); - if (condition.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { @@ -273,8 +251,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { + if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { return RampNode::make(base, stride, op->lanes); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index ecaad586f8941..1149e039cae4e 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,11 +29,8 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc -PrimFunc::PrimFunc(Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { +PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -60,29 +57,25 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - // TODO(tvm-team) redirect to Text printer once we have a good text format. - auto* node = static_cast(ref.get()); - p->stream << "PrimFunc(" << node->params << ") "; - if (node->attrs.defined()) { - p->stream << "attrs=" << node->attrs; - } - p->stream << " {\n"; - p->indent += 2; - p->Print(node->body); - p->indent -= 2; - p->stream << "}\n"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + p->stream << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + p->stream << "attrs=" << node->attrs; + } + p->stream << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + p->stream << "}\n"; + }); TVM_REGISTER_GLOBAL("tir.PrimFunc") -.set_body_typed([](Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { - return PrimFunc(params, body, ret_type, buffer_map, attrs); -}); + .set_body_typed([](Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { + return PrimFunc(params, body, ret_type, buffer_map, attrs); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 76a91ea42d427..f63dcfe003c60 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -27,7 +27,7 @@ namespace tvm { namespace tir { // Implementation of Visitors -template +template inline void VisitArray(const Array& arr, F fvisit) { for (size_t i = 0; i < arr.size(); i++) { fvisit(arr[i]); @@ -35,10 +35,8 @@ inline void VisitArray(const Array& arr, F fvisit) { } // Implementation of mutators -template -inline Array MutateArray(const Array& arr, - F fmutate, - bool allow_copy_on_write = false) { +template +inline Array MutateArray(const Array& arr, F fmutate, bool allow_copy_on_write = false) { if (allow_copy_on_write) { // if we allow copy on write, we can directly // call the inplace mutate function. diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 4ad244ff02b23..2757c2fa2ddd0 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -24,6 +24,7 @@ #include #include #include + #include // Centralized header for constant folders. #include "../../arith/const_fold.h" @@ -32,15 +33,15 @@ namespace tvm { using namespace tir; - runtime::DataType GetRuntimeDataType(const Type& type) { - if (auto * n = type.as()) { + if (auto* n = type.as()) { return n->dtype; } else if (type.as()) { return DataType::Handle(); + } else if (IsVoidType(type)) { + return DataType::Void(); } else { - LOG(FATAL) << "Type " << type - << " does not have a corresponding runtime::DataType"; + LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType"; return DataType::Handle(); } } @@ -57,9 +58,8 @@ Type GetType(const PrimExpr& expr) { } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); - // These types already implies the specific type. - if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { - return PrimType(dtype); + if (dtype.is_void()) { + return VoidType(); } return PrimType(dtype); } @@ -73,8 +73,7 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { return tir::CallNode::make( t, tir::intrinsic::tvm_large_uint_imm, - {make_const(DataType::UInt(32), low), - make_const(DataType::UInt(32), high)}, + {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } @@ -88,8 +87,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); } else { - CHECK(ltype.lanes() == rtype.lanes()) - << "Cannot match type " << ltype << " vs " << rtype; + CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; // Only do very simple type coversion @@ -196,8 +194,8 @@ PrimExpr infinity(const DataType& dtype) { } namespace tir { -template -inline bool ConstPowerHelper(ValueType val, int *shift) { +template +inline bool ConstPowerHelper(ValueType val, int* shift) { if (val <= 0) return false; shift[0] = 0; while (val != 0) { @@ -253,8 +251,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CallNode::make( - t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { @@ -266,8 +263,8 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { // negation PrimExpr operator-(PrimExpr a) { - using tir::IntImmNode; using tir::FloatImmNode; + using tir::IntImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value); @@ -309,22 +306,14 @@ PrimExpr truncmod(PrimExpr a, PrimExpr b) { return tir::ModNode::make(a, b); } -PrimExpr operator/(PrimExpr a, PrimExpr b) { - return div(a, b); -} +PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } -PrimExpr operator%(PrimExpr a, PrimExpr b) { - return truncmod(a, b); -} +PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv -PrimExpr indexdiv(PrimExpr a, PrimExpr b) { - return floordiv(a, b); -} +PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); } -PrimExpr indexmod(PrimExpr a, PrimExpr b) { - return floormod(a, b); -} +PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); } PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; @@ -346,8 +335,8 @@ PrimExpr floormod(PrimExpr a, PrimExpr b) { PrimExpr min(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return b; if (is_neg_inf(a)) return a; if (is_pos_inf(b)) return a; @@ -360,8 +349,8 @@ PrimExpr min(PrimExpr a, PrimExpr b) { PrimExpr max(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return a; if (is_neg_inf(a)) return b; if (is_pos_inf(b)) return b; @@ -383,19 +372,14 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::CallNode::make( - true_value.dtype(), - tir::intrinsic::tvm_if_then_else, - {cond, true_value, false_value}, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::CallNode::make(cond.dtype(), - tir::CallNode::likely, - { cond }, - tir::CallNode::PureIntrinsic); + return tir::CallNode::make(cond.dtype(), tir::CallNode::likely, {cond}, + tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { @@ -468,17 +452,18 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::shift_right, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { @@ -486,17 +471,18 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::shift_left, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { @@ -504,11 +490,11 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_and, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { @@ -516,11 +502,11 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_or, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { @@ -528,24 +514,23 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); + }); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, + tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_not, {a}, + tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::CallNode::make( - x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { @@ -567,7 +552,7 @@ PrimExpr abs(PrimExpr x) { return x; } else { LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for absolute op. Skipping absolute op..."; + << " not supported for absolute op. Skipping absolute op..."; return x; } } @@ -584,14 +569,13 @@ PrimExpr isnan(PrimExpr x) { } if (x.dtype().bits() == 16) { return tir::CallNode::make(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, - tir::CallNode::PureIntrinsic); + {cast(DataType::Float(32, t.lanes()), std::move(x))}, + tir::CallNode::PureIntrinsic); } else { return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { - LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for isnan op. Skipping isnan op..."; + LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; return x; } } @@ -615,8 +599,7 @@ PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AddNode::make(x, y); PrimExpr identity_element = make_zero(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -625,8 +608,7 @@ PrimExpr all(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::AndNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), true); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -635,8 +617,7 @@ PrimExpr any(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::OrNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), false); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -644,8 +625,7 @@ PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MaxNode::make(x, y); PrimExpr identity_element = min_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -653,8 +633,7 @@ PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MinNode::make(x, y); PrimExpr identity_element = max_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } @@ -662,15 +641,14 @@ PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); PrimExpr result = tir::MulNode::make(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic); + return tir::CallNode::make(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { @@ -720,91 +698,69 @@ PrimExpr trunc(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { - return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : - std::floor(fx->value))); + return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } - // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t()); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double()); - } else { - LOG(FATAL) << "only accept int or float"; - } - }); - -TVM_REGISTER_GLOBAL("node.LargeUIntImm") -.set_body_typed(LargeUIntImm); - -TVM_REGISTER_GLOBAL("node.String") -.set_body_typed(tir::StringImmNode::make); +TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[0].type_code() == kDLInt) { + *ret = tir::make_const(args[1], args[0].operator int64_t()); + } else if (args[0].type_code() == kDLFloat) { + *ret = tir::make_const(args[1], args[0].operator double()); + } else { + LOG(FATAL) << "only accept int or float"; + } +}); -TVM_REGISTER_GLOBAL("tir.min_value") -.set_body_typed(min_value); +TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("tir.max_value") -.set_body_typed(max_value); +TVM_REGISTER_GLOBAL("node.String").set_body_typed(tir::StringImmNode::make); -TVM_REGISTER_GLOBAL("tir.abs") -.set_body_typed(tvm::abs); +TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); -TVM_REGISTER_GLOBAL("tir.isnan") -.set_body_typed(tvm::isnan); +TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); -TVM_REGISTER_GLOBAL("tir.isfinite") -.set_body_typed(tvm::isfinite); +TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("tir.isinf") -.set_body_typed(tvm::isinf); +TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("tir.floor") -.set_body_typed(tvm::floor); +TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); -TVM_REGISTER_GLOBAL("tir.ceil") -.set_body_typed(tvm::ceil); +TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); -TVM_REGISTER_GLOBAL("tir.round") -.set_body_typed(tvm::round); +TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("tir.nearbyint") -.set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("tir.trunc") -.set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("tir._cast") -.set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body_typed([](PrimExpr a, PrimExpr b) { \ - return (Func(a, b)); \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b) { \ + return (Func(a, b)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - bool lhs_is_int = args[0].type_code() == kDLInt; \ - bool rhs_is_int = args[1].type_code() == kDLInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ - } else { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ - } \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \ + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ + } \ }) - REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); @@ -821,20 +777,20 @@ REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); -REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); -REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) +REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("tir._OpIfThenElse") -.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - return if_then_else(cond, true_value, false_value); -}); + .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { + return if_then_else(cond, true_value, false_value); + }); } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index cc61e7e6abaee..4c58fd63c69f6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -21,9 +21,8 @@ * \file tvm/tir/stmt.cc */ #include -#include #include - +#include namespace tvm { namespace tir { @@ -40,13 +39,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt") -.set_body_typed(LetStmtNode::make); +TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make); -Stmt AttrStmtNode::make(ObjectRef node, - std::string attr_key, - PrimExpr value, - Stmt body) { +Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -55,15 +50,12 @@ Stmt AttrStmtNode::make(ObjectRef node, return Stmt(n); } -TVM_REGISTER_GLOBAL("tir.AttrStmt") -.set_body_typed(AttrStmtNode::make); +TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make); Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); - CHECK(message.dtype() == DataType::Int(32) || - message.as()) - << "TypeError: AssertStmt message must be an int or string:" - << message << "\n"; + CHECK(message.dtype() == DataType::Int(32) || message.as()) + << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; ObjectPtr node = make_object(); node->condition = std::move(condition); @@ -73,21 +65,17 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { } TVM_REGISTER_GLOBAL("tir.AssertStmt") -.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { - if (const auto* str = message.as()) { - auto msg = StringImmNode::make(str->data); - return AssertStmtNode::make(condition, msg, body); - } else { - return AssertStmtNode::make(condition, Downcast(message), body); - } -}); + .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { + if (const auto* str = message.as()) { + auto msg = StringImmNode::make(str->data); + return AssertStmtNode::make(condition, msg, body); + } else { + return AssertStmtNode::make(condition, Downcast(message), body); + } + }); -Stmt ForNode::make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body) { +Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, + DeviceAPI device_api, Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); CHECK(min.dtype().is_scalar()); @@ -105,19 +93,12 @@ Stmt ForNode::make(Var loop_var, return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.For") -.set_body_typed([]( - Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body) { - return ForNode::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); +TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, + int for_type, int device_api, Stmt body) { + return ForNode::make(loop_var, min, extent, static_cast(for_type), + static_cast(device_api), body); }); - Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { CHECK(value.defined()); CHECK(index.defined()); @@ -133,20 +114,17 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr return Stmt(node); } - -TVM_REGISTER_GLOBAL("tir.Store") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PrimExpr value = args[1]; - if (args.size() == 3) { - *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); - } else { - *ret = StoreNode::make(args[0], value, args[2], args[3]); - } - }); - +TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { + PrimExpr value = args[1]; + if (args.size() == 3) { + *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); + } else { + *ret = StoreNode::make(args[0], value, args[2], args[3]); + } +}); Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { - CHECK(value_index >=0 && value_index < func->num_outputs()) + CHECK(value_index >= 0 && value_index < func->num_outputs()) << "value index output function return value bound"; CHECK(value.defined()) << "Provide of undefined value\n"; @@ -162,45 +140,39 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array< return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Provide") -.set_body_typed(ProvideNode::make); +TVM_REGISTER_GLOBAL("tir.Provide").set_body_typed(ProvideNode::make); - -Stmt AllocateNode::make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, +Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body) { - for (size_t i = 0; i < extents.size(); ++i) { - CHECK(extents[i].defined()); - CHECK(extents[i].dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + return Stmt(node); } // overloaded, needs special handling // has default args TVM_REGISTER_GLOBAL("tir.Allocate") -.set_body_typed([]( - Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body - ){ - return AllocateNode::make(buffer_var, type, extents, condition, body); -}); + .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + Stmt body) { + return AllocateNode::make(buffer_var, type, extents, condition, body); + }); int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode *int_size = extents[i].as()) { + if (const IntImmNode* int_size = extents[i].as()) { result *= int_size->value; if (result > std::numeric_limits::max()) { return 0; @@ -218,16 +190,10 @@ Stmt FreeNode::make(Var buffer_var) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Free") -.set_body_typed(FreeNode::make); - +TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make); -Stmt RealizeNode::make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body) { +Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds, + PrimExpr condition, Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); @@ -248,29 +214,23 @@ Stmt RealizeNode::make(FunctionRef func, return Stmt(node); } - -TVM_REGISTER_GLOBAL("tir.Realize") -.set_body_typed(RealizeNode::make); - +TVM_REGISTER_GLOBAL("tir.Realize").set_body_typed(RealizeNode::make); Prefetch::Prefetch(Buffer buffer, Array bounds) { data_ = make_object(buffer, bounds); } -TVM_REGISTER_GLOBAL("tir.Prefetch") -.set_body_typed([](Buffer buffer, Array bounds) { +TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { return Prefetch(buffer, bounds); }); - SeqStmt::SeqStmt(Array seq) { auto node = make_object(); node->seq = std::move(seq); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt") -.set_body_typed([](Array seq) { +TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { return SeqStmt(std::move(seq)); }); @@ -286,9 +246,7 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.IfThenElse") -.set_body_typed(IfThenElseNode::make); - +TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make); Stmt EvaluateNode::make(PrimExpr value) { CHECK(value.defined()); @@ -298,8 +256,7 @@ Stmt EvaluateNode::make(PrimExpr value) { return Stmt(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate") -.set_body_typed(EvaluateNode::make); +TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make); BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { ObjectPtr node = make_object(); @@ -310,69 +267,60 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) } TVM_REGISTER_GLOBAL("tir.BufferStore") -.set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { - return BufferStore(buffer, value, indices); -}); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { + return BufferStore(buffer, value, indices); + }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); - -BufferRealize::BufferRealize(Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) { - data_ = make_object( - buffer, bounds, condition, body); +BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + data_ = make_object(buffer, bounds, condition, body); } TVM_REGISTER_GLOBAL("tir.BufferRealize") -.set_body_typed([](Buffer buffer, - Array bounds, - PrimExpr condition, - Stmt body) { - return BufferRealize(buffer, bounds, condition, body); -}); + .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + return BufferRealize(buffer, bounds, condition, body); + }); TVM_REGISTER_NODE_TYPE(BufferRealizeNode); // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "let " << op->var << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "let " << op->var << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "// attr ["; - p->Print(op->node); - p->stream << "] " - << op->attr_key << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "// attr ["; + p->Print(op->node); + p->stream << "] " << op->attr_key << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "assert("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->message); - p->stream << ")\n"; - p->Print(op->body); - }); - -std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "assert("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->message); + p->stream << ")\n"; + p->Print(op->body); + }); + +std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) switch (type) { case ForType::Serial: out << "for"; @@ -391,221 +339,221 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->for_type << " (" << op->loop_var << ", "; - p->Print(op->min); - p->stream << ", "; - p->Print(op->extent); - p->stream << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->for_type << " (" << op->loop_var << ", "; + p->Print(op->min); + p->stream << ", "; + p->Print(op->extent); + p->stream << ") {\n"; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "] = "; - p->Print(op->value); - if (!is_one(op->predicate)) { - p->stream << " if "; - p->Print(op->predicate); - } - p->stream << '\n'; - }); + p->indent += 2; + p->Print(op->body); + p->indent -= 2; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->func->func_name() << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - p->stream << " ="; - p->Print(op->value); - p->stream << '\n'; - }); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) p->stream << ", "; - } - p->stream << "]"; - p->stream << " = "; - p->Print(op->value); - p->stream << '\n'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + p->stream << " if "; + p->Print(op->predicate); + } + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } - p->stream << "]"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << "\n"; - p->Print(op->body); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->func->func_name() << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + p->stream << " ="; + p->Print(op->value); + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "free " << op->buffer_var; - p->stream << '\n'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "buffer_realize " << op->buffer->name << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; - }); + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << "\n"; + p->Print(op->body); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "realize " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "free " << op->buffer_var; + p->stream << '\n'; + }); - p->indent += 2; - p->Print(op->body); - p->indent -= 2; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; - p->PrintIndent(); - p->stream << "}\n"; - }); + p->indent += 2; + p->Print(op->body); + p->indent -= 2; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "prefetch " << op->buffer << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - }); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "realize " << op->func->func_name() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (op->func->num_outputs() != 1) { + p->stream << ".value[" << op->value_index << "]"; + } + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - p->stream << "if (" << op->condition << ") {\n"; p->indent += 2; - p->Print(op->then_case); + p->Print(op->body); p->indent -= 2; - if (!op->else_case.defined()) { - break; + p->PrintIndent(); + p->stream << "}\n"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "prefetch " << op->buffer << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; } + p->stream << ")"; + }); - if (const IfThenElseNode *nested_if = op->else_case.as()) { - p->PrintIndent(); - p->stream << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - p->stream << "} else {\n"; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + p->stream << "if (" << op->condition << ") {\n"; p->indent += 2; - p->Print(op->else_case); + p->Print(op->then_case); p->indent -= 2; - break; + + if (!op->else_case.defined()) { + break; + } + + if (const IfThenElseNode* nested_if = op->else_case.as()) { + p->PrintIndent(); + p->stream << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + p->stream << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } } - } - p->PrintIndent(); - p->stream << "}\n"; -}); + p->PrintIndent(); + p->stream << "}\n"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - p->stream << "\n"; - }); - -template -void PrintList(const Array &exprs, ReprPrinter* p) { + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + p->stream << "\n"; + }); + +template +void PrintList(const Array& exprs, ReprPrinter* p) { for (size_t i = 0; i < exprs.size(); ++i) { p->Print(exprs[i]); if (i < exprs.size() - 1) { @@ -615,14 +563,14 @@ void PrintList(const Array &exprs, ReprPrinter* p) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "shuffle("; - PrintList(op->vectors, p); - p->stream << ", "; - PrintList(op->indices, p); - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "shuffle("; + PrintList(op->vectors, p); + p->stream << ", "; + PrintList(op->indices, p); + p->stream << ")"; + }); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(PrefetchNode); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ec97b03c88c47..13d0b098dd4a0 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -21,7 +21,9 @@ */ #include #include + #include + #include "functor_common.h" namespace tvm { @@ -62,9 +64,9 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); this->VisitExpr(op->condition); this->VisitStmt(op->body); } @@ -92,30 +94,25 @@ void StmtVisitor::VisitStmt_(const ProvideNode* op) { void StmtVisitor::VisitStmt_(const RealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); this->VisitStmt(op->body); this->VisitExpr(op->condition); } void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); } void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { - VisitArray(op->seq, [this](const Stmt& s) { - this->VisitStmt(s); - }); -} - -void StmtVisitor::VisitStmt_(const EvaluateNode* op) { - this->VisitExpr(op->value); + VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } +void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } class StmtMutator::Internal { public: @@ -146,8 +143,7 @@ class StmtMutator::Internal { Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -160,8 +156,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -175,9 +170,7 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -193,9 +186,7 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (extents.same_as(op->extents) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -213,8 +204,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -230,9 +220,7 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr index = this->VisitExpr(op->index); PrimExpr predicate = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && - index.same_as(op->index) && - predicate.same_as(op->predicate)) { + if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -247,8 +235,7 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); - if (value.same_as(op->value) && - indices.same_as(op->indices)) { + if (value.same_as(op->value) && indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -263,9 +250,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt body = this->VisitStmt(op->body); - if (bounds.same_as(op->bounds) && - condition.same_as(op->condition) && - body.same_as(op->body)) { + if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -279,8 +264,7 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); PrimExpr value = this->VisitExpr(op->value); - if (args.same_as(op->args) && - value.same_as(op->value)) { + if (args.same_as(op->args) && value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -294,9 +278,7 @@ Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (bounds.same_as(op->bounds) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -330,8 +312,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { } // advanced visit function for seqstmt. -Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, +Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate) { if (flatten_before_visit) { // Pass 1, check if we need to flatten. @@ -344,10 +325,8 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { - Array seq = - fmutate != nullptr ? - MutateArray(op->seq, fmutate, allow_copy_on_write_) : - Internal::Mutate(this, op->seq); + Array seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_) + : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { return GetRef(op); } else { @@ -380,9 +359,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { PrimExpr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -404,14 +381,10 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } } -Stmt StmtMutator::VisitStmt_(const FreeNode* op) { - return GetRef(op); -} - +Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef(op); } // Implementations of IRTransform, PostOrderVisit and Substitute -class IRApplyVisit : - public StmtExprVisitor { +class IRApplyVisit : public StmtExprVisitor { public: explicit IRApplyVisit(std::function f) : f_(f) {} @@ -434,8 +407,7 @@ class IRApplyVisit : std::unordered_set visited_; }; -void PostOrderVisit(const ObjectRef& node, - std::function fvisit) { +void PostOrderVisit(const ObjectRef& node, std::function fvisit) { if (node.as()) { IRApplyVisit visitor(fvisit); visitor(Downcast(node)); @@ -445,42 +417,29 @@ void PostOrderVisit(const ObjectRef& node, } } -class IRTransformer final : - public StmtExprMutator { +class IRTransformer final : public StmtExprMutator { public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, + IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } + : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} Stmt VisitStmt(const Stmt& stmt) final { - return MutateInternal(stmt, [this](const Stmt& s) { - return this->BaseVisitStmt(s); - }); + return MutateInternal(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); } PrimExpr VisitExpr(const PrimExpr& expr) final { - return MutateInternal(expr, [this](const PrimExpr& e) { - return this->BaseVisitExpr(e); - }); + return MutateInternal(expr, + [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); } private: // NOTE: redirect to parent's call // This is used to get around limitation of gcc-4.8 - Stmt BaseVisitStmt(const Stmt& s) { - return StmtMutator::VisitStmt(s); - } - PrimExpr BaseVisitExpr(const PrimExpr& e) { - return ExprMutator::VisitExpr(e); - } + Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } + PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } template T MutateInternal(const T& node, F fmutate) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { + if (only_enable_.size() && !only_enable_.count(node->type_index())) { return fmutate(node); } if (f_preorder_ != nullptr) { @@ -501,10 +460,8 @@ class IRTransformer final : const std::unordered_set& only_enable_; }; -Stmt IRTransform(Stmt ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - Optional> only_enable) { +Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, Optional> only_enable) { std::unordered_set only_type_index; if (only_enable.defined()) { for (auto s : only_enable.value()) { @@ -517,9 +474,7 @@ Stmt IRTransform(Stmt ir_node, class IRSubstitue : public StmtExprMutator { public: - explicit IRSubstitue(std::function(const Var&)> vmap) - : vmap_(vmap) { - } + explicit IRSubstitue(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); @@ -533,8 +488,7 @@ class IRSubstitue : public StmtExprMutator { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return LoadNode::make( - op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); + return LoadNode::make(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); } else { return ret; } @@ -545,8 +499,8 @@ class IRSubstitue : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return StoreNode::make( - Downcast(mapped_var.value()), op->value, op->index, op->predicate); + return StoreNode::make(Downcast(mapped_var.value()), op->value, op->index, + op->predicate); } else { return ret; } @@ -556,36 +510,28 @@ class IRSubstitue : public StmtExprMutator { std::function(const Var&)> vmap_; }; -Stmt Substitute(Stmt stmt, - std::function(const Var&)> vmap) { +Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { return IRSubstitue(vmap)(std::move(stmt)); } -PrimExpr Substitute(PrimExpr expr, - std::function(const Var&)> vmap) { +PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { return IRSubstitue(vmap)(std::move(expr)); } +TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); -TVM_REGISTER_GLOBAL("tir.IRTransform") -.set_body_typed(IRTransform); - - -TVM_REGISTER_GLOBAL("tir.PostOrderVisit") -.set_body_typed([](ObjectRef node, PackedFunc f) { - tir::PostOrderVisit(node, [f](const ObjectRef& n) { - f(n); - }); +TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); TVM_REGISTER_GLOBAL("tir.Substitute") -.set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef{ - if (node->IsInstance()) { - return Substitute(Downcast(node), vmap); - } else { - return Substitute(Downcast(node), vmap); - } -}); + .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { + if (node->IsInstance()) { + return Substitute(Downcast(node), vmap); + } else { + return Substitute(Downcast(node), vmap); + } + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index dda9ff460cf08..30d5f0f507741 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -21,16 +21,14 @@ * \file tir/ir/transform.cc * \brief TIR specific transformation passes. */ -#include #include +#include #include - namespace tvm { namespace tir { namespace transform { - /*! * \brief Function level pass that applies transformations to all * TIR functions within the module. @@ -43,9 +41,7 @@ class PrimFuncPassNode : public PassNode { /*! \brief The pass function called on each. */ runtime::TypedPackedFunc pass_func; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -90,8 +86,7 @@ PrimFuncPass::PrimFuncPass( } // Perform Module -> Module optimizations at the PrimFunc level. -IRModule PrimFuncPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -123,9 +118,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, const std::string& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } @@ -133,18 +126,16 @@ Pass CreatePrimFuncPass( TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return PrimFuncPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "PrimFuncPass(" << info->name - << ", opt_level=" << info->opt_level << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")"; + }); } // namespace transform } // namespace tir diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 1e6f6c6ed3a04..67a88f5d922eb 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -20,15 +20,15 @@ /*! * \file hoist_if_then_else.cc */ +#include #include #include #include -#include -#include +#include #include #include -#include + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" @@ -152,13 +152,12 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - PackedFunc replace_target_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - if (current_for.get() == top_for_node) { - *ret = new_if_stmt; - } - }); + PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"For"}); } @@ -170,21 +169,19 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { Stmt else_for; CHECK(if_stmt.as()); - PackedFunc replace_then_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->then_case; - } - }); + PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); - PackedFunc replace_else_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->else_case; - } - }); + PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"IfThenElse"}); if (if_stmt.as()->else_case.defined()) { @@ -196,7 +193,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { // Locate all For nodes and capture child IfThenElse nodes. void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const ObjectRef& node){ + PostOrderVisit(stmt, [&](const ObjectRef& node) { const ForNode* for_node = node.as(); if (!for_node) return; @@ -269,10 +266,8 @@ void IfThenElseHoist::LocateTopFor() { CHECK(for_node); std::vector new_for_list{for_stmt}; for_tracking_map_.insert({for_stmt.get(), new_for_list}); - if (cond_var_map_[if_stmt] - .count(for_node->loop_var.get())) { - std::vector updated_for_list(for_list.begin(), - for_list.begin() + i); + if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), for_list.begin() + i); if2for_map_[if_stmt] = updated_for_list; break; } else { @@ -315,13 +310,11 @@ void IfThenElseHoist::LocateTopFor() { // We keep all For nodes tracing in for_tracking_map_. When we get a // hoisted IfThenElse, we match it with tracing For nodes to pick // the updated one. -size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, - const Stmt& if_stmt) { +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; size_t updated_for_idx = 0; for (size_t i = 0; i < tracked_for_list.size(); ++i) { - const Stmt& current_for = - tracked_for_list.at(tracked_for_list.size() - 1 - i); + const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); if (is_first_if(current_for, if_stmt)) { updated_for_idx = tracked_for_list.size() - 1 - i; break; @@ -340,11 +333,11 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); - const Stmt& updated_for_node = - for_tracking_map_[for_stmt.get()].at(updated_for_idx); + const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx); auto generated_for_pair = RemoveIf(updated_for_node, new_if); const Stmt& then_for = generated_for_pair.first; - const Stmt& else_for = generated_for_pair.second;; + const Stmt& else_for = generated_for_pair.second; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; if (else_for.get()) { @@ -356,12 +349,10 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); - const Stmt& actual_next_for = - for_tracking_map_[original_next_for.get()].at(updated_for_idx); + const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx); Stmt update_for_stmt = update_for(actual_next_for, new_if); - for_tracking_map_[original_next_for.get()]. - at(updated_for_idx) = update_for_stmt; + for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt; } } return new_if; @@ -369,56 +360,46 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { // Mutate For nodes in post order DFS manner. Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { - PackedFunc replace_top_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - const ForNode* for_node = current_for.as(); - if (!for_node) return; - - if (top_for_var_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : - top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(HoistIf(if_stmt)); - } + PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + const ForNode* for_node = current_for.as(); + if (!for_node) return; - const IfThenElseNode* next_if_node; - const IfThenElseNode* current_if_node = - new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - CHECK(current_if_node); - const Stmt current_if_stmt = - IfThenElseNode::make(current_if_node->condition, - current_if_node->then_case, - current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - CHECK(next_if_node); - new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt, - next_if_node->else_case); - current_if_node = new_for.as(); - } + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } - if (!new_for.get()) { - const IfThenElseNode* first_if_node = new_if_list[0].as(); - CHECK(first_if_node); - new_for = IfThenElseNode::make(first_if_node->condition, - first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; + const IfThenElseNode* next_if_node; + const IfThenElseNode* current_if_node = new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = IfThenElseNode::make( + current_if_node->condition, current_if_node->then_case, current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = + IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); + current_if_node = new_for.as(); } - }); - return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); -} -Stmt HoistIfThenElse(Stmt stmt) { - return IfThenElseHoist().VisitAndMutate(stmt); + if (!new_for.get()) { + const IfThenElseNode* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); } +Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } -TVM_REGISTER_GLOBAL("testing.HoistIfThenElse") -.set_body_typed(HoistIfThenElse); +TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index a68e4ee0db84d..01a69969b4890 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -21,23 +21,23 @@ * \file arg_binder.cc * \brief Helper utility to match and bind arguments. */ -#include -#include -#include "ir_util.h" #include "arg_binder.h" + +#include +#include + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { -void BinderAddAssert(arith::Analyzer* ana, - PrimExpr cond, - const std::string& arg_name, +void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name, std::vector* asserts) { PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { - LOG(FATAL) << "Bind have an unmet assertion: " - << cond << ", " << " on argument " << arg_name; + LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; } if (!is_one(scond)) { std::ostringstream os; @@ -47,9 +47,7 @@ void BinderAddAssert(arith::Analyzer* ana, } } -bool ArgBinder::Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.dtype(), value.dtype()); if (const VarNode* v = arg.as()) { @@ -73,18 +71,14 @@ bool ArgBinder::Bind_(const PrimExpr& arg, return false; } -void ArgBinder::Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let) { Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, - const Array& value, +void ArgBinder::BindArray(const Array& arg, const Array& value, const std::string& arg_name) { - CHECK_EQ(arg.size(), value.size()) - << "Argument " << arg_name << " array size mismatch"; + CHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { std::ostringstream os; os << arg_name << "[" << i << "]"; @@ -92,16 +86,11 @@ void ArgBinder::BindArray(const Array& arg, } } -void ArgBinder::BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, +void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - CHECK_EQ(arg->scope, value->scope) - << "Argument " << arg_name - << " Buffer bind scope mismatch"; + CHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; CHECK_EQ(arg->dtype, value->dtype) - << "Argument " << arg_name - << " Buffer bind data type mismatch"; + << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " << " required_alignment=" << arg->data_alignment @@ -121,9 +110,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, - truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } @@ -132,8 +120,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { CHECK(is_one(analyzer_.Simplify(value->shape[i]))) - << "Argument " << arg_name << " shape mismatch" - << arg->shape << " vs " << value->shape; + << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } for (size_t i = 0; i < arg->shape.size(); ++i) { std::ostringstream os; @@ -159,22 +146,17 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind k return TVMStructGet(t, arr, 0, kind); } -void ArgBinder::BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, +void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, + const PrimExpr& device_id, const Var& handle, const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = EvaluateNode::make(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); - PrimExpr a_ndim = make_const(tvm_ndim_type, - static_cast(buffer->shape.size())); + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; - ndim_err_msg << arg_name - << ".ndim is expected to equal " - << buffer->shape.size(); + ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); // type checks @@ -182,14 +164,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1))) { + IntImm(DataType::UInt(8), dtype.code()) && + TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + IntImm(DataType::UInt(8), dtype.bits()) && + TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + IntImm(DataType::UInt(16), dtype.lanes())); + if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); @@ -200,9 +180,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs - init_nest_.emplace_back(AttrStmtNode::make( - vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); + init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), + nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -210,28 +190,24 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, init_nest_.emplace_back(LetStmtNode::make( v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1)) { + if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; } std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; - Bind_(buffer->shape[k], - cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_shape, - IntImm(DataType::Int(32), k), const_true(1))), - field_name.str(), true); + Bind_( + buffer->shape[k], + cast(buffer->shape[k].dtype(), + LoadNode::make(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), + field_name.str(), true); } // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmtNode::make( - v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), - nop)); - PrimExpr is_null = CallNode::make( - DataType::Bool(1), intrinsic::tvm_handle_is_null, - {v_strides}, CallNode::PureIntrinsic); + v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + PrimExpr is_null = CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, + CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -239,10 +215,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = cast( - stype, - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = cast(stype, LoadNode::make(tvm_shape_type, v_strides, + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,9 +225,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, << " expected to be compact array"; if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); - Stmt check = - AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_msg, EvaluateNode::make(0)); + Stmt check = AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), + stride_msg, EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } @@ -264,9 +237,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - PrimExpr value = cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr value = cast( + buffer->shape[k].dtype(), + LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -283,8 +256,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))), + LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), + const_true(1))), field_name.str(), true); } } @@ -293,7 +266,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, if (const auto* const_offset = buffer->elem_offset.as()) { Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), + TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, @@ -305,18 +278,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(&analyzer_, - truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } } // device info. - Bind_(device_type, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/transforms/arg_binder.h b/src/tir/transforms/arg_binder.h index 1769950b8979b..657ebdbec1345 100644 --- a/src/tir/transforms/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -24,13 +24,13 @@ #ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_ #define TVM_TIR_TRANSFORMS_ARG_BINDER_H_ -#include -#include #include +#include +#include #include -#include #include +#include namespace tvm { namespace tir { @@ -63,10 +63,7 @@ class ArgBinder { * \param def_map A definition map that contains definition of known variables. * ArgBinder will update this def_map when adding new definitions. */ - explicit ArgBinder( - std::unordered_map* def_map) - : def_map_(def_map) { - } + explicit ArgBinder(std::unordered_map* def_map) : def_map_(def_map) {} /*! * \brief Try to bind arg to value, generate constraint if necessary. * \param arg The argument to be binded. @@ -74,9 +71,7 @@ class ArgBinder { * \param arg_name argument name. * \param with_let Whether add lets during bind */ - void Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let = false); /*! * \brief Bind array to array @@ -84,19 +79,17 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, - const Array& value, + void BindArray(const Array& arg, const Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. - * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as + * arg's higher dimensions are of 1. */ - void BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, + void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match); /*! * \brief Bind symbolic buffer to a DLTensor handle. @@ -106,20 +99,13 @@ class ArgBinder { * \param handle The DLTensor handle. * \param arg_name argument name. */ - void BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, - const std::string& arg_name); + void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, + const Var& handle, const std::string& arg_name); /*! \return The defs generated in binding. */ - const std::vector& defs() const { - return defs_; - } + const std::vector& defs() const { return defs_; } /*! \return The asserts generated in binding */ - const std::vector& asserts() const { - return asserts_; - } + const std::vector& asserts() const { return asserts_; } /*! * \brief Initialization nest generated * This is only non-empty when BindDLTensor is called. @@ -131,19 +117,13 @@ class ArgBinder { * Let statement is usually generated when bind to DLTensor and memory load is involved. * \return The initialization nest generated during binding. */ - const std::vector& init_nest() const { - return init_nest_; - } + const std::vector& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { - return def_handle_dtype_; - } + const Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function - bool Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets); /*! \brief The definition map, can be uses to substitute */ std::unordered_map* def_map_; diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 4b1c0094b194d..2e1e5b97fcf97 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -22,15 +22,16 @@ */ // Instrument checkers for out of the bounds access. -#include #include +#include #include #include -#include #include -#include +#include + #include #include +#include namespace tvm { namespace tir { @@ -41,20 +42,19 @@ class BoundCollector : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::buffer_bound) { - if (const VarNode *key = op->node.as()) { + if (const VarNode* key = op->node.as()) { mem_to_shape[key] = op->value; } } StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker( - const std::unordered_map &mem_to_shape) + explicit BoundChecker(const std::unordered_map& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -86,10 +86,8 @@ class BoundChecker : public StmtExprMutator { PrimExpr condition = MakeCondition(); if (!condition.as()) { Stmt nop = EvaluateNode::make(1); - Stmt then_case = - StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); - Stmt else_case = - AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); + Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); + Stmt else_case = AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); Stmt body = IfThenElseNode::make(condition, then_case, else_case); return body; } @@ -109,9 +107,7 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, - const Array& new_shape, - const DataType& type) { + void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { // Sanity check at first. if (!new_shape.size()) { return; @@ -126,11 +122,11 @@ class BoundChecker : public StmtExprMutator { // Scalarize the shape. PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[0])); + CastNode::make(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[i]))); + CastNode::make(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -140,23 +136,21 @@ class BoundChecker : public StmtExprMutator { return false; } - if (const RampNode *ramp_index = index.as()) { - return ramp_index->base.defined() && - ramp_index->base.dtype().is_scalar() && - ramp_index->stride.defined() && - ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0); + if (const RampNode* ramp_index = index.as()) { + return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() && + ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() && + (ramp_index->lanes > 0); } return true; } bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { - return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && - IndexIsValid(index) && !unsafe_rewritten_; + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && + !unsafe_rewritten_; } void Collect(PrimExpr index, Var buffer_var) { - store_scope_bound_collector_.push_back( - std::make_pair(index, mem_to_shape_[buffer_var.get()])); + store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()])); } PrimExpr MakeCondition() { @@ -166,13 +160,12 @@ class BoundChecker : public StmtExprMutator { PrimExpr index = buffer_to_mem.first; PrimExpr upper_bound = buffer_to_mem.second; - if (const RampNode *ramp_index = index.as()) { + if (const RampNode* ramp_index = index.as()) { // In case index is base + stride * i. // Non inclusive range. - index = AddNode::make( - ramp_index->base, - MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), - ramp_index->lanes - 1))); + index = AddNode::make(ramp_index->base, MulNode::make(ramp_index->stride, + make_const(ramp_index->stride.dtype(), + ramp_index->lanes - 1))); } // Try to simplify index and bound. @@ -188,8 +181,7 @@ class BoundChecker : public StmtExprMutator { PrimExpr current_condition = AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound)); - condition = - !i ? current_condition : AndNode::make(condition, current_condition); + condition = !i ? current_condition : AndNode::make(condition, current_condition); } return condition; } @@ -201,9 +193,9 @@ class BoundChecker : public StmtExprMutator { // Pool which collects the pair of index and shape for specific store/load. std::vector> store_scope_bound_collector_; // Error message. - const char *const error_message_ = "OUT OF THE BOUNDS"; + const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map mem_to_shape_; // internal analyzer arith::Analyzer analyzer_; }; @@ -230,7 +222,7 @@ Pass InstrumentBoundCheckers() { } TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") -.set_body_typed(InstrumentBoundCheckers); + .set_body_typed(InstrumentBoundCheckers); } // namespace transform diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index c17d66562ef15..9e5e4ae6cfec7 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -22,14 +22,13 @@ * * \file combine_context_call.cc */ +#include +#include +#include #include #include #include #include -#include -#include -#include - #include @@ -44,7 +43,7 @@ class ContextCallCombiner final : public StmtExprMutator { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; - auto it = ctx_map_.find(ctx); + auto it = ctx_map_.find(ctx); if (it != ctx_map_.end()) { return it->second; } else { @@ -65,8 +64,7 @@ class ContextCallCombiner final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::coproc_uop_scope) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable std::unordered_map temp; std::swap(temp, ctx_map_); @@ -91,14 +89,11 @@ class ContextCallCombiner final : public StmtExprMutator { } } - Stmt Combine(Stmt stmt) { - return BuildContext(ctx_map_, this->VisitStmt(stmt)); - } + Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->VisitStmt(stmt)); } private: static Stmt BuildContext( - const std::unordered_map& cmap, - Stmt body) { + const std::unordered_map& cmap, Stmt body) { for (const auto& kv : cmap) { body = LetStmtNode::make(kv.second, kv.first, body); } @@ -108,7 +103,6 @@ class ContextCallCombiner final : public StmtExprMutator { std::unordered_map ctx_map_; }; - namespace transform { Pass CombineContextCall() { @@ -120,8 +114,7 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall") -.set_body_typed(CombineContextCall); +TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 174564f26245f..41e1124b0df04 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -21,11 +21,13 @@ * \file coproc_sync.cc */ #include -#include #include #include +#include + #include #include + #include "ir_util.h" #include "storage_access.h" @@ -89,11 +91,9 @@ class CoProcTouchedBuffer : public StmtExprVisitor { // Synchronization planning with co-processor. class CoProcSyncPlanner : public StorageAccessVisitor { public: - explicit CoProcSyncPlanner( - const std::unordered_set& touched, - const std::string& coproc_name) - : touched_(touched), coproc_name_(coproc_name) { - } + explicit CoProcSyncPlanner(const std::unordered_set& touched, + const std::string& coproc_name) + : touched_(touched), coproc_name_(coproc_name) {} void Plan(const Stmt& stmt) { this->VisitStmt(stmt); @@ -107,22 +107,19 @@ class CoProcSyncPlanner : public StorageAccessVisitor { std::unordered_map > sync_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { return PlanSync(seq, loop, false); } private: // Plan write synchronization if write is not coherent - std::vector PlanSync( - std::vector seq, const ForNode* loop, - bool force_sync_at_end) { + std::vector PlanSync(std::vector seq, const ForNode* loop, + bool force_sync_at_end) { // detect write barriers // access by the co-processor. std::vector co_access; @@ -131,8 +128,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { auto find_conflict = [&](const AccessEntry& acc) { for (const AccessEntry& x : co_access) { if (x.buffer.same_as(acc.buffer) && - ((acc.type == kRead && x.type == kWrite) || - acc.type == kWrite)) { + ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { return true; } } @@ -143,7 +139,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { bool sync_write = false; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_write = true; break; + sync_write = true; + break; } if (acc.type == kSync) { co_access.clear(); @@ -169,7 +166,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { const StmtEntry& s = seq[i]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_at_end = true; break; + sync_at_end = true; + break; } } if (sync_.count(s.stmt) || sync_at_end) break; @@ -197,10 +195,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {EvaluateNode::make(CallNode::make( - DataType::Int(32), - sync_name, - {}, CallNode::Intrinsic))}; + return { + EvaluateNode::make(CallNode::make(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -210,9 +206,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { // Detect memory barriers when coproc read/write memory class CoProcBarrierDetector : public StorageAccessVisitor { public: - explicit CoProcBarrierDetector( - const std::unordered_set& touched, - const std::string& coproc_name) + explicit CoProcBarrierDetector(const std::unordered_set& touched, + const std::string& coproc_name) : touched_(touched) { read_barrier_name_ = coproc_name + ".coproc_read_barrier"; write_barrier_name_ = coproc_name + ".coproc_write_barrier"; @@ -233,14 +228,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { std::unordered_map > barrier_after_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { if (read_barrier_) { return PlanReadBarrier(seq, loop); } else { @@ -250,17 +243,15 @@ class CoProcBarrierDetector : public StorageAccessVisitor { private: // Plan write barrier at Read after write point. - std::vector PlanWriteBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { std::vector read_seq; std::unordered_map > write_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = write_set.find(acc.buffer.get()); + auto it = write_set.find(acc.buffer.get()); if (it != write_set.end()) { CHECK_NE(i, 0U); - barrier_after_[seq[i - 1].stmt].push_back( - MakeBarrier(write_barrier_name_, it->second)); + barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); write_set.erase(it); } }; @@ -284,23 +275,21 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(seq.size(), acc); } } - for (const auto &kv : write_set) { + for (const auto& kv : write_set) { read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); } return read_seq; } - std::vector PlanReadBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { std::vector write_seq; std::unordered_map > read_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = read_set.find(acc.buffer.get()); + auto it = read_set.find(acc.buffer.get()); if (it != read_set.end()) { CHECK_NE(i, seq.size()); - barrier_before_[seq[i].stmt].push_back( - MakeBarrier(read_barrier_name_, it->second)); + barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); read_set.erase(it); } }; @@ -325,7 +314,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(0, acc); } } - for (const auto &kv : read_set) { + for (const auto& kv : read_set) { write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); } return write_seq; @@ -340,13 +329,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } Range none; Range r = arith::Union(wset).cover_range(none); - CHECK(r.defined()) - << "Cannot deduce write range of " << wvec[0].buffer; + CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; return EvaluateNode::make(CallNode::make( - DataType::Int(32), func, - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); + DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, + CallNode::Intrinsic)); } // Write barrier name bool read_barrier_{false}; @@ -355,12 +343,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor { const std::unordered_set& touched_; }; - class CoProcInstDepDetector : public StmtVisitor { public: - explicit CoProcInstDepDetector( - const IterVar& coproc_axis, - const std::string& coproc_name) + explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { sync_push_name_ = coproc_name + ".coproc_dep_push"; sync_pop_name_ = coproc_name + ".coproc_dep_pop"; @@ -375,8 +360,7 @@ class CoProcInstDepDetector : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::coproc_scope && - op->node.same_as(coproc_axis_)) { + if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { const IntImmNode* ctx_id = op->value.as(); CHECK(ctx_id != nullptr); curr_state_.clear(); @@ -399,9 +383,7 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state_.node = op; CHECK(first_state_.node != nullptr); // loop carry dependency - InjectSync(last_state_, first_state_, - &(curr_state_.exit_push), - &(curr_state_.enter_pop)); + InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); curr_state_.enter_ctx = first_state_.enter_ctx; curr_state_.exit_ctx = last_state_.exit_ctx; } @@ -423,12 +405,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } first_state_.clear(); last_state_.clear(); @@ -439,12 +417,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } } // update in the trace. @@ -487,15 +461,14 @@ class CoProcInstDepDetector : public StmtVisitor { // record the push/pop sequence that could be possibly un-matched. // return the push/pop message at enter/exit of the Block // after considering the existing unmatcheded events and added events - void InjectSync(const SyncState& prev, - const SyncState& next, + void InjectSync(const SyncState& prev, const SyncState& next, std::vector >* prev_exit_push, std::vector >* next_enter_pop) { prev_exit_push->clear(); next_enter_pop->clear(); // quick path - if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && - prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) { + if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && + next.enter_ctx.size() == 1) { int from = *prev.exit_ctx.begin(); int to = *next.enter_ctx.begin(); if (from != to) { @@ -520,15 +493,11 @@ class CoProcInstDepDetector : public StmtVisitor { // policy 1 std::vector prev_after, next_before; for (const std::pair& p : pending) { - if (std::find(prev.exit_push.begin(), - prev.exit_push.end(), p) == - prev.exit_push.end()) { + if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { vpush.push_back(p); prev_after.emplace_back(MakePush(p.first, p.second)); } - if (std::find(next.enter_pop.begin(), - next.enter_pop.end(), p) == - next.enter_pop.end()) { + if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { vpop.push_back(p); next_before.emplace_back(MakePop(p.first, p.second)); } @@ -549,18 +518,18 @@ class CoProcInstDepDetector : public StmtVisitor { } } if (prev_after.size() != 0) { - auto &v1 = insert_after_[prev.node]; + auto& v1 = insert_after_[prev.node]; v1.insert(v1.end(), prev_after.begin(), prev_after.end()); } if (next_before.size() != 0) { - auto &v2 = insert_before_[next.node]; + auto& v2 = insert_before_[next.node]; v2.insert(v2.end(), next_before.begin(), next_before.end()); } } void MatchFixEnterPop(const SyncState& state) { if (state.enter_pop.size() == 0) return; - auto &vec = insert_before_[state.node]; + auto& vec = insert_before_[state.node]; for (const std::pair& p : state.enter_pop) { vec.push_back(MakePush(p.first, p.second)); } @@ -568,7 +537,7 @@ class CoProcInstDepDetector : public StmtVisitor { void MatchFixExitPush(const SyncState& state) { if (state.exit_push.size() == 0) return; - auto &vec = insert_after_[state.node]; + auto& vec = insert_after_[state.node]; for (const std::pair& p : state.exit_push) { vec.push_back(MakePop(p.first, p.second)); } @@ -587,16 +556,16 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_push_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_pop_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } // sync states. SyncState first_state_, last_state_, curr_state_; @@ -605,7 +574,6 @@ class CoProcInstDepDetector : public StmtVisitor { std::string sync_push_name_, sync_pop_name_; }; - class CoProcSyncInserter : public StmtMutator { public: Stmt Insert(Stmt stmt) { @@ -614,7 +582,7 @@ class CoProcSyncInserter : public StmtMutator { if (visitor.coproc_.size() == 0) return stmt; std::unordered_set touched; - for (const auto &kv : visitor.touched_) { + for (const auto& kv : visitor.touched_) { if (kv.second.normal && kv.second.coproc) { touched.insert(kv.first); } @@ -641,8 +609,7 @@ class CoProcSyncInserter : public StmtMutator { vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } // Detect barrier - CoProcInstDepDetector sync_detector( - *visitor.coproc_.begin(), coproc_name); + CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); sync_detector.Plan(stmt); for (const auto& kv : sync_detector.insert_before_) { auto& vec = insert_before_[kv.first]; @@ -661,9 +628,8 @@ class CoProcSyncInserter : public StmtMutator { Stmt new_stmt = StmtMutator::VisitStmt(stmt); return SeqStmt::Flatten( - it_before != insert_before_.end() ? it_before->second : std::vector(), - new_stmt, - it_after != insert_after_.end() ? it_after->second : std::vector()); + it_before != insert_before_.end() ? it_before->second : std::vector(), new_stmt, + it_after != insert_after_.end() ? it_after->second : std::vector()); } private: @@ -673,10 +639,7 @@ class CoProcSyncInserter : public StmtMutator { std::unordered_map > insert_after_; }; - -Stmt CoProcSync(Stmt stmt) { - return CoProcSyncInserter().Insert(std::move(stmt)); -} +Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } namespace transform { @@ -689,8 +652,7 @@ Pass CoProcSync() { return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CoProcSync") -.set_body_typed(CoProcSync); +TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync); } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 7ff2e3f7d17ea..0decb94df03be 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -21,18 +21,15 @@ * \file decorate_device_scope.cc */ #include -#include #include +#include #include namespace tvm { namespace tir { Stmt DecorateDeviceScope(Stmt&& stmt) { - Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), - tir::attr::device_scope, - 0, - stmt); + Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); return body; } @@ -47,8 +44,7 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope") -.set_body_typed(DecorateDeviceScope); +TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 86bbefc22830b..5af1a39d27c60 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -21,12 +21,13 @@ * \brief Replace certain copy with copy intrinsics. * \file copy_intrin_rewrite.cc */ -#include -#include -#include #include +#include +#include #include #include +#include + #include "../../arith/pattern_match.h" namespace tvm { @@ -36,11 +37,9 @@ using runtime::PackedFunc; class CopyIntrinInjector : public StmtMutator { public: - CopyIntrinInjector(const std::string& pragma_key, - const PackedFunc& flower_copy_fromto) - : pragma_key_(attr::pragma_scope_prefix+ pragma_key), - flower_copy_fromto_(flower_copy_fromto) { - } + CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) + : pragma_key_(attr::pragma_scope_prefix + pragma_key), + flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { @@ -48,15 +47,14 @@ class CopyIntrinInjector : public StmtMutator { storage_scope_[buf] = op->value.as()->value; } else if (op->attr_key == pragma_key_) { Stmt ret; - CHECK(MatchCopyPattern(op->body, &ret)) - << "Cannot match copy pattern of " << op->body; + CHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; } return StmtMutator::VisitStmt_(op); } private: - bool MatchCopyPattern(Stmt stmt, Stmt *out) { + bool MatchCopyPattern(Stmt stmt, Stmt* out) { using namespace arith; Stmt body = stmt; @@ -72,9 +70,8 @@ class CopyIntrinInjector : public StmtMutator { // Expr sel_cond, sel_true_value, sel_false_value; // match select or if PVar sel_cond, sel_true_value, sel_false_value; - bool has_cond = - if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || - select(sel_cond, sel_true_value, sel_false_value).Match(store->value); + bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || + select(sel_cond, sel_true_value, sel_false_value).Match(store->value); const CastNode* cast = store->value.as(); const LoadNode* load = store->value.as(); @@ -95,11 +92,9 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = - arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = - arith::DetectLinearEquation(load->index, loop_vars); - if (load_strides.size() == 0 || store_strides.size() == 0) return false; + Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); + Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); + if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; const size_t loop_var_size = loop_vars.size(); if (loop_var_size == 0) { @@ -114,8 +109,7 @@ class CopyIntrinInjector : public StmtMutator { PrimExpr pad_value; PrimExpr src_elem_offset = load_strides[loop_var_size]; if (has_cond) { - Array clip_bound = - arith::DetectClipBound(sel_cond.Eval(), loop_vars); + Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); pad_value = sel_false_value.Eval(); if (clip_bound.size() == 0) return false; CHECK_EQ(src_shape.size(), loop_vars.size()); @@ -150,27 +144,15 @@ class CopyIntrinInjector : public StmtMutator { Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); if (loop_var_size == 0) { - src_strides.push_back(make_const(DataType::Int(32), 1)); - dst_strides.push_back(make_const(DataType::Int(32), 1)); + src_strides.push_back(make_const(DataType::Int(32), 1)); + dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = BufferNode::make( - store->buffer_var, - store->value.dtype(), - dst_shape, - dst_strides, - store_strides[loop_var_size], - store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), - 0, 0, kDefault); - Buffer src = BufferNode::make( - load->buffer_var, - load->dtype, - src_shape, - src_strides, - src_elem_offset, - load->buffer_var->name_hint, - GetStorageScope(load->buffer_var.get()), - 0, 0, kDefault); + Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, + store_strides[loop_var_size], store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides, + src_elem_offset, load->buffer_var->name_hint, + GetStorageScope(load->buffer_var.get()), 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; @@ -194,28 +176,23 @@ class CopyIntrinInjector : public StmtMutator { arith::Analyzer analyzer_; }; -Stmt InjectCopyIntrin(Stmt stmt, - const std::string& pragma_key, +Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, const PackedFunc& flower_copy_fromto) { return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } - namespace transform { -Pass InjectCopyIntrin(std::string pragma_key, - PackedFunc flower_copy_fromto) { +Pass InjectCopyIntrin(std::string pragma_key, PackedFunc flower_copy_fromto) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = CopyIntrinInjector( - pragma_key, flower_copy_fromto)(std::move(n->body)); + n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin") -.set_body_typed(InjectCopyIntrin); +TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin); } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 4e5d08c696365..018997848fd07 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -22,11 +22,12 @@ * \file inject_double_buffer.cc */ #include -#include -#include #include -#include "ir_util.h" +#include +#include + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -52,7 +53,6 @@ class DoubleBufferDetector : public StmtExprVisitor { std::unordered_set touched_; }; - class StripDoubleBufferWrite : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -66,8 +66,7 @@ class StripDoubleBufferWrite : public StmtMutator { class DoubleBufferInjector : public StmtExprMutator { public: - explicit DoubleBufferInjector(int split_loop) - : split_loop_(split_loop) {} + explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} Stmt Inject(Stmt stmt) { DoubleBufferDetector detector; @@ -99,8 +98,8 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + it->second.stride = + arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); Array new_extents{make_const(op->extents[0].dtype(), 2)}; @@ -109,13 +108,11 @@ class DoubleBufferInjector : public StmtExprMutator { } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmtNode::make( - op->buffer_var, attr::storage_scope, - StringImmNode::make(it->second.scope), - EvaluateNode::make(0))); - alloc_nest.emplace_back(AllocateNode::make( - op->buffer_var, op->dtype, new_extents, op->condition, - EvaluateNode::make(0))); + alloc_nest.emplace_back(AttrStmtNode::make(op->buffer_var, attr::storage_scope, + StringImmNode::make(it->second.scope), + EvaluateNode::make(0))); + alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents, + op->condition, EvaluateNode::make(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); @@ -134,8 +131,7 @@ class DoubleBufferInjector : public StmtExprMutator { << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); PrimExpr zero = old_loop->min; - PrimExpr new_ext = - old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); + PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); PrimExpr factor = make_const(new_ext.dtype(), split_loop_); PrimExpr outer_ext = new_ext / factor; PrimExpr tail_base = outer_ext * factor; @@ -146,9 +142,8 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = ForNode::make( - outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - SeqStmt::Flatten(loop_seq)); + Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type, + old_loop->device_api, SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); @@ -156,8 +151,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; tail_seq.emplace_back( - IfThenElseNode::make(idx < old_loop->extent, - Substitute(tail_body, vmap))); + IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap))); } stmt = SeqStmt::Flatten(loop, tail_seq); } @@ -179,10 +173,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(in_double_buffer_scope_); CHECK(e.stride.defined()); - return StoreNode::make(op->buffer_var, - op->value, - e.switch_write_var * e.stride + op->index, - op->predicate); + return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, + op->predicate); } else { return stmt; } @@ -196,10 +188,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); - return LoadNode::make(op->dtype, - op->buffer_var, - e.switch_read_var * e.stride + op->index, - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, + op->predicate); } else { return expr; } @@ -213,8 +203,7 @@ class DoubleBufferInjector : public StmtExprMutator { private: Stmt MakeProducer(const AttrStmtNode* op) { const Var buffer = Downcast(op->node); - CHECK_NE(loop_nest_.size(), 0U) - << "Double buffer scope must be inside a loop"; + CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; @@ -226,8 +215,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr one = make_const(e.loop->loop_var.dtype(), 1); PrimExpr two = make_const(e.loop->loop_var.dtype(), 2); PrimExpr loop_shift = e.loop->loop_var + one; - e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", - e.loop->loop_var.dtype()); + e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; Stmt body = this->VisitStmt(op->body); @@ -270,12 +258,10 @@ class DoubleBufferInjector : public StmtExprMutator { std::unordered_map dbuffer_info_; }; - Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } - namespace transform { Pass InjectDoubleBuffer(int split_loop) { @@ -287,8 +273,7 @@ Pass InjectDoubleBuffer(int split_loop) { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer") -.set_body_typed(InjectDoubleBuffer); +TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index e9dae0a5dfc90..3b626f0108a11 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -21,20 +21,21 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR +#include +#include #include #include #include #include #include -#include -#include + #include namespace tvm { namespace tir { -using arith::IntSet; using arith::DomainTouched; +using arith::IntSet; class PrefetchInjector : public StmtMutator { public: @@ -68,7 +69,7 @@ class PrefetchInjector : public StmtMutator { } Stmt VisitStmt_(const ForNode* op) final { - auto &var = op->loop_var; + auto& var = op->loop_var; loop_nest_.push_back(var); if (op->for_type == ForType::Vectorized) { vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1); @@ -83,16 +84,13 @@ class PrefetchInjector : public StmtMutator { private: std::vector loop_nest_; - std::unordered_map vectorized_; + std::unordered_map vectorized_; static const Range none; }; const Range PrefetchInjector::none; -Stmt InjectPrefetch(Stmt stmt) { - return PrefetchInjector()(std::move(stmt)); -} - +Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } namespace transform { @@ -105,8 +103,7 @@ Pass InjectPrefetch() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch") -.set_body_typed(InjectPrefetch); +TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch").set_body_typed(InjectPrefetch); } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 01fb6fe0bda87..834a7e908f76d 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -24,9 +24,11 @@ #include #include #include + #include -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -34,8 +36,7 @@ namespace tir { // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { public: - explicit ExprTouched(const std::unordered_set &touched, - bool check_write) + explicit ExprTouched(const std::unordered_set& touched, bool check_write) : touched_var_(touched), check_write_(check_write) {} void VisitExpr(const PrimExpr& n) final { @@ -43,19 +44,17 @@ class ExprTouched final : public StmtExprVisitor { if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt& n) final { + void VisitStmt(const Stmt& n) final { // early stopping if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitStmt(n); } - void VisitExpr_(const LoadNode *op) final { + void VisitExpr_(const LoadNode* op) final { HandleUseVar(op->buffer_var.get()); StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode *op) final { - HandleUseVar(op); - } - void VisitExpr_(const CallNode *op) final { + void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } + void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); @@ -84,9 +83,7 @@ class ExprTouched final : public StmtExprVisitor { used_vars_.push_back(var); } } - void HandleWriteVar(const VarNode* var) { - write_vars_.push_back(var); - } + void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); } // the fields. bool expr_touched_{false}; std::vector used_vars_; @@ -134,8 +131,7 @@ class VarTouchedAnalysis : public StmtVisitor { Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); } - void Record(const VarNode* var, - const ExprTouched& tc) { + void Record(const VarNode* var, const ExprTouched& tc) { if (touched_var_.count(var)) return; if (tc.expr_touched_) { touched_var_.insert(var); @@ -148,14 +144,11 @@ class VarTouchedAnalysis : public StmtVisitor { } } - std::unordered_set - TouchedVar(const Stmt& stmt, - const VarNode* var) { + std::unordered_set TouchedVar(const Stmt& stmt, const VarNode* var) { touched_var_.insert(var); this->VisitStmt(stmt); // do a DFS to push affect around dependency. - std::vector pending( - touched_var_.begin(), touched_var_.end()); + std::vector pending(touched_var_.begin(), touched_var_.end()); while (!pending.empty()) { const VarNode* v = pending.back(); pending.pop_back(); @@ -173,29 +166,26 @@ class VarTouchedAnalysis : public StmtVisitor { // Whether variable is touched by the thread variable. std::unordered_set touched_var_; // x -> all the buffers x read from - std::unordered_map > affect_; + std::unordered_map > affect_; }; - // Inject virtual thread loop // rewrite the buffer access pattern when necessary. class VTInjector : public StmtExprMutator { public: // constructor - VTInjector(Var var, - int num_threads, - const std::unordered_set& touched_var, + VTInjector(Var var, int num_threads, const std::unordered_set& touched_var, bool allow_share) - : var_(var), num_threads_(num_threads), - touched_var_(touched_var), allow_share_(allow_share) { - } + : var_(var), + num_threads_(num_threads), + touched_var_(touched_var), + allow_share_(allow_share) {} // Inject VTLoop when needed. Stmt VisitStmt(const Stmt& s) final { CHECK(!visit_touched_var_); auto stmt = StmtExprMutator::VisitStmt(s); if (visit_touched_var_ || trigger_base_inject_) { - if (!vt_loop_injected_) { + if (!vt_loop_injected_) { return InjectVTLoop(stmt, false); } visit_touched_var_ = false; @@ -205,8 +195,7 @@ class VTInjector : public StmtExprMutator { } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - CHECK(!alloc_remap_.count(op)) - << "Buffer address may get rewritten in virtual thread"; + CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; if (touched_var_.count(op)) { visit_touched_var_ = true; } @@ -224,9 +213,8 @@ class VTInjector : public StmtExprMutator { } auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return LoadNode::make(op->dtype, op->buffer_var, - RewriteIndex(op->index, it->second), - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), + op->predicate); } else { return expr; } @@ -242,13 +230,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = true; PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); - PrimExpr stride = - it->second / make_const(offset.dtype(), dtype.lanes()); + PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return CallNode::make( - op->dtype, op->name, - {op->args[0], op->args[1], offset, extent, op->args[4]}, - op->call_type); + return CallNode::make(op->dtype, op->name, + {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return allow_share_ ? GetRef(op) : var_; } else { @@ -269,10 +254,8 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return StoreNode::make(op->buffer_var, - op->value, - RewriteIndex(op->index, it->second), - op->predicate); + return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second), + op->predicate); } else { return stmt; } @@ -283,13 +266,11 @@ class VTInjector : public StmtExprMutator { if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && - (op->attr_key == attr::coproc_uop_scope || - op->attr_key == attr::coproc_scope)) { + (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { return InjectVTLoop(GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return AttrStmtNode::make(op->node, op->attr_key, value, body); @@ -304,8 +285,7 @@ class VTInjector : public StmtExprMutator { } visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetStmtNode::make(op->var, value, body); @@ -323,12 +303,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); ++max_loop_depth_; - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -347,8 +325,7 @@ class VTInjector : public StmtExprMutator { else_case = this->VisitStmt(op->else_case); max_loop_depth_ = std::max(temp, max_loop_depth_); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -391,8 +368,7 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + PrimExpr stride = arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); for (PrimExpr e : extents) { @@ -408,14 +384,10 @@ class VTInjector : public StmtExprMutator { // Mutate the body. body = this->VisitStmt(op->body); } - if (!changed && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); } } @@ -445,9 +417,8 @@ class VTInjector : public StmtExprMutator { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, make_zero(idx.dtype()), - make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), + ForType::Serial, DeviceAPI::None, stmt); } } @@ -472,7 +443,6 @@ class VTInjector : public StmtExprMutator { std::unordered_map alloc_remap_; }; - class VirtualThreadInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -513,8 +483,7 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread") -.set_body_typed(InjectVirtualThread); +TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); } // namespace transform diff --git a/src/tir/transforms/ir_util.cc b/src/tir/transforms/ir_util.cc index 9ff3fca772774..ff3e941fb332b 100644 --- a/src/tir/transforms/ir_util.cc +++ b/src/tir/transforms/ir_util.cc @@ -21,11 +21,13 @@ * \file ir_util.cc * \brief Helper functions to construct and compose IR nodes. */ +#include "ir_util.h" + #include -#include -#include + #include -#include "ir_util.h" +#include +#include namespace tvm { namespace tir { @@ -84,7 +86,6 @@ Stmt MergeNest(const std::vector>& nest, Stmt body) { return body; } - class IRConvertSSA final : public StmtExprMutator { public: PrimExpr VisitExpr_(const VarNode* op) final { @@ -112,9 +113,8 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { - return LoadNode::make( - op->dtype, scope_[op->buffer_var.get()].back(), - op->index, op->predicate); + return LoadNode::make(op->dtype, scope_[op->buffer_var.get()].back(), op->index, + op->predicate); } else { return expr; } @@ -123,9 +123,8 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { - return StoreNode::make( - scope_[op->buffer_var.get()].back(), op->value, - op->index, op->predicate); + return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index, + op->predicate); } else { return stmt; } @@ -152,8 +151,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return ForNode::make( - new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -167,9 +165,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return AllocateNode::make( - new_var, op->dtype, op->extents, op->condition, - op->body); + return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -184,15 +180,13 @@ class IRConvertSSA final : public StmtExprMutator { if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); - return AttrStmtNode::make( - alloc->buffer_var, op->attr_key, op->value, new_alloc); + return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc); } } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmtNode::make( - scope_[v].back(), op->attr_key, op->value, op->body); + return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } @@ -202,13 +196,11 @@ class IRConvertSSA final : public StmtExprMutator { } private: - std::unordered_map > scope_; + std::unordered_map> scope_; std::unordered_set defined_; }; -Stmt ConvertSSA(Stmt stmt) { - return IRConvertSSA()(std::move(stmt)); -} +Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 18f79773d5f4b..69b5a3973b362 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -24,9 +24,10 @@ #ifndef TVM_TIR_TRANSFORMS_IR_UTIL_H_ #define TVM_TIR_TRANSFORMS_IR_UTIL_H_ +#include #include #include -#include + #include namespace tvm { @@ -56,7 +57,7 @@ Stmt MergeNest(const std::vector >& nest, Stmt body); * \return if update happens, return the new array, else return the * original array */ -template +template inline Array UpdateArray(Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; @@ -81,13 +82,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \param kind The data kind. * \return the get expression. */ -inline PrimExpr TVMStructGet( - DataType dtype, Var handle, int index, - intrinsic::TVMStructFieldKind kind) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind))}; +inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, + intrinsic::TVMStructFieldKind kind) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } @@ -101,7 +99,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { return CallNode::make( DataType::Handle(), intrinsic::tvm_address_of, {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}, + const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } @@ -116,11 +114,9 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return CallNode::make( - DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, offset, - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return CallNode::make(DataType::Handle(), intrinsic::tvm_address_of, + {LoadNode::make(dtype, handle, offset, const_true(dtype.lanes()))}, + CallNode::PureIntrinsic); } /*! @@ -131,14 +127,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet( - Var handle, int index, - intrinsic::TVMStructFieldKind kind, PrimExpr value) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind)), - value}; +inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, + PrimExpr value) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; return EvaluateNode::make( CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } @@ -150,8 +142,7 @@ inline Stmt TVMStructSet( */ inline DataType APIType(DataType t) { if (t.is_handle()) return t; - CHECK_EQ(t.lanes(), 1) - << "Cannot pass vector type through packed API."; + CHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; if (t.is_uint() || t.is_int()) return DataType::Int(64); CHECK(t.is_float()); return DataType::Float(64); @@ -174,7 +165,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { return align; } - /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 86b8cde2524cd..bb4e5f7678a72 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -24,8 +24,9 @@ * \file lift_attr_scope.cc */ #include -#include #include +#include + #include "ir_util.h" namespace tvm { @@ -35,14 +36,12 @@ namespace tir { // to a few specified attr keys class AttrScopeLifter : public StmtMutator { public: - explicit AttrScopeLifter(std::string attr_key) - : attr_key_(attr_key) {} + explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { - stmt = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, stmt); + stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt); } return stmt; } @@ -52,14 +51,11 @@ class AttrScopeLifter : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (attr_node_.defined()) { - Stmt body = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, op->body); + Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return AllocateNode::make( - op->buffer_var, op->dtype, - op->extents, op->condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body); } else { return stmt; } @@ -97,8 +93,7 @@ class AttrScopeLifter : public StmtMutator { // check if all decorations are common. for (size_t begin = 0; begin < attr_node.size();) { size_t end = begin + 1; - while (end < attr_node.size() && - attr_node[end].same_as(attr_node[begin]) && + while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && ValueSame(attr_value[end], attr_value[begin])) { ++end; } @@ -116,8 +111,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { - stmt = AttrStmtNode::make( - attr_node[begin], attr_key_, attr_value[begin], stmt); + stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); begin = end; @@ -137,32 +131,25 @@ class AttrScopeLifter : public StmtMutator { std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->VisitStmt(op->else_case); - if (attr_node_.defined() && - attr_value_.defined() && - first_node.defined() && - first_value.defined() && - attr_node_.same_as(first_node) && + if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && + first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElseNode::make(op->condition, then_case, else_case); } } else { if (first_node.defined()) { - then_case = AttrStmtNode::make( - first_node, attr_key_, first_value, then_case); + then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { - else_case = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, else_case); + else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); } - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElseNode::make(op->condition, then_case, else_case); @@ -192,7 +179,6 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } - namespace transform { Pass LiftAttrScope(std::string attr_key) { @@ -204,8 +190,7 @@ Pass LiftAttrScope(std::string attr_key) { return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope") -.set_body_typed(LiftAttrScope); +TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope); } // namespace transform diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index dbceb37407f57..6392e7031cd82 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -20,24 +20,26 @@ /*! * \file loop_partition.cc */ +#include +#include #include #include -#include #include -#include -#include +#include + #include #include -#include "ir_util.h" + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using arith::IntSet; using arith::DeduceBound; using arith::Intersect; +using arith::IntSet; using PartitionKey = std::pair; struct PartitionKeyHash { @@ -72,8 +74,7 @@ bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) class CandidateSelector final : public StmtExprVisitor { public: using VarIsUsed = bool; - explicit CandidateSelector(bool split_const_loop) - : split_const_loop_(split_const_loop) {} + explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {} void VisitStmt_(const ForNode* op) final { // partition const loop when sets split_const_loop_ @@ -92,7 +93,7 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -156,16 +157,16 @@ class CandidateSelector final : public StmtExprVisitor { class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { - for (const auto& kv : hint_map) { - out_vars_.insert(kv.first); - } - for (const auto& kv : relax_map) { - out_vars_.insert(kv.first); - } - } + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) + : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + for (const auto& kv : hint_map) { + out_vars_.insert(kv.first); + } + for (const auto& kv : relax_map) { + out_vars_.insert(kv.first); + } + } void VisitStmt_(const ForNode* op) final { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; @@ -198,21 +199,18 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, - std::unordered_set({current_var_.get()}))) { + if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. - IntSet interval = - DeduceBound(current_var_, cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is true within interval partitions[{cond.get(), true}] = interval; } PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { - IntSet interval = - DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is false within interval partitions[{cond.get(), false}] = interval; @@ -261,7 +259,7 @@ class PartitionFinder : public StmtExprVisitor { class ConditionEliminator : public StmtExprMutator { public: explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) - : ps_(ps), cond_value_(cond_value) {} + : ps_(ps), cond_value_(cond_value) {} PrimExpr VisitExpr(const PrimExpr& e) final { if (ps_.find(e.get()) != ps_.end()) { @@ -275,12 +273,11 @@ class ConditionEliminator : public StmtExprMutator { bool cond_value_; }; - // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public StmtMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, - PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} + explicit ThreadPartitionInserter(const std::unordered_set& ps, PrimExpr cond) + : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -310,8 +307,7 @@ class ThreadPartitionInserter : public StmtMutator { // likely conditions class LoopPartitioner : public StmtMutator { public: - explicit LoopPartitioner(bool split_const_loop) - : selector(CandidateSelector(split_const_loop)) {} + explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {} Stmt VisitAndMutate(Stmt stmt) { selector(stmt); @@ -320,15 +316,14 @@ class LoopPartitioner : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), op->loop_var, - op->min, op->min + op->extent - 1, op->body, false); + Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, + op->body, false); if (s.defined()) return s; } // normal path when loop partition fails // normal loop variable can be put into hint map. - hint_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = StmtMutator::VisitStmt_(op); hint_map_.erase(op->loop_var.get()); return res; @@ -339,7 +334,7 @@ class LoopPartitioner : public StmtMutator { return StmtMutator::VisitStmt_(op); } - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; if (selector.candidates.count(op)) { @@ -352,13 +347,11 @@ class LoopPartitioner : public StmtMutator { Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. - relax_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); relax_map_.erase(var.get()); } else { - hint_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); hint_map_.erase(var.get()); } @@ -366,13 +359,11 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, - PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, + Stmt body, bool partition_thread_scope); - std::pair> - GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value); + std::pair> GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value); inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); @@ -385,18 +376,15 @@ class LoopPartitioner : public StmtMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> -LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value) { +std::pair> LoopPartitioner::GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; - for (const auto &kv : partitions) { + for (const auto& kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet interval = Downcast(kv.second); - arith::IntervalSet intersection = arith::Intersect( - &analyzer_, interval, for_interval); + arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval); if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); @@ -453,13 +441,8 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Object* node, - const Stmt& stmt, - Var var, - PrimExpr min, - PrimExpr max, - Stmt body, - bool partition_thread_scope) { +Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min, + PrimExpr max, Stmt body, bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::interval(min, max)}); @@ -476,7 +459,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, std::unordered_set cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, true); + GetIntervalAndCondset(finder.partitions, for_interval, true); if (middle_interval.is_nothing()) { // if such interval doesn't exist, find an interval in which all // conditions on var are false @@ -507,8 +490,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(body_begin == min)) { PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = MaxNode::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; @@ -533,15 +515,13 @@ Stmt LoopPartitioner::TryPartition(const Object* node, // require the extent to be non-negative PrimExpr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = MinNode::make(post_doubt_begin, max+1); + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; + post_doubt_begin = MinNode::make(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); + Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } @@ -583,8 +563,8 @@ Stmt LoopPartitioner::TryPartition(const Object* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) { - const ForNode *for_node = static_cast(node); +inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { + const ForNode* for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore @@ -597,7 +577,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b class RemoveLikelyTags : public StmtExprMutator { public: - PrimExpr VisitExpr_(const CallNode *op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); @@ -613,7 +593,6 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) { return stmt; } - namespace transform { Pass LoopPartition(bool split_const_loop) { @@ -625,8 +604,7 @@ Pass LoopPartition(bool split_const_loop) { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LoopPartition") -.set_body_typed(LoopPartition); +TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index ce81528b8b35c..92b463c914c0f 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -21,10 +21,11 @@ * \brief Pass for lowering custom datatypes */ +#include +#include #include #include -#include -#include + #include "../../target/datatype/registry.h" namespace tvm { @@ -79,9 +80,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return AllocateNode::make( - allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body); + return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body); } return stmt; } @@ -97,19 +97,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } -#define DEFINE_MUTATE__(OP, NodeName) \ - inline PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (toBeLowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr); \ - } \ - return expr; \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ } DEFINE_MUTATE__(Add, AddNode); @@ -131,15 +131,13 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; }; - namespace transform { Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerCustomDatatypes: Require the target attribute"; + CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; @@ -147,8 +145,7 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes") -.set_body_typed(LowerCustomDatatypes); +TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); } // namespace transform diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index dac426d8c2731..a8424624d5f46 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -21,20 +21,21 @@ * \file lower_device_storage_access.cc * \brief Lower the special device storage access. */ -#include -#include -#include #include -#include #include -#include "ir_util.h" +#include +#include +#include +#include + #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: @@ -51,8 +52,7 @@ class StorageAccessInfoLower : public StmtExprMutator { << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { - return LetStmtNode::make( - op->buffer_var, info->head_address, op->body); + return LetStmtNode::make(op->buffer_var, info->head_address, op->body); } else { return op->body; } @@ -99,30 +99,23 @@ class StorageAccessInfoLower : public StmtExprMutator { PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr( - op->dtype, buffer_var, dtype, offset, - it->second.info); + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); } CHECK(op->dtype.is_handle()); // Change to address_of return AddressOffset(buffer_var, dtype, offset); } - PrimExpr MakeTaggedAccessPtr(DataType ptr_type, - Var buffer_var, - DataType dtype, - PrimExpr offset, + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset, const MemoryInfo& info) { if (ptr_type.is_handle()) { - CHECK(info->head_address.defined()) - << buffer_var << " is not adddressable."; + CHECK(info->head_address.defined()) << buffer_var << " is not adddressable."; return AddressOffset(buffer_var, dtype, offset); } int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); - return cast(ptr_type, - analyzer_.Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); + return cast(ptr_type, analyzer_.Simplify( + offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { @@ -139,9 +132,7 @@ class StorageAccessInfoLower : public StmtExprMutator { arith::Analyzer analyzer_; }; -Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower()(std::move(stmt)); -} +Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } namespace transform { @@ -151,12 +142,11 @@ Pass LowerDeviceStorageAccessInfo() { n->body = StorageAccessInfoLower()(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") -.set_body_typed(LowerDeviceStorageAccessInfo); + .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index a909d4c6b83cd..7df8fd257ca58 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -21,23 +21,24 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ +#include +#include #include +#include #include -#include -#include -#include #include -#include "../../arith/pattern_match.h" + #include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" namespace tvm { namespace tir { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: - using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; IntrinInjecter(arith::Analyzer* analyzer, std::string target_name) : IRMutatorWithAnalyzer(analyzer) { @@ -50,8 +51,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { PrimExpr r = ApplyPattern(op->name, GetRef(op)); if (r.defined()) return r; } @@ -78,16 +78,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. return op->a >> make_const(dtype, shift); } if (analyzer_->CanProveGreaterEqual(op->b, 0)) { // Common path, positive divisor - if (analyzer_->CanProveGreaterEqual(op->a, 0) || - analyzer_->CanProveGreaterEqual(e, 0)) { + if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { return truncdiv(op->a, op->b); } else { DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; @@ -100,7 +98,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return tir::SelectNode::make(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -110,9 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rdiv, rdiv - make_const(dtype, 1)); + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1)); } } @@ -125,11 +122,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. - int64_t mask = ( - static_cast(1) << static_cast(shift)) - 1; + int64_t mask = (static_cast(1) << static_cast(shift)) - 1; return op->a & make_const(dtype, mask); } @@ -160,9 +155,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rmod, rmod + op->b); + return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, + rmod + op->b); } } @@ -171,8 +165,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PVar x, y; PVar c; auto e = GetRef(op); - if (max(floordiv(x, y), c).Match(e) && - c.Eval()->value >= 0 && + if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); } @@ -232,15 +225,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return e; } - PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, - const AddNode* op) { + PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) { // emit fma instruction: a * b + c PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(CallNode::make( - op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = + (*fma_)(CallNode::make(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -288,18 +280,15 @@ Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerIntrin: Require the target attribute"; + CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = - IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); + n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin") -.set_body_typed(LowerIntrin); +TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); } // namespace transform diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9cb817d04b6d5..127b012de69fd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -21,28 +21,28 @@ * Lower allreduce to device implementable ir. * \file lower_thread_allreduce.cc */ +#include +#include +#include #include #include #include -#include -#include -#include #include -#include "ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: - explicit ThreadAllreduceBuilder(int warp_size) - : warp_size_(warp_size) {} + explicit ThreadAllreduceBuilder(const TargetNode* target) + : target_(target), warp_size_(target->thread_warp_size) {} - Stmt VisitStmt_(const AttrStmtNode *op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { thread_extents_.push_back(op); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -58,7 +58,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return ret; } } else if (op->attr_key == attr::reduce_scope) { - const CommReducerNode *combiner = op->node.as(); + const CommReducerNode* combiner = op->node.as(); CHECK(combiner); reduce_combiner_.push_back(combiner); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -84,15 +84,19 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); - // use volatile access to shared buffer. - stmt = AttrStmtNode::make( - repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = AllocateNode::make( - repl->buffer_var, repl->dtype, - repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make( - repl->buffer_var, attr::storage_scope, - StringImmNode::make("shared"), stmt); + if (warp_allocs_.count(repl)) { + stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, + op->body); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("local"), stmt); + } else { + // use volatile access to shared buffer. + stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = + AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("shared"), stmt); + } return stmt; } else { return stmt; @@ -119,21 +123,22 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return scope.dim_index < other.scope.dim_index; } }; + // make allreduce. Stmt MakeAllreduce(const CallNode* call) { CHECK(!reduce_combiner_.empty()); - const CommReducerNode *combiner = reduce_combiner_.back(); + const CommReducerNode* combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const IntImmNode *size_of_args = call->args[0].as(); + const IntImmNode* size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); - PrimExpr cond = call->args[size+1]; + PrimExpr cond = call->args[size + 1]; for (size_t idx = 0; idx < size; ++idx) { - values[idx] = call->args[1+idx]; + values[idx] = call->args[1 + idx]; if (!is_one(cond)) { values[idx] = SelectNode::make(cond, values[idx], inits[idx]); } @@ -141,7 +146,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const VarNode* buffer = call->args[2+size+idx].as(); + const VarNode* buffer = call->args[2 + size + idx].as(); CHECK(buffer); buffers[idx] = buffer; } @@ -160,12 +165,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { e.scope = runtime::ThreadScope::make(iv->thread_tag); e.iv = iv; CHECK_LE(e.scope.rank, 1); - CHECK_GE(e.scope.dim_index, 0) - << "vthread do not work with cross thread reduction"; + CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; if (e.scope.rank == 1) { const auto* ptr = attr->value.as(); - CHECK(ptr) - << "Need constant extent for reduce set " << iv; + CHECK(ptr) << "Need constant extent for reduce set " << iv; e.extent = static_cast(ptr->value); if (reduce_set.count(iv->var.get())) { vred.push_back(e); @@ -175,66 +178,200 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } } - CHECK_EQ(nmatch, reduce_set.size()) - << "Not all reduce index are presented in the context"; + CHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context"; std::sort(vred.begin(), vred.end()); std::sort(vpar.begin(), vpar.end()); // the size of each index. int reduce_extent, group_extent; - int threadx_extent = 1; PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); PrimExpr group_index = FlattenThread(vpar, &group_extent); - if (reduce_extent == 1) { - // special case, no reduction is needed. - std::vector stores(size); + std::vector seq; + std::vector shared_bufs(size); + std::vector local_vars; + // + // This is an optimization. For small reduction sizes, it may be beneficial + // for a single warp to performance the entire reduction. No trips to shared + // memory and no cross warp synchronizations are required. + // The following code emits the reduction as follows: + // + // Allocate reduction vars v[i], i = 0..size-1 + // + // for offset from 16 to 1 by 2 + // + // a <- load(v[i]) + // b <- shuffle_down(load(v[i], offset)) + // v[i] <- reduction(a, b) + // + // broadcast results from lane 0 to all other lanes and store + // the final reduction result to the proper location. + // + if (is_warp_reduction(types)) { + // TODO(tvm-team) sub-warp reduction support. + CHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction"; + // + // This is the index to the reduction variable, one reduction + // variable per warp. Local scope seems easier to reason without + // relying on a pattern match pass to fix it later. + PrimExpr index(0); + + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred)); + + // Uses a local variable to store the shuffled data. + // Later on, this allocation will be properly attached to this statement. + Var var("t" + std::to_string(idx), types[idx]); + Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0)); + local_vars.push_back(s); + } + + // The mask for this reducer, as this reducer may sit inside + // a divergent control flow. Here it uses a variable to cache the current + // active channels. + // + Var mask_var("mask", DataType::UInt(32)); + { + PrimExpr pred = const_true(1); + PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, + CallNode::Intrinsic); + seq.emplace_back(StoreNode::make(mask_var, mask, index, pred)); + // Push allocation with an empty body. Later this will be fixed + // when the entire body is ready. + auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, + EvaluateNode::make(0)); + local_vars.push_back(stmt); + } + + // Emit reductions within a warp. + for (int offset = 16; offset > 0; offset /= 2) { + // Load reduction values, no synchronization needed. + Array a, b; + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + PrimExpr val = LoadNode::make(types[i], var, index, pred); + a.push_back(val); + + // __shfl_*sync calls shall not appear in if_then_else expressions + // as this is causing extra divergency. E.g. + // + // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); + // + // behaves differently from + // + // int t = __shfl_sync(mask, v1, 0); + // v1 = (v2 < v3) ? v3 : t; + // + // The former may cause dead lock as there is a divergent + // branch with a warp sync call inside. + // + const char* shfl_func = intrinsic::tvm_warp_shuffle_down; + PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); + const AllocateNode* repl = local_vars[i].as(); + Stmt s = StoreNode::make(repl->buffer_var, other, index, pred); + seq.push_back(s); + + PrimExpr load = LoadNode::make(types[i], repl->buffer_var, index, pred); + b.push_back(load); + } + + // Do reductions. + Array ret = (*combiner)(a, b); + + // Store the reduction result to itself. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + stores[i] = StoreNode::make(var, ret[i], index, pred); + } + seq.push_back(SeqStmt::Flatten(stores)); + } + + // Broadcast the reduction result from lane 0 to all other lanes. + // This avoids to emit predicated stores, as all threads are + // uniformmly writting the same result. + // for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2+size+i]); - stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + const char* shfl_func = intrinsic::tvm_warp_shuffle; + PrimExpr val = LoadNode::make(types[i], var, index, pred); + PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); + seq.push_back(StoreNode::make(var, splat, index, pred)); + } + + // Update existing allocations. + for (size_t i = 0; i < size; ++i) { + CHECK(!load_remap_.count(buffers[i])); + PrimExpr pred = const_true(types[i].lanes()); + Var var = shared_bufs[i]; + load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred); + Array extents{PrimExpr(1)}; + auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0)); + alloc_remap_[buffers[i]] = node; + warp_allocs_.insert(node.get()); + } + } else { + int threadx_extent = 1; + if (reduce_extent == 1) { + // special case, no reduction is needed. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + PrimExpr pred = const_true(types[i].lanes()); + Var buffer_var = Downcast(call->args[2 + size + i]); + stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + } + return SeqStmt::Flatten(stores); + } + // Whether the threadIdx.x is involved in reduction. + if (vred[0].scope.dim_index == 0) { + threadx_extent = vred[0].extent; + } + // This sync is necessary because there might be incomplete read of + // previous iteration on the same buffer. + seq.emplace_back(SyncThread("shared")); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); + } + seq.emplace_back(SyncThread("shared")); + seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, + reduce_extent, threadx_extent)); + for (size_t idx = 0; idx < size; ++idx) { + CHECK(!load_remap_.count(buffers[idx])); + PrimExpr pred = const_true(types[idx].lanes()); + load_remap_[buffers[idx]] = LoadNode::make( + types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); + alloc_remap_[buffers[idx]] = AllocateNode::make( + shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, + EvaluateNode::make(0)); } - return SeqStmt::Flatten(stores); - } - // Whether the threadIdx.x is involved in reduction. - if (vred[0].scope.dim_index == 0) { - threadx_extent = vred[0].extent; - } - std::vector seq; - std::vector shared_bufs(size); - // This sync is necessary because there might be incomplete read of - // previous iteration on the same buffer. - seq.emplace_back(SyncThread("shared")); - for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); - PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make( - shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); } - seq.emplace_back(SyncThread("shared")); - seq.emplace_back(MakeBufAllreduce( - combiner, types, shared_bufs, - reduce_index, group_index, reduce_extent, threadx_extent)); - for (size_t idx = 0; idx < size; ++idx) { - CHECK(!load_remap_.count(buffers[idx])); - PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = LoadNode::make( - types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = AllocateNode::make( - shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, - pred, EvaluateNode::make(0)); + + // Fix all local allocations as all statements are built. + Stmt body = SeqStmt::Flatten(seq); + for (auto var : local_vars) { + const AllocateNode* repl = var.as(); + if (repl) { + body = + AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, + StringImmNode::make("local"), body); + } } - return SeqStmt::Flatten(seq); + + return body; } + // make allreduce. - Stmt MakeBufAllreduce(const CommReducerNode *combiner, - const std::vector& types, - const Array& shared_bufs, - PrimExpr reduce_index, - PrimExpr group_index, - int reduce_extent, - int threadx_extent) { + Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, + const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, + int reduce_extent, int threadx_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -250,8 +387,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Array a, b; for (size_t i = 0; i < size; ++i) { b.push_back(LoadNode::make(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); @@ -271,9 +408,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } CHECK(threadx_extent >= 1 && warp_size_ >= 1); // normal synchronization - while (reduce_align > threadx_extent || - reduce_align > warp_size_) { - reduce_align = reduce_align >> 1; + while (reduce_align > threadx_extent || reduce_align > warp_size_) { + reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < reduce_align; seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); @@ -295,8 +431,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Flatten the thread index. // Also return a warp number, - PrimExpr FlattenThread(const std::vector& tvec, - int* out_total_extent) { + PrimExpr FlattenThread(const std::vector& tvec, int* out_total_extent) { int& total_extent = *out_total_extent; total_extent = 1; if (tvec.size() == 0) { @@ -325,11 +460,58 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync)}, - CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImmNode::make(sync)}, CallNode::Intrinsic)); + } + + // Emit warp shuffle intrinsic calls. + PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { + PrimExpr pred = const_true(1); + PrimExpr index(0); + PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred); + PrimExpr width = IntImm(DataType::Int(32), warp_size_); + Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; + return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic); + } + + // Check if this is a reduction on threadIdx.x and its extent matches + // the warp size. + // + // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads. + bool is_warp_reduction(const std::vector& types) const { + // Only cuda target supports warp reductions. + if (target_->target_name != "cuda") return false; + + // Supported types: + // {u}int, {u}long, {u}long long, float, double, half/half2 + if (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_float16()) return ty.lanes() > 2; + if (ty.is_vector()) return true; + return ty.bytes() < 4 || ty.bytes() > 8; + })) { + return false; + } + if (thread_extents_.empty()) { + return false; + } + + const AttrStmtNode* op = thread_extents_.back(); + DCHECK_EQ(op->attr_key, attr::thread_extent); + + IterVar iv = Downcast(op->node); + ThreadEntry e; + e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.extent = 0; + if (auto ptr = op->value.as()) { + e.extent = static_cast(ptr->value); + } + + return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1; } + + // The target. + const TargetNode* target_ = nullptr; + // The warp size of the device. int warp_size_{1}; @@ -337,9 +519,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; + // Allocate from warp reductions + std::unordered_set warp_allocs_; // Internal analyzer arith::Analyzer analyzer_; }; @@ -350,16 +534,15 @@ Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); + CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; + const TargetNode* target_node = target.as(); + n->body = ThreadAllreduceBuilder(target_node)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce") -.set_body_typed(LowerThreadAllreduce); +TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index ee6c44d213135..88c4363b5de18 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -21,10 +21,10 @@ * Lower TVM related builtin intrinsics such as packed call. * \file tir/transforms/lower_tvm_buildin.cc */ +#include #include #include #include -#include #include @@ -40,10 +40,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImmNode::make(type), ConstInt32(num)}; - return CallNode::make( - DataType::Handle(), - intrinsic::tvm_stack_alloca, - args, CallNode::Intrinsic); + return CallNode::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -57,18 +54,14 @@ class BuiltinLower : public StmtExprMutator { stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); if (max_shape_stack_ != 0) { - stmt = LetStmtNode::make( - stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); + stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { - stmt = LetStmtNode::make( - stack_array_, StackAlloca("array", max_array_stack_), stmt); + stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt); } if (max_arg_stack_ != 0) { - stmt = LetStmtNode::make( - stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); - stmt = LetStmtNode::make( - stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); + stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); + stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); } return stmt; } @@ -109,44 +102,34 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = EvaluateNode::make( - CallNode::make(DataType::Int(32), - intrinsic::tvm_throw_last_error, {}, - CallNode::Intrinsic)); + Stmt throw_last_error = EvaluateNode::make(CallNode::make( + DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({ - IfThenElseNode::make( - CallNode::make(DataType::Bool(1), - intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), - throw_last_error), - op->body}); + Stmt body = SeqStmt( + {IfThenElseNode::make(CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), + op->body}); Stmt alloca = LetStmtNode::make( op->buffer_var, - CallNode::make(op->buffer_var.dtype(), - "TVMBackendAllocWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), - IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}, - CallNode::Extern), + CallNode::make( + op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, + CallNode::Extern), body); - PrimExpr free_op = CallNode::make(DataType::Int(32), - "TVMBackendFreeWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - op->buffer_var}, - CallNode::Extern); - Stmt free_stmt = IfThenElseNode::make( - free_op != make_zero(DataType::Int(32)), throw_last_error); + PrimExpr free_op = CallNode::make(DataType::Int(32), "TVMBackendFreeWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), op->buffer_var}, + CallNode::Extern); + Stmt free_stmt = + IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); - body = AttrStmtNode::make( - op->buffer_var, attr::storage_alignment, - make_const(DataType::Int(32), runtime::kTempAllocaAlignment), - body); + body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment, + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } @@ -185,9 +168,8 @@ class BuiltinLower : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq_.emplace_back( - StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin +i), const_true(1))); + prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), + ConstInt32(stack_begin + i), const_true(1))); } return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } @@ -197,45 +179,36 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, - make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, - make_const(DataType::UInt(16), dtype.lanes()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, + make_const(DataType::UInt(8), dtype.bits()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, + make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); PrimExpr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, - cast(DataType::UInt(64), byte_offset))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, + cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, - cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, - cast(DataType::Int(32), device_type_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, + cast(DataType::Int(32), device_id_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, + cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. @@ -255,18 +228,15 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = CastNode::make(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -275,19 +245,14 @@ class BuiltinLower : public StmtExprMutator { run_shape_stack_ = restore_shape_stack; run_array_stack_ = restore_array_stack; run_arg_stack_ = arg_stack_begin; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1) - }; - return CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1)}; + return CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + CallNode::Intrinsic); } - PrimExpr MakeCallTracePacked(const CallNode *op) { + PrimExpr MakeCallTracePacked(const CallNode* op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; @@ -304,15 +269,12 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = CastNode::make(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -323,18 +285,13 @@ class BuiltinLower : public StmtExprMutator { // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. run_arg_stack_ = arg_stack_begin + args_size - 1; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1] - }; - return CallNode::make( - op->dtype, intrinsic::tvm_call_trace_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1), + // Pass traced value. + op->args[args_size - 1]}; + return CallNode::make(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + CallNode::Intrinsic); } private: @@ -379,8 +336,7 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin") -.set_body_typed(LowerTVMBuiltin); +TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 516b96cd9c156..4c8dec01245c6 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -25,20 +25,19 @@ */ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. -#include #include - +#include +#include +#include +#include #include #include -#include #include -#include -#include #include -#include "../../arith/pattern_match.h" #include "../../arith/compute_expr.h" +#include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -101,13 +100,8 @@ namespace tir { // store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: - WarpStoreCoeffFinder(const VarNode* buffer, - Var warp_index, - arith::Analyzer* analyzer) - : buffer_(buffer), - warp_index_(warp_index), - analyzer_(analyzer) { - } + WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) + : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { this->VisitStmt(stmt); @@ -116,7 +110,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const StoreNode *op) final { + void VisitStmt_(const StoreNode* op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); @@ -133,16 +127,14 @@ class WarpStoreCoeffFinder : private StmtVisitor { } void UpdatePattern(const PrimExpr& index) { - Array m = - arith::DetectLinearEquation(index, {warp_index_}); - CHECK_EQ(m.size(), 2U) - << "LowerWarpMemory failed due to store index=" << index; + Array m = arith::DetectLinearEquation(index, {warp_index_}); + CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); const auto* mcoeff_as_int = mcoeff.as(); CHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index - << ", require positive constant coefficient on warp index " << warp_index_ - << " but get " << mcoeff; + << ", require positive constant coefficient on warp index " << warp_index_ << " but get " + << mcoeff; if (warp_coeff_ != 0) { CHECK_EQ(warp_coeff_, mcoeff_as_int->value) @@ -162,13 +154,10 @@ class WarpStoreCoeffFinder : private StmtVisitor { arith::Analyzer* analyzer_; }; - // Visitor to find the warp index class WarpIndexFinder : private StmtVisitor { public: - explicit WarpIndexFinder(int warp_size) - : warp_size_(warp_size) { - } + explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} // find the warp co-efficient and the shuffle width in the statement std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); @@ -179,21 +168,20 @@ class WarpIndexFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const AttrStmtNode *op) final { + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { auto* value_as_int = op->value.as(); - CHECK(value_as_int && - value_as_int->value <= warp_size_ && + CHECK(value_as_int && value_as_int->value <= warp_size_ && warp_size_ % value_as_int->value == 0) << "Expect threadIdx.x 's size to be no larger than, and a factor of" - << " warp size(" << warp_size_ << ")" << " to enable warp memory" + << " warp size(" << warp_size_ << ")" + << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { CHECK(warp_index_.same_as(iv)) - << "Find two instance of " << warp_index_->thread_tag - << " in the same kernel. " + << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { @@ -221,27 +209,21 @@ class WarpAccessRewriter : protected StmtExprMutator { Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); - CHECK_GT(alloc_size, 0) - << "warp memory only support constant alloc size"; + CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); - warp_coeff_ = WarpStoreCoeffFinder( - buffer_, warp_index_, analyzer_).Find(op->body); + warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0) << "Warp memory must be multiple of the extent of threadIdx.x"; warp_group_ = alloc_size / (width_ * warp_coeff_); - return AllocateNode::make( - op->buffer_var, - op->dtype, - {make_const(DataType::Int(32), alloc_size / width_)}, - op->condition, - this->VisitStmt(op->body)); + return AllocateNode::make(op->buffer_var, op->dtype, + {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, + this->VisitStmt(op->body)); } protected: PrimExpr VisitExpr_(const VarNode* op) override { - CHECK(op != buffer_) - << "Cannot access address of warp memory directly"; + CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } @@ -261,14 +243,13 @@ class WarpAccessRewriter : protected StmtExprMutator { std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id CHECK(!ExprUseVar(local_index, warp_index_)) - << "LowerWarpMemory failed to rewrite load to shuffle for index " - << op->index << " local_index=" << local_index; - PrimExpr load_value = LoadNode::make( - op->dtype, op->buffer_var, local_index, op->predicate); - return CallNode::make(load_value.dtype(), - intrinsic::tvm_warp_shuffle, - {load_value, group, width_, warp_size_}, - CallNode::Intrinsic); + << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index + << " local_index=" << local_index; + PrimExpr load_value = LoadNode::make(op->dtype, op->buffer_var, local_index, op->predicate); + PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, + CallNode::Intrinsic); + return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, + {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } @@ -301,10 +282,8 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); y = y * m + x; - PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), - m); - return std::make_pair(analyzer_->canonical_simplify(y), - analyzer_->canonical_simplify(z)); + PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); + return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); } } @@ -325,14 +304,12 @@ class WarpAccessRewriter : protected StmtExprMutator { arith::Analyzer* analyzer_; }; - // Bind bound information of variables to make analyzer more effective // TODO(tqchen): consider a pass to inline the bound info into the expr // so analysis can be context independent. class BindVarBoundInfo : public StmtVisitor { public: - explicit BindVarBoundInfo(arith::Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; @@ -341,8 +318,7 @@ class BindVarBoundInfo : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); if (!var_dom_.count(iv->var.get())) { @@ -364,9 +340,7 @@ class BindVarBoundInfo : public StmtVisitor { // Mutator to change the read pattern class WarpMemoryRewriter : private StmtMutator { public: - explicit WarpMemoryRewriter(int warp_size) - : warp_size_(warp_size) { - } + explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {} Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; @@ -396,8 +370,7 @@ class WarpMemoryRewriter : private StmtMutator { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); - return AttrStmtNode::make( - op->node, op->attr_key, StringImmNode::make("local"), op->body); + return AttrStmtNode::make(op->node, op->attr_key, StringImmNode::make("local"), op->body); } } return StmtMutator::VisitStmt_(op); @@ -416,16 +389,14 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory") -.set_body_typed(LowerWarpMemory); +TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 4e5ca2ddf40d1..b6314ad175070 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,22 +20,22 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include -#include -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include +#include +#include -#include -#include #include +#include +#include -#include "ir_util.h" #include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -45,15 +45,12 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { EvaluateNode::make(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func, - int num_unpacked_args) { +PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol) - << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "MakePackedAPI: Require the target attribute"; + CHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; int target_device_type = target.value()->device_type; std::string name_hint = global_symbol.value(); @@ -85,15 +82,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { - Array call_args{ - v_packed_args, - IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + Array call_args{v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = CallNode::make( - api_type, intrinsic::tvm_struct_get, call_args, - CallNode::PureIntrinsic); + PrimExpr res = + CallNode::make(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { res = CastNode::make(t, res); @@ -111,8 +105,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, std::ostringstream os; os << name_hint << ": num_args should be " << num_packed_args; - seq_init.emplace_back( - MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); + seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } // Need to re-declare vars, in case some arguments also appears in the buffer. @@ -131,24 +124,21 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } if (i < num_packed_args) { // Value loads - seq_init.emplace_back(LetStmtNode::make( - v_arg, f_arg_value(v_arg.dtype(), i), nop)); + seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmtNode::make( - tcode, LoadNode::make( - DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back( + LetStmtNode::make(tcode, + LoadNode::make(DataType::Int(32), v_packed_arg_type_ids, + IntImm(DataType::Int(32), i), const_true(1)), + nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kTVMOpaqueHandle || - tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, + AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImmNode::make(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; @@ -188,35 +178,30 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } for (const auto& kv : buffer_def) { - binder.BindDLTensor(kv.second, device_type, device_id, - kv.first, kv.first->name_hint); + binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); } if (num_unpacked_args == 0) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - auto body = AttrStmtNode::make( - make_zero(DataType::Int(32)), attr::compute_scope, - StringImmNode::make(name_hint + "_compute_"), func_ptr->body); + auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope, + StringImmNode::make(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImmNode::make("default"); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_id, device_id, nop)); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_type, device_type, nop)); + seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop)); + seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - {StringImmNode::make(runtime::symbol::tvm_set_device), - device_type, device_id}, CallNode::Intrinsic)); + DataType::Int(32), intrinsic::tvm_call_packed, + {StringImmNode::make(runtime::symbol::tvm_set_device), device_type, device_id}, + CallNode::Intrinsic)); body = SeqStmt({set_device, body}); } } - func_ptr->body = MergeNest( - {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); @@ -229,7 +214,6 @@ PrimFunc MakePackedAPI(PrimFunc&& func, LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); } - func_ptr->buffer_map = Map(); func_ptr->checked_type_ = func_ptr->func_type_annotation(); func_ptr->ret_type = PrimType(DataType::Int(32)); @@ -248,9 +232,8 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } @@ -263,12 +246,10 @@ Pass MakePackedAPI(int num_unpacked_args) { return m; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.MakePackedAPI", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI") -.set_body_typed(MakePackedAPI); +TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 4cf5ccdd081c7..ad86e452823fe 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,9 +22,10 @@ * \brief narrow the datatype of indexing vars */ +#include #include #include -#include + #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" @@ -55,8 +56,8 @@ namespace tir { // - Use DataTypeRewritter to rewrite the components of an indexing expression. using arith::Analyzer; -using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; +using arith::IRMutatorWithAnalyzer; // Determine the result dtype for Var, IntImm and Cast, // which will be stored in `vmap` eventually. @@ -70,24 +71,22 @@ using arith::ConstIntBound; // Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` class DataTypeVisitor final : public StmtExprVisitor { public: - explicit DataTypeVisitor(int target_bits) - : bits_(target_bits), target_bits_(target_bits) {} + explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = max_bits_; - const PrimExprNode* op = e.as(); - if (bound_.find(op) == bound_.end()) { + if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_); } - ConstIntBound bound = bound_[op]; + ConstIntBound bound = bound_[e]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || (bound->max_value <= ubound && bound->min_value >= lbound)) { bits = target_bits_; } - int tmp = bits > bits_ ? bits : bits_; + int tmp = bits > bits_ ? bits : bits_; std::swap(bits_, tmp); StmtExprVisitor::VisitExpr(e); std::swap(bits_, tmp); @@ -97,19 +96,16 @@ class DataTypeVisitor final : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { @@ -187,12 +183,12 @@ class DataTypeVisitor final : public StmtExprVisitor { // the extent of vars to be rewritten std::unordered_map vextent_; // the memorized bound generated by ConstIntBoundAnalyzer - std::unordered_map bound_; + arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; class DataTypeRewriter : public StmtExprMutator { public: - explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {} Stmt operator()(Stmt s) { visitor_(s); @@ -212,19 +208,15 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - Stmt s = StoreNode::make(op->buffer_var, - op->value, - index, - op->predicate); + Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate); return StmtExprMutator::VisitStmt_(s.as()); } Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be ForNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), @@ -232,27 +224,20 @@ class DataTypeRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be AttrStmtNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); - CHECK(iv != nullptr) - << "Expected type to be IterVarNode" - << ", but get " << op->node->GetTypeKey(); + CHECK(iv != nullptr) << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); } - return AttrStmtNode::make( - ivmap_[iv], - op->attr_key, - cast(var.dtype(), op->value), - op->body); + return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } return StmtExprMutator::VisitStmt_(op); } @@ -298,9 +283,8 @@ class DataTypeRewriter : public StmtExprMutator { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = StmtExprMutator::VisitExpr_(op); const CastNode* new_op = e.as(); - CHECK(new_op != nullptr) - << "Expected type to be CastNode" - << ", but get " << e->GetTypeKey(); + CHECK(new_op != nullptr) << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); return CastNode::make(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); @@ -336,40 +320,38 @@ class DataTypeRewriter : public StmtExprMutator { bool is_index_{false}; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); - CHECK(op != nullptr) - << "Expected type to be CallNode" - << ", but get " << e->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->call_type == CallNode::PureIntrinsic) { if (op->name == intrinsic::tvm_if_then_else) { return if_then_else(op->args[0], op->args[1], op->args[2]); @@ -390,9 +372,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt NarrowDataType(Stmt stmt, int target_bits) { - return DataTypeRewriter(target_bits)(stmt); -} +Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } namespace transform { @@ -402,12 +382,10 @@ Pass NarrowDataType(int target_bits) { n->body = DataTypeRewriter(target_bits)(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.NarrowDataType", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") -.set_body_typed(NarrowDataType); +TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index fdcfc4d4702e6..efb9e6956b178 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -20,12 +20,12 @@ /*! * \file remap_thread_axis.cc */ +#include #include #include #include -#include -#include +#include namespace tvm { namespace tir { @@ -33,14 +33,9 @@ namespace tir { // Mutator to change the read pattern class ThreadAxisRewriter : private StmtExprMutator { public: - explicit ThreadAxisRewriter( - const std::unordered_map& tmap) - : tmap_(tmap) { - } + explicit ThreadAxisRewriter(const std::unordered_map& tmap) : tmap_(tmap) {} - Stmt Rewrite(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); } private: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -57,8 +52,7 @@ class ThreadAxisRewriter : private StmtExprMutator { CHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make( - new_iv, op->attr_key, op->value, body); + return AttrStmtNode::make(new_iv, op->attr_key, op->value, body); } } return StmtExprMutator::VisitStmt_(op); @@ -75,7 +69,6 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; - PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { @@ -83,8 +76,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) } auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(opt_thread_axis != nullptr) - << "Require attribute " << tir::attr::kDeviceThreadAxis; + CHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); @@ -99,7 +91,6 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); } - namespace transform { Pass RemapThreadAxis(Map thread_map) { @@ -109,8 +100,7 @@ Pass RemapThreadAxis(Map thread_map) { return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis") -.set_body_typed(RemapThreadAxis); +TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index ceaf27b816ffa..15a7e8638e5cc 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -22,11 +22,12 @@ * \brief Remove no op from the stmt */ #include -#include #include #include -#include +#include #include +#include + #include namespace tvm { @@ -105,7 +106,7 @@ class NoOpRemover : public StmtMutator { auto n = CopyOnWrite(op); size_t top = 0; for (size_t i = 0; i < n->seq.size(); ++i) { - if (!is_no_op(n->seq[i])) { + if (!is_no_op(n->seq[i])) { n->seq.Set(top++, n->seq[i]); } } @@ -147,9 +148,7 @@ class NoOpRemover : public StmtMutator { } }; -Stmt RemoveNoOp(Stmt stmt) { - return NoOpRemover()(std::move(stmt)); -} +Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } namespace transform { @@ -162,8 +161,7 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp") -.set_body_typed(RemoveNoOp); +TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); } // namespace transform diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 6052cbf32b5dc..149cda946d881 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -29,16 +29,13 @@ namespace tvm { namespace tir { - // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. - bool VisitExpr_(const SelectNode* op) { - return VisitExpr(op->condition); - } + bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); @@ -75,21 +72,11 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } - bool VisitExpr_(const NotNode* op) final { - return VisitExpr(op->a); - } - bool VisitExpr_(const LetNode* op) final { - return VisitExpr(op->body) || VisitExpr(op->value); - } - bool VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const BroadcastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const RampNode* op) final { - return VisitExpr(op->base) && VisitExpr(op->stride); - } + bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } + bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } + bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } bool VisitExpr_(const ShuffleNode* op) final { for (PrimExpr e : op->vectors) { if (VisitExpr(e)) return true; @@ -102,7 +89,7 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const StringImmNode* op) final { return false; } private: - template + template bool BinaryOp(const T* op) { return VisitExpr(op->a) || VisitExpr(op->b); } @@ -115,23 +102,17 @@ class UnsafeSelectRewriter : public StmtExprMutator { op = expr.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); - if ((unsafe.VisitExpr(op->true_value) || - unsafe.VisitExpr(op->false_value)) && + if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return CallNode::make( - op->dtype, - intrinsic::tvm_if_then_else, - {op->condition, op->true_value, op->false_value}, - CallNode::Intrinsic); + return CallNode::make(op->dtype, intrinsic::tvm_if_then_else, + {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; } } }; -Stmt RewriteUnsafeSelect(Stmt stmt) { - return UnsafeSelectRewriter()(std::move(stmt)); -} +Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } namespace transform { @@ -144,8 +125,7 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect") -.set_body_typed(RewriteUnsafeSelect); +TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 752939e625ae9..759b320131e5f 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,14 +21,13 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ +#include #include +#include #include +#include #include -#include -#include -#include -#include #include "../../arith/ir_mutator_with_analyzer.h" namespace tvm { @@ -38,20 +37,15 @@ using namespace tir; class StmtSimplifier : public IRMutatorWithAnalyzer { public: - explicit StmtSimplifier(Analyzer* analyzer) - : IRMutatorWithAnalyzer(analyzer) {} + explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitStmt; using Parent::VisitStmt_; - PrimExpr VisitExpr(const PrimExpr& expr) final { - return analyzer_->Simplify(expr); - } + PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } - Stmt Simplify(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt_(const ForNode* op) final { analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); @@ -69,8 +63,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return this->VisitStmt(op->body); } Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -109,8 +102,7 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_REGISTER_GLOBAL("tir.transform.Simplify") -.set_body_typed(Simplify); +TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); } // namespace transform diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 4511838efe574..d9cd6d35497cf 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -17,10 +17,10 @@ * under the License. */ +#include #include -#include #include -#include +#include namespace tvm { namespace tir { @@ -34,9 +34,7 @@ class AssertSkipper : public StmtMutator { } }; -Stmt SkipAssert(Stmt stmt) { - return AssertSkipper()(std::move(stmt)); -} +Stmt SkipAssert(Stmt stmt) { return AssertSkipper()(std::move(stmt)); } namespace transform { @@ -49,8 +47,7 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SkipAssert") -.set_body_typed(SkipAssert); +TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 44f032ffa444f..9bdb0e2354787 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,14 +22,14 @@ * \brief Split device function from host. */ #include -#include -#include -#include +#include +#include +#include #include +#include +#include #include -#include -#include -#include +#include #include @@ -69,13 +69,11 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { return LetStmtNode::make(op->var, value, body); @@ -102,13 +100,11 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -127,12 +123,10 @@ class VarUseDefAnalysis : public StmtExprMutator { } void HandleDef(const VarNode* v) { - CHECK(!def_count_.count(v)) - << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - CHECK(!use_count_.count(v)) - << "variable " << v->name_hint - << " has been used before definition!"; + CHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + CHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; use_count_[v] = 0; def_count_[v] = 1; } @@ -161,7 +155,6 @@ class VarUseDefAnalysis : public StmtExprMutator { std::unordered_map def_count_; }; - Array UndefinedVars(const Stmt& stmt, const Array& args) { VarUseDefAnalysis m; for (Var arg : args) { @@ -171,16 +164,10 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { return m.undefined_; } - class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, - Target device_target, - std::string name_prefix) - : device_mod_(device_mod), - device_target_(device_target), - name_prefix_(name_prefix) { - } + explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) + : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AllocateNode* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); @@ -188,8 +175,7 @@ class HostDeviceSplitter : public StmtMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { return SplitDeviceFunc(GetRef(op)); } @@ -216,8 +202,7 @@ class HostDeviceSplitter : public StmtMutator { // Create a new version of v. auto it = handle_data_type_.find(var.get()); if (it != handle_data_type_.end()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype))); + tir::Var new_var(var->name_hint, PointerType(PrimType((*it).second->dtype))); params.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -237,8 +222,8 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); - device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, - runtime::String(kernel_symbol)); + device_func = + WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); @@ -252,9 +237,8 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - call_args, CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, + call_args, CallNode::Intrinsic)); } // target ir module @@ -268,19 +252,15 @@ class HostDeviceSplitter : public StmtMutator { std::unordered_map handle_data_type_; }; - PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "SplitHostDevice: Require the target attribute"; + CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; - HostDeviceSplitter splitter( - device_mod, - target.value(), - static_cast(global_symbol.value())); + HostDeviceSplitter splitter(device_mod, target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); @@ -289,7 +269,6 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { return std::move(func); } - namespace transform { Pass SplitHostDevice() { @@ -308,12 +287,10 @@ Pass SplitHostDevice() { return mod; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.SplitHostDevice", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice") -.set_body_typed(SplitHostDevice); +TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 1f28e138a1ef9..35888bd7f9e13 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -20,12 +20,15 @@ /*! * \file storage_access.cc */ +#include "storage_access.h" + #include + #include #include -#include "storage_access.h" -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -89,8 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::make(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); @@ -145,8 +147,8 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { if (s.access.size() != 0) { // relax the touched set to contain all ranges in the loop. std::unordered_map relax_map; - relax_map[op->loop_var.get()] = arith::IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + relax_map[op->loop_var.get()] = + arith::IntSet::range(Range::make_by_min_extent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { CHECK(e.touched.defined()); @@ -180,7 +182,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); @@ -197,8 +199,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(op->args[1]); - e.touched = arith::IntSet::range( - Range::make_by_min_extent(offset, extent)); + e.touched = arith::IntSet::range(Range::make_by_min_extent(offset, extent)); e.scope = scope; if (flag->value & 1) { e.type = kRead; diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 12e76bd08732d..80bbff4c1fe4f 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -24,19 +24,21 @@ #ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ #define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ +#include #include #include -#include #include -#include + #include +#include + #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; /*! * \brief Base class of storage access analysis */ @@ -85,31 +87,20 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final; protected: - StorageAccessVisitor() { - scope_.push_back(std::vector()); - } + StorageAccessVisitor() { scope_.push_back(std::vector()); } /*! \return number of conditions in the current scope. */ - int condition_counter() const { - return condition_counter_; - } + int condition_counter() const { return condition_counter_; } /*! \return whether we are in device environment. */ - bool in_device_env() const { - return in_device_env_; - } + bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ - const Array& env_threads() const { - return env_threads_; - } + const Array& env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked * \param scope The scope of the buffer. * \return Whether the analysis of buffer is enabled. */ - virtual bool Enabled(const VarNode* buffer, - const StorageScope& scope) const { - return true; - } + virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; } /*! * \brief Summarize the sequence of operations into parent. * @@ -121,8 +112,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \return The summarized sequence that represent access that * the parent should taken care of to synchronize. */ - virtual std::vector Summarize( - std::vector seq, const ForNode* loop) = 0; + virtual std::vector Summarize(std::vector seq, const ForNode* loop) = 0; /*! * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 96686950b4973..96d0e307efe6d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -23,40 +23,39 @@ */ // The pass definition originates from Halide pipeline. -#include #include +#include +#include +#include +#include +#include #include +#include #include -#include #include -#include #include -#include -#include -#include + #include -#include "ir_util.h" -#include "arg_binder.h" + #include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" +#include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; class StorageFlattener : public StmtExprMutator { public: - explicit StorageFlattener(const Map& extern_buffer_map, - int cache_line_size, - bool create_bound_attributes, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer), - create_bound_attributes_(create_bound_attributes) { + explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, + bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; @@ -70,8 +69,7 @@ class StorageFlattener : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); return StoreNode::make(buf_var, op->value, op->index, op->predicate); @@ -89,10 +87,8 @@ class StorageFlattener : public StmtExprMutator { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << buffer; - body = AttrStmtNode::make( - it->second.buffer->data, op->attr_key, op->value, std::move(body)); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -129,31 +125,24 @@ class StorageFlattener : public StmtExprMutator { const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; if (is_opengl_) { - return EvaluateNode::make(CallNode::make( - DataType(), - CallNode::glsl_texture_store, - {e.buffer->data, op->value}, - CallNode::Intrinsic)); + return EvaluateNode::make(CallNode::make(DataType(), CallNode::glsl_texture_store, + {e.buffer->data, op->value}, CallNode::Intrinsic)); } else { Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } // To create bound attribute collector should has at least one item. if (create_bound_attributes_ && shape_collector_.size()) { for (size_t i = 0; i < shape_collector_.size(); ++i) { - body = AttrStmtNode::make( - shape_collector_[i].first, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, shape_collector_[i].second), body); + body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, shape_collector_[i].second), body); } } return body; @@ -176,14 +165,12 @@ class StorageFlattener : public StmtExprMutator { } // deduce current storage scope. auto it = storage_scope_.find(op->buffer.get()); - CHECK(it != storage_scope_.end()) - << "Cannot find storage scope of " << op->buffer; + CHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank( - curr_thread_scope_.back().rank); + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } } else { skey = StorageScope::make(strkey); @@ -221,11 +208,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = BufferNode::make( - Var(op->buffer->data->name_hint, DataType::Handle()), - op->buffer->dtype, shape, strides, PrimExpr(), - op->buffer->name, skey.to_string(), - align, 0, kDefault); + e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()), + op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -240,26 +225,23 @@ class StorageFlattener : public StmtExprMutator { } if (strides.size() != 0) { int first_dim = 0; - ret = AllocateNode::make( - e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = AllocateNode::make(e.buffer->data, storage_type, + {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { shape = e.buffer->shape; if (shape.size() == 0) { shape.push_back(make_const(DataType::Int(32), 1)); } - ret = AllocateNode::make( - e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = AllocateNode::make(e.buffer->data, storage_type, shape, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmtNode::make( - e.buffer->data, attr::storage_scope, - StringImmNode::make(e.buffer->scope), ret); + ret = AttrStmtNode::make(e.buffer->data, attr::storage_scope, + StringImmNode::make(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, e.buffer->shape), ret); + MakeBound(e.buffer->dtype, e.buffer->shape), ret); } return ret; } @@ -269,8 +251,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); return LoadNode::make(op->dtype, buf_var, op->index, op->predicate); @@ -295,38 +276,31 @@ class StorageFlattener : public StmtExprMutator { const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); } - - Stmt VisitStmt_(const PrefetchNode *op) final { + Stmt VisitStmt_(const PrefetchNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; CHECK_EQ(e.buffer->shape.size(), op->bounds.size()) - << "Prefetch dim should be the same as buffer dim"; + << "Prefetch dim should be the same as buffer dim"; - int block_size = 1, - elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); + int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); int starts = op->bounds.size() - 1; @@ -344,25 +318,23 @@ class StorageFlattener : public StmtExprMutator { for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } - auto &func_name = op->buffer->name; - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); + auto& func_name = op->buffer->name; + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); for (int i = starts - 1; i >= 0; --i) { - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); args.push_back(vars.back() + op->bounds[i]->min); } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = ForNode::make( - vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, + stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); - PrimExpr address = CallNode::make( - DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); - PrimExpr prefetch = CallNode::make( - op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + PrimExpr address = + CallNode::make(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + PrimExpr prefetch = CallNode::make(op->buffer->dtype, CallNode::prefetch, + {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -372,9 +344,8 @@ class StorageFlattener : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - CHECK(op->call_type != CallNode::Halide) - << "Cannot handle Halide calls " - << " please run SchedulePostProcToPrimFunc first"; + CHECK(op->call_type != CallNode::Halide) << "Cannot handle Halide calls " + << " please run SchedulePostProcToPrimFunc first"; return StmtExprMutator::VisitExpr_(op); } @@ -390,7 +361,6 @@ class StorageFlattener : public StmtExprMutator { return Stmt(); } - private: // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope @@ -427,7 +397,7 @@ class StorageFlattener : public StmtExprMutator { // We do support a few relaxed case, such as bindingx // region with shape [1, 1, n, m] to buffer with shape [n, m] Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast > (op->node); + Array arr = Downcast>(op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); const BufferNode* target = arr[1].as(); @@ -437,8 +407,7 @@ class StorageFlattener : public StmtExprMutator { auto key = GetRef(target); auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find buffer of " << key; + CHECK(it != buf_map_.end()) << "Cannot find buffer of " << key; const BufferEntry& be = it->second; CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); @@ -452,15 +421,14 @@ class StorageFlattener : public StmtExprMutator { } else { for (size_t i = 0; i < tuple->args.size(); i += 2) { begins.push_back(tuple->args[i]); - auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]); + auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]); extents.push_back(new_extent); } } Buffer slice = be.buffer.MakeSlice(begins, extents); if (buffer->strides.size() == 0) { CHECK_EQ(slice->strides.size(), 0U) - << "Trying to bind compact buffer to strided one strides=" - << slice->strides; + << "Trying to bind compact buffer to strided one strides=" << slice->strides; } else { slice = slice.MakeStrideView(); } @@ -508,26 +476,24 @@ class StorageFlattener : public StmtExprMutator { } }; - bool ShapeIsValid(const Array &shape) { + bool ShapeIsValid(const Array& shape) { // Zero-dimensional tensor does not need boundary check. - if (!shape.size()) - return false; + if (!shape.size()) return false; for (size_t i = 0; i < shape.size(); ++i) { - if (!shape[i].defined() || !shape[i].dtype().is_scalar() || - is_negative_const(shape[i])) { + if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { return false; } } return true; } - PrimExpr MakeBound(const DataType &type, const Array &shape) { + PrimExpr MakeBound(const DataType& type, const Array& shape) { // We have already checked the shape size to be greater then 0. PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { - bound = MulNode::make( - bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); + bound = + MulNode::make(bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); } return bound; } @@ -538,8 +504,7 @@ class StorageFlattener : public StmtExprMutator { // Buffer map std::unordered_map buf_map_; // Dimension alignment - std::unordered_map, - ObjectHash, ObjectEqual> dim_align_; + std::unordered_map, ObjectHash, ObjectEqual> dim_align_; // Storage scope std::unordered_map storage_scope_; // The current thread scope. @@ -557,35 +522,27 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; -PrimFunc StorageFlatten(PrimFunc func, - int cache_line_size, - bool create_bound_attributes) { +PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { auto fptr = func.CopyOnWrite(); IRVisitorWithAnalyzer bound_analyzer; bound_analyzer(fptr->body); - fptr->body = StorageFlattener(fptr->buffer_map, - cache_line_size, - create_bound_attributes, + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); return func; } - namespace transform { // TODO(tvm-team): consolidate configs to the PassContext -Pass StorageFlatten(int cache_line_size, - bool create_bound_attributes) { +Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return StorageFlatten( - std::move(f), cache_line_size, create_bound_attributes); + return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten") -.set_body_typed(StorageFlatten); +TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten); } // namespace transform diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index ca2b5a9b45aa7..fc86f2bdf3484 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -22,19 +22,21 @@ * \brief Memory access pattern analysis and optimization. * Re-write data access to enable memory sharing when possible. */ -#include #include -#include -#include +#include +#include #include +#include #include -#include +#include + #include -#include #include -#include "ir_util.h" +#include + #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -125,8 +127,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << "Load memory in places other than store."; + CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buf); } } @@ -142,24 +143,23 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << " buf=" << buf->name_hint; + CHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; scope_[it->second.level].touched.push_back(buf); } } - template + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; - int64_t begin_index = static_cast(linear_seq_.size()); + int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); - int64_t end_index = static_cast(linear_seq_.size()); + int64_t end_index = static_cast(linear_seq_.size()); CHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); @@ -179,24 +179,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = - StorageScope::make(op->value.as()->value); + alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } } - void VisitStmt_(const IfThenElseNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } - void VisitStmt_(const ForNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } - void VisitStmt_(const AssertStmtNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. std::vector linear_seq_; @@ -238,9 +231,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // class InplaceOpVerifier : public StmtExprVisitor { public: - bool Check(const Object* stmt, - const VarNode* dst, - const VarNode* src) { + bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) { dst_ = dst; src_ = src; result_ = true; @@ -272,7 +263,8 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { - result_ = false; return; + result_ = false; + return; } } @@ -293,9 +285,9 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // always reject extern code - if (op->attr_key == attr::extern_scope || - op->attr_key == attr::volatile_scope) { - result_ = false; return; + if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { + result_ = false; + return; } StmtExprVisitor::VisitStmt_(op); } @@ -304,17 +296,19 @@ class InplaceOpVerifier : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { - result_ = false; return; + result_ = false; + return; } // do not allow indirect memory load if (mem_nest_ != 0) { - result_ = false; return; + result_ = false; + return; } if (src_ == buf) { - if (store_ == nullptr || - store_->value.dtype() != op->dtype || + if (store_ == nullptr || store_->value.dtype() != op->dtype || !tir::ExprDeepEqual()(store_->index, op->index)) { - result_ = false; return; + result_ = false; + return; } } ++mem_nest_; @@ -322,7 +316,6 @@ class InplaceOpVerifier : public StmtExprVisitor { --mem_nest_; } - private: // result of the check bool result_{true}; @@ -358,10 +351,9 @@ class StoragePlanRewriter : public StmtExprMutator { for (StorageEntry* e : attach_map_.at(nullptr)) { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, + StringImmNode::make(e->scope.to_string()), + EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } @@ -374,20 +366,16 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; - return StoreNode::make(it->second->alloc_var, - op->value, - RemapIndex(op->value.dtype(), op->index, it->second), - op->predicate); + return StoreNode::make(it->second->alloc_var, op->value, + RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; - return LoadNode::make(op->dtype, - it->second->alloc_var, - RemapIndex(op->dtype, op->index, it->second), - op->predicate); + return LoadNode::make(op->dtype, it->second->alloc_var, + RemapIndex(op->dtype, op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); @@ -417,10 +405,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return CallNode::make( - op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + return CallNode::make(op->dtype, op->name, + {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { return StmtExprMutator::VisitExpr_(op); } @@ -429,17 +416,14 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakeAttach(svec, op->body)); + return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } @@ -448,31 +432,26 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; - return AttrStmtNode::make( - it->second->alloc_var, op->attr_key, op->value, op->body); + return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body); } else { return StmtExprMutator::VisitStmt_(op); } } Stmt VisitStmt_(const ForNode* op) final { - CHECK(op->for_type != ForType::Vectorized) - << "VectorizeLoop before LiftStorageAlloc"; + CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return ForNode::make( - op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, + MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } } - Stmt VisitStmt_(const AllocateNode* op) final { - return this->VisitStmt(op->body); - } + Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: struct StorageEntry { @@ -517,15 +496,13 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector kill; }; - Stmt MakeAttach(const std::vector& svec, - Stmt body) { + Stmt MakeAttach(const std::vector& svec, Stmt body) { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, + StringImmNode::make(e->scope.to_string()), + EvaluateNode::make(0))); nest.push_back(e->new_alloc); } } @@ -545,15 +522,14 @@ class StoragePlanRewriter : public StmtExprMutator { attach_map_[e->attach_scope_].push_back(e); } // find allocation via attach map. - for (auto &kv : attach_map_) { + for (auto& kv : attach_map_) { // find the element with the most amount of bytes. std::vector& vec = kv.second; // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; if (e->scope.tag.length() != 0) { - CHECK_NE(e->const_nbits, 0U) - << "Special tagged memory must be const size"; + CHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { vec[j]->merged_children.push_back(e); @@ -568,7 +544,8 @@ class StoragePlanRewriter : public StmtExprMutator { // already merged if (e->bits_offset != 0) continue; if (e->merged_children.size() != 0) { - NewAllocTagMerged(e); continue; + NewAllocTagMerged(e); + continue; } // Get the allocation size; e->alloc_var = e->allocs[0]->buffer_var; @@ -581,10 +558,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. PrimExpr sz = arith::ComputeReduce(e->allocs[0]->extents, - make_const(DataType::Int(32), 1)); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {sz}, - e->allocs[0]->condition, EvaluateNode::make(0)); + make_const(DataType::Int(32), 1)); + e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, + EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -595,13 +571,12 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = arith::ComputeReduce( - op->extents, make_const(DataType::Int(32), 1)); + PrimExpr sz = + arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { - LOG(WARNING) << "The allocation requires : " << imm->value - << " * " << nbits + LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits << " bits, which is greater than the maximum of" " int32. The size is cast to int64." << "\n"; @@ -625,9 +600,8 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = analyzer_.Simplify(combo_size); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {combo_size}, const_true(), - EvaluateNode::make(0)); + e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(), + EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -653,7 +627,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Always align to max_simd_bits // so we can remap types by keeping this property if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { @@ -663,15 +637,14 @@ class StoragePlanRewriter : public StmtExprMutator { child->alloc_var = e->alloc_var; total_bits += child->const_nbits; if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); - PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), - (total_bits + type_bits - 1) / type_bits); - e->new_alloc = AllocateNode::make( - e->alloc_var, e->elem_type, {alloc_size}, const_true(), - EvaluateNode::make(0)); + PrimExpr alloc_size = + make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); + e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(), + EvaluateNode::make(0)); if (info.defined()) { CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -764,8 +737,7 @@ class StoragePlanRewriter : public StmtExprMutator { visitor.Check(s.stmt, var, src)) { uint64_t const_nbits = static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * - ae.alloc->dtype.lanes(); + ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -786,8 +758,7 @@ class StoragePlanRewriter : public StmtExprMutator { // enter/exit new scope if (s.stmt->IsInstance()) { const auto* op = static_cast(s.stmt); - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); } else { @@ -816,10 +787,8 @@ class StoragePlanRewriter : public StmtExprMutator { } } // Allocate new storage entry. - StorageEntry* NewAlloc(const AllocateNode* op, - const Object* attach_scope, - const StorageScope& scope, - size_t const_nbits) { + StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, + const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); // Re-use not successful, allocate a new buffer. std::unique_ptr entry(new StorageEntry()); @@ -832,23 +801,21 @@ class StoragePlanRewriter : public StmtExprMutator { return e; } - StorageEntry* FindAlloc(const AllocateNode* op, - const Object* attach_scope, + StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); - uint64_t const_nbits = static_cast( - op->constant_allocation_size() * op_elem_bits); + uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) { return NewAlloc(op, attach_scope, scope, const_nbits); } - if (const_nbits > 0 && const_nbits <= 32) { + if (const_nbits > 0 && const_nbits <= 32) { return NewAlloc(op, attach_scope, scope, const_nbits); } } @@ -859,7 +826,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto end = const_free_map_.upper_bound(const_nbits * match_range); // start looking at the buffer that is bigger than the required size first for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 @@ -871,7 +838,7 @@ class StoragePlanRewriter : public StmtExprMutator { // then start looking at smaller buffers. for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; @@ -881,8 +848,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } else { // Simple strategy: round roubin. - for (auto it = sym_free_list_.begin(); - it != sym_free_list_.end(); ++it) { + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { StorageEntry* e = *it; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; @@ -904,8 +870,7 @@ class StoragePlanRewriter : public StmtExprMutator { // This rules only apply if we are using non special memory if (e->scope.tag.length() == 0) { // Disable sharing of local memory. - if (e->scope.rank >= StorageRank::kWarp || - e->allocs[0]->dtype.is_handle()) return; + if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return; // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; } @@ -936,7 +901,6 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; - // Turn alloc into vector alloc // if all its access is the same vector type. class VectorAllocRewriter : public StmtExprMutator { @@ -964,19 +928,15 @@ class VectorAllocRewriter : public StmtExprMutator { op = stmt.as(); const auto& tvec = acc_map_[op->buffer_var.get()]; - if (tvec.size() == 1 && - tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && - tvec[0].lanes() != op->dtype.lanes()) { + if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && + tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { int factor = tvec[0].lanes() / op->dtype.lanes(); Array extents = op->extents; arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return AllocateNode::make( - op->buffer_var, tvec[0], extents, - op->condition, op->body); + return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body); } } return stmt; @@ -1000,7 +960,6 @@ Stmt StorageRewrite(Stmt stmt) { return VectorAllocRewriter()(std::move(stmt)); } - PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); VectorAllocRewriter rewriter; @@ -1014,8 +973,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { const auto& tvec = rewriter.acc_map_[var.get()]; if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0]))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); args.push_back(new_var); remap_vars.Set(var, new_var); @@ -1023,8 +981,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0].with_lanes(1)))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); args.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -1042,7 +999,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { return f; } - namespace transform { Pass StorageRewrite() { @@ -1055,9 +1011,7 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite") -.set_body_typed(StorageRewrite); - +TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1067,7 +1021,7 @@ Pass PointerValueTypeRewrite() { } TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") -.set_body_typed(PointerValueTypeRewrite); + .set_body_typed(PointerValueTypeRewrite); } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 9924dd2bb0848..8650d2cc32c01 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -21,17 +21,17 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ +#include #include -#include #include -#include +#include #include #include -#include "storage_access.h" -#include "ir_util.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" +#include "storage_access.h" namespace tvm { namespace tir { @@ -47,7 +47,7 @@ class FragmentGetter : public StmtExprVisitor { std::string layout; FragmentInfo() = default; FragmentInfo(int _m, int _n, int _k, const std::string& _layout) - : m(_m), n(_n), k(_k), layout(_layout) {} + : m(_m), n(_n), k(_k), layout(_layout) {} }; void VisitExpr_(const CallNode* op) final { @@ -136,13 +136,12 @@ class FragmentGetter : public StmtExprVisitor { // Check shape of fragment making sure it is a valid shape for tvm_mma_sync class FragmentChecker : public StmtExprVisitor { public: - explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {} void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->is_intrinsic(intrinsic::tvm_mma_sync) || - op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); @@ -170,13 +169,13 @@ class FragmentChecker : public StmtExprVisitor { return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; } // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; // Store the metadata into attributes class InferFragmenter : public StmtMutator { public: - explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {} Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); @@ -186,15 +185,14 @@ class InferFragmenter : public StmtMutator { FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); // Add shape attribute to all fragments - std::string shape = std::to_string(info.m) + ", " + - std::to_string(info.n) + ", " + - std::to_string(info.k); + std::string shape = + std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); PrimExpr shape_expr = StringImmNode::make(shape); Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout, - StringImmNode::make(info.layout), shape_attr); + StringImmNode::make(info.layout), shape_attr); return layout_attr; } else { return shape_attr; @@ -205,7 +203,7 @@ class InferFragmenter : public StmtMutator { private: // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; Stmt InferFragment(Stmt stmt) { @@ -228,8 +226,7 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragment") -.set_body_typed(InferFragment); +TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index a32fd647dbb6c..0379fd9f56218 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -20,39 +20,35 @@ /*! * \file thread_storage_sync.cc */ -#include +#include #include -#include +#include #include #include -#include #include #include +#include "../../runtime/thread_storage_scope.h" #include "ir_util.h" #include "storage_access.h" -#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { class ThreadSyncPlanner : public StorageAccessVisitor { public: - explicit ThreadSyncPlanner(StorageScope sync_scope) - : sync_scope_(sync_scope) {} + explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} - // The syncs inserted before each statement + // The syncs inserted before each statement std::unordered_set syncs_inserted_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return in_device_env() && scope == sync_scope_; } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { // Unsynced reads and writes std::vector reads; std::vector writes; @@ -70,19 +66,23 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } // If sync is inserted. remove the irrelevant things. if (sync_before_stmt) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } // Add the read/write of current statement for (const AccessEntry& acc : s.access) { @@ -91,12 +91,12 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } else if (acc.type == kWrite) { writes.push_back(acc); } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); } } @@ -109,19 +109,21 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); break; } @@ -174,22 +176,16 @@ class ThreadSyncPlanner : public StorageAccessVisitor { private: // find conflicting entry in vec. - bool FindConflict(const std::vector& vec, - const AccessEntry& e, - bool loop_carry) { + bool FindConflict(const std::vector& vec, const AccessEntry& e, bool loop_carry) { for (const AccessEntry& x : vec) { if (x.buffer.same_as(e.buffer)) { // Assumes no race between threads // Same index value means no conflicts // TODO(tqchen) more standard set based testing. - if (e.touched.is_single_point() && - x.touched.is_single_point()) { - if (ExprDeepEqual()(e.touched.point_value(), - x.touched.point_value())) continue; + if (e.touched.is_single_point() && x.touched.is_single_point()) { + if (ExprDeepEqual()(e.touched.point_value(), x.touched.point_value())) continue; } - if (x.double_buffer_write && - e.type == kRead && - !loop_carry) continue; + if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; return true; } } @@ -203,8 +199,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { class ThreadSyncInserter : public StmtExprMutator { public: - ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set& syncs) + ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt VisitStmt(const Stmt& stmt) final { @@ -214,10 +209,9 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string())}, - CallNode::Intrinsic)); + barrier = EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImmNode::make(sync_scope_.to_string())}, + CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -258,8 +252,7 @@ class ThreadSyncInserter : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::make(op->value.as()->value); return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -316,13 +309,10 @@ class ThreadSyncInserter : public StmtExprMutator { } } rw_stats_.clear(); - Stmt kinit = EvaluateNode::make( - CallNode::make( - DataType::Int(32), - intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Stmt kinit = EvaluateNode::make(CallNode::make( + DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); - body = AttrStmtNode::make( - op->node, op->attr_key, op->value, body); + body = AttrStmtNode::make(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { @@ -334,8 +324,7 @@ class ThreadSyncInserter : public StmtExprMutator { IterVar iv = Downcast(attr->node); runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); if (s.rank == 0) { - num_blocks_ = (num_blocks_.defined() ? - attr->value * num_blocks_ : attr->value); + num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { PrimExpr cond = iv->var == make_zero(iv->var.dtype()); is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; @@ -346,9 +335,8 @@ class ThreadSyncInserter : public StmtExprMutator { } return EvaluateNode::make( CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string()), - is_lead_, num_blocks_}, - CallNode::Intrinsic)); + {StringImmNode::make(sync_scope_.to_string()), is_lead_, num_blocks_}, + CallNode::Intrinsic)); } // data structure. StorageScope sync_scope_; @@ -384,8 +372,7 @@ Pass ThreadSync(std::string storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ThreadSync") -.set_body_typed(ThreadSync); +TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 4fc69a3d892a6..a69ccc59adf31 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -22,32 +22,31 @@ * \file unroll_loop.cc */ // Unrolls the loop as in Halide pipeline. +#include #include #include #include -#include #include -#include -#include +#include + #include +#include #include -#include "ir_util.h" + #include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { class LoopUnroller : public StmtExprMutator { public: - explicit LoopUnroller(int auto_max_step, - int auto_max_depth, - int auto_max_extent, + explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) { - } + explicit_unroll_(explicit_unroll) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -72,24 +71,19 @@ class LoopUnroller : public StmtExprMutator { op = stmt.as(); int value = GetExtent(op); // condition for auto unroll - bool auto_unroll = ( - op->for_type == ForType::Serial && - value >= 0 && - normal_loop_depth_ == 0 && - unroll_depth_ <= auto_max_depth_); + bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 && + unroll_depth_ <= auto_max_depth_); - auto_unroll = auto_unroll && ( - value * step_count_ <= auto_max_step_|| - value <= auto_max_extent_); + auto_unroll = + auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); if (op->for_type == ForType::Unrolled) { - CHECK_GE(value, 0) - << "Cannot unroll non-constant loop"; + CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } if (auto_unroll) { - step_count_ *= value; + step_count_ *= value; unroll_depth_ += 1; } else { normal_loop_depth_ += 1; @@ -102,9 +96,8 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->for_type != ForType::Unrolled) { - return ForNode::make( - op->loop_var, op->min, op->extent, - ForType::Unrolled, op->device_api, op->body); + return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, + op->body); } } return stmt; @@ -159,7 +152,7 @@ class LoopUnroller : public StmtExprMutator { int GetExtent(const ForNode* op) { // constant folding. PrimExpr extent = analyzer_.Simplify(op->extent); - const IntImmNode *v1 = extent.as(); + const IntImmNode* v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, // as it's impossible to unroll such large loops @@ -186,17 +179,9 @@ class LoopUnroller : public StmtExprMutator { arith::Analyzer analyzer_; }; - -Stmt UnrollLoop(Stmt stmt, - int auto_max_step, - int auto_max_depth, - int auto_max_extent, +Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) { - Stmt ret = LoopUnroller( - auto_max_step, - auto_max_depth, - auto_max_extent, - explicit_unroll)(stmt); + Stmt ret = LoopUnroller(auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { @@ -206,24 +191,17 @@ Stmt UnrollLoop(Stmt stmt, namespace transform { -Pass UnrollLoop(int auto_max_step, - int auto_max_depth, - int auto_max_extent, - bool explicit_unroll) { +Pass UnrollLoop(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = UnrollLoop(std::move(f->body), - auto_max_step, - auto_max_depth, - auto_max_extent, + n->body = UnrollLoop(std::move(f->body), auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop") -.set_body_typed(UnrollLoop); +TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index e155c709c7460..9e553cb12ceb9 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -21,14 +21,16 @@ * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. +#include #include #include -#include #include -#include -#include +#include + #include +#include #include + #include "../../arith/compute_expr.h" namespace tvm { @@ -41,9 +43,8 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { return BroadcastNode::make(op->value, lanes); } } - CHECK_EQ(e.dtype().lanes(), 1) - << "Cannot broadcast lane=" << e.dtype().lanes() - << " to " << lanes; + CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " + << lanes; return BroadcastNode::make(e, lanes); } @@ -64,9 +65,8 @@ class VecAllocAccess : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { - return LoadNode::make(op->dtype, op->buffer_var, - op->index * var_lanes_ + var_, - op->predicate); + return LoadNode::make(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, + op->predicate); } else { return expr; } @@ -76,10 +76,8 @@ class VecAllocAccess : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { - return StoreNode::make(op->buffer_var, - op->value, - op->index * var_lanes_ + var_, - op->predicate); + return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_, + op->predicate); } else { return stmt; } @@ -96,8 +94,7 @@ class VecAllocAccess : public StmtExprMutator { class Vectorizer : public StmtExprMutator { public: - Vectorizer(Var var, int var_lanes) - : var_(var), var_lanes_(var_lanes) { + Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = RampNode::make(0, 1, var_lanes); } @@ -112,17 +109,12 @@ class Vectorizer : public StmtExprMutator { } } - PrimExpr VisitExpr_(const AddNode* op) final { - return AddSubVec(op); - } - PrimExpr VisitExpr_(const SubNode* op) final { - return AddSubVec(op); - } + PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); } + PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); } PrimExpr VisitExpr_(const MulNode* op) final { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -130,60 +122,30 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - return RampNode::make( - a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); + return RampNode::make(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - return RampNode::make( - b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); + return RampNode::make(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } } return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } return BinaryVec(op); } - PrimExpr VisitExpr_(const DivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const ModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorDivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MinNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MaxNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const EQNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const NENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const AndNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const OrNode* op) final { - return BinaryVec(op); - } + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); @@ -198,29 +160,23 @@ class Vectorizer : public StmtExprMutator { stride = BroadcastTo(stride, lanes); Array elems; for (int i = 0; i < lanes; ++i) { - elems.push_back( - RampNode::make(ShuffleNode::make_extract_element(base, i), - ShuffleNode::make_extract_element(stride, i), - op->lanes)); + elems.push_back(RampNode::make(ShuffleNode::make_extract_element(base, i), + ShuffleNode::make_extract_element(stride, i), op->lanes)); } return ShuffleNode::make_concat(elems); } - PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr t = this->VisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); - if (cond.same_as(op->condition) && - t.same_as(op->true_value) && - f.same_as(op->false_value)) { + if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { - int lanes = std::max(std::max( - cond.dtype().lanes(), - t.dtype().lanes()), f.dtype().lanes()); + int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); @@ -233,31 +189,28 @@ class Vectorizer : public StmtExprMutator { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { - return lets_[v]; + return lets_[v]; } else { return GetRef(v); } } // IfThenElse expr - PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr MutateIfThenElseExpr_(const CallNode* op) { PrimExpr cond = this->VisitExpr(op->args[0]); - if (cond.dtype().is_vector()) { + if (cond.dtype().is_vector()) { need_scalarize_ = true; return GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); - if (cond.same_as(op->args[0]) && - t.same_as(op->args[1]) && - f.same_as(op->args[2])) { + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return CallNode::make( - op->dtype.with_lanes(lanes), op->name, - {cond, t, f}, op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type, + op->func, op->value_index); } } // Call @@ -279,8 +232,8 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype, op->name, new_args, op->call_type, op->func, + op->value_index); } } else { int lane = 0; @@ -289,9 +242,8 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype.with_lanes(lane), op->name, new_args, - op->call_type, op->func, op->value_index); + return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type, + op->func, op->value_index); } } } @@ -303,11 +255,8 @@ class Vectorizer : public StmtExprMutator { return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return LoadNode::make( - op->dtype.with_lanes(lanes), - op->buffer_var, - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return LoadNode::make(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // Let @@ -320,8 +269,7 @@ class Vectorizer : public StmtExprMutator { return LetNode::make(v, value, this->VisitExpr(op->body)); } else { PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetNode::make(op->var, value, body); @@ -350,10 +298,8 @@ class Vectorizer : public StmtExprMutator { } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); - return StoreNode::make(op->buffer_var, - BroadcastTo(value, lanes), - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // For @@ -368,13 +314,10 @@ class Vectorizer : public StmtExprMutator { return Scalarize(GetRef(op)); } Stmt body = this->VisitStmt(op->body); - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, - op->for_type, op->device_api, body); + return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -389,8 +332,7 @@ class Vectorizer : public StmtExprMutator { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -421,12 +363,9 @@ class Vectorizer : public StmtExprMutator { // place the vector lanes in least significant dimension. extents.push_back(var_lanes_); // rewrite access to buffer internally. - Stmt body = VecAllocAccess( - op->buffer_var.get(), var_, var_lanes_)(op->body); + Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); } // scalarize the statment Stmt Scalarize(Stmt stmt) { @@ -473,24 +412,22 @@ class Vectorizer : public StmtExprMutator { if (!changed) return arr; return Array(new_arr); } - template + template PrimExpr BinaryVec(const T* op) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } - template + template PrimExpr AddSubVec(const T* op) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -500,12 +437,10 @@ class Vectorizer : public StmtExprMutator { if (a.dtype().lanes() == 1 && b_ramp) { return RampNode::make( arith::Compute(a, b_ramp->base), - arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), - b_ramp->lanes); + arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } if (b.dtype().lanes() == 1 && a_ramp) { - return RampNode::make( - arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + return RampNode::make(arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); @@ -529,9 +464,7 @@ class LoopVectorizer : public StmtMutator { } }; -Stmt VectorizeLoop(Stmt stmt) { - return LoopVectorizer()(std::move(stmt)); -} +Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } class VectorizeSkipper : public StmtMutator { public: @@ -539,18 +472,15 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return ForNode::make(op->loop_var, op->min, op->extent, - ForType::Serial, op->device_api, + return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); } else { - return stmt; + return stmt; } } }; -Stmt SkipVectorize(Stmt stmt) { - return VectorizeSkipper()(std::move(stmt)); -} +Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } namespace transform { @@ -568,8 +498,7 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop") -.set_body_typed(VectorizeLoop); +TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); } // namespace transform diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index f4c259fca3422..8e9d7bc69840f 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -25,7 +25,7 @@ TEST(Simplify, MinMax) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; + auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); auto e1s = ana.canonical_simplify(e1); CHECK(tvm::tir::is_zero(e1s)); @@ -37,7 +37,7 @@ TEST(Simplify, MinMax) { TEST(Simplify, Mul) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e = (x * x) - (x * x) ; + auto e = (x * x) - (x * x); auto es = ana.canonical_simplify(e); CHECK(tvm::tir::is_zero(es)); } @@ -53,7 +53,7 @@ TEST(Simplify, Mod) { auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index ccf1b251482fb..7b301bd13f684 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace test { @@ -33,23 +33,17 @@ struct TestAttrs : public AttrsNode { double learning_rate; TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name of the field"); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name of the field"); TVM_ATTR_FIELD(expr) .describe("expression field") .set_default(tir::make_const(DataType::Int(32), 1)); - TVM_ATTR_FIELD(learning_rate) - .describe("learning_rate") - .set_default(0.1); + TVM_ATTR_FIELD(learning_rate).describe("learning_rate").set_default(0.1); } }; -} -} +} // namespace test +} // namespace tvm TEST(Attrs, Basic) { using namespace tvm; @@ -84,12 +78,11 @@ TEST(Attrs, Basic) { // Check docstring std::ostringstream os; n->PrintDocString(os); - LOG(INFO) << "docstring\n"<< os.str(); + LOG(INFO) << "docstring\n" << os.str(); CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 4913731f1bd38..c9a91fc0afca3 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -20,12 +20,12 @@ #include #include #include -#include -#include #include +#include +#include -#include #include +#include TEST(BuildModule, Basic) { using namespace tvm; @@ -37,18 +37,17 @@ TEST(BuildModule, Basic) { auto A = placeholder(shape, DataType::Float(32), "A"); auto B = placeholder(shape, DataType::Float(32), "B"); - auto C = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "C"); + auto C = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "C"); - auto s = create_schedule({ C->op }); + auto s = create_schedule({C->op}); auto cAxis = C->op.as()->axis; IterVar bx, tx; s[C].split(cAxis[0], 64, &bx, &tx); - auto args = Array({ A, B, C }); + auto args = Array({A, B, C}); std::unordered_map binds; auto config = BuildConfig::Create(); @@ -94,19 +93,16 @@ TEST(BuildModule, Heterogeneous) { auto B = placeholder(shape, DataType::Float(32), "B"); auto C = placeholder(shape, DataType::Float(32), "C"); - auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "elemwise_add"); + auto elemwise_add = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add"); auto copy = placeholder(shape, DataType::Float(32), "__copy"); - auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) { - return copy[i] - C[i]; - }, "elemwise_sub"); + auto elemwise_sub = compute( + C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); With cuda_scope(target_cuda); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); - With llvm_scope(target_llvm); auto s2 = create_schedule({elemwise_sub->op}); @@ -117,8 +113,7 @@ TEST(BuildModule, Heterogeneous) { std::unordered_map binds; auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config); auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config); - Map inputs = {{target_cuda, lowered_s1}, - {target_llvm, lowered_s2}}; + Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target(), config); // Assertion for build. @@ -148,12 +143,9 @@ TEST(BuildModule, Heterogeneous) { "\"float32\"]]}}"; // Setup inputs. - auto a_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto b_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto c_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto a_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto pa = (float*)(a_val->data); auto pb = (float*)(b_val->data); @@ -174,8 +166,17 @@ TEST(BuildModule, Heterogeneous) { const runtime::PackedFunc* graph_runtime = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); - runtime::Module mod = (*graph_runtime)( - json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + runtime::Module mod = + (*graph_runtime)(json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + + // test FFI for module. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + + test_ffi(runtime::Module(mod), static_cast(kTVMModuleHandle)); + test_ffi(Optional(mod), static_cast(kTVMModuleHandle)); PackedFunc set_input = mod.GetFunction("set_input", false); PackedFunc run = mod.GetFunction("run", false); @@ -194,7 +195,7 @@ TEST(BuildModule, Heterogeneous) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index c67df63e6e7e5..5d1f4720b965b 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include @@ -35,8 +35,7 @@ class TestErrorSwitch { public: // Need this so that destructor of temporary objects don't interrupt our // testing. - TestErrorSwitch(const TestErrorSwitch& other) - : should_fail(other.should_fail) { + TestErrorSwitch(const TestErrorSwitch& other) : should_fail(other.should_fail) { const_cast(other).should_fail = false; } @@ -50,8 +49,7 @@ class TestErrorSwitch { } }; -class TestArrayObj : public Object, - public InplaceArrayBase { +class TestArrayObj : public Object, public InplaceArrayBase { public: static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "test.TestArrayObj"; @@ -112,8 +110,7 @@ TEST(InplaceArrayBase, BadExceptionSafety) { TestErrorSwitch f2{true}; TestErrorSwitch f3{false}; std::vector fields{f1, f2, f3}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->WrongInit(fields.begin(), fields.end()); } catch (...) { @@ -133,8 +130,7 @@ TEST(InplaceArrayBase, ExceptionSafety) { // since it's not initalized. TestErrorSwitch f2{true}; std::vector fields{f1, f2}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->Init(fields.begin(), fields.end()); } catch (...) { @@ -223,8 +219,7 @@ TEST(Map, Iterator) { using namespace tvm; PrimExpr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2( - map1.begin(), map1.end()); + std::unordered_map map2(map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } @@ -402,7 +397,6 @@ TEST(String, Cast) { String s2 = Downcast(r); } - TEST(Optional, Composition) { Optional opt0(nullptr); Optional opt1 = String("xyz"); @@ -468,6 +462,18 @@ TEST(Optional, PackedCall) { CHECK(packedfunc("xyz", false).operator String() == "xyz"); CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); + + // test FFI convention. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + String s = "xyz"; + auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0}); + test_ffi(Optional(nd), static_cast(kTVMNDArrayHandle)); + test_ffi(Optional(s), static_cast(kTVMObjectRValueRefArg)); + test_ffi(s, static_cast(kTVMObjectHandle)); + test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } int main(int argc, char** argv) { diff --git a/tests/cpp/crt_memory_test.cc b/tests/cpp/crt_memory_test.cc index 1c129166f1220..c2582ba02525e 100644 --- a/tests/cpp/crt_memory_test.cc +++ b/tests/cpp/crt_memory_test.cc @@ -27,7 +27,7 @@ TEST(CRTMemory, Alloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vmalloc(1); + void* a = vmalloc(1); EXPECT_EQ(vleak_size, 1); vfree(a); EXPECT_EQ(vleak_size, 0); @@ -36,9 +36,9 @@ TEST(CRTMemory, Alloc) { TEST(CRTMemory, Realloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vrealloc(0, 1); + void* a = vrealloc(0, 1); EXPECT_EQ(vleak_size, 1); - void * b = vrealloc(a, 1); + void* b = vrealloc(a, 1); EXPECT_EQ(a, b); EXPECT_EQ(vleak_size, 1); vfree(a); @@ -46,7 +46,7 @@ TEST(CRTMemory, Realloc) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index e17cc73b18a10..a5d47dd4d989d 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -34,7 +34,6 @@ TEST(Expr, Basic) { CHECK(os.str() == "max(((x + 1) + 2), 100)"); } - TEST(ExprNodeRef, Basic) { using namespace tvm; using namespace tvm::tir; @@ -44,8 +43,7 @@ TEST(ExprNodeRef, Basic) { CHECK(GetRef(op).same_as(z)); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 3941de5eef178..052cba1b26268 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -19,10 +19,10 @@ #include #include -#include -#include #include +#include #include +#include #include TEST(IRF, Basic) { @@ -32,14 +32,10 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - f.set_dispatch([](const ObjectRef& n, int b) { - return b; - }); - f.set_dispatch([](const ObjectRef& n, int b) { - return b + 2; - }); - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 4); + f.set_dispatch([](const ObjectRef& n, int b) { return b; }); + f.set_dispatch([](const ObjectRef& n, int b) { return b + 2; }); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 4); } TEST(IRF, CountVar) { @@ -51,37 +47,31 @@ TEST(IRF, CountVar) { auto z = x + 1 + y + y; tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; - }); + }); CHECK_EQ(n_var, 2); } - TEST(IRF, ExprTransform) { using namespace tvm; using namespace tvm::tir; Var x("x"); auto z = x + 1; - class MyExprFunctor - : public tir::ExprFunctor { + class MyExprFunctor : public tir::ExprFunctor { public: - int VisitExpr_(const VarNode* op, int b) final { - return b; - } - int VisitExpr_(const IntImmNode* op, int b) final { - return op->value; - } + int VisitExpr_(const VarNode* op, int b) final { return b; } + int VisitExpr_(const IntImmNode* op, int b) final { return op->value; } int VisitExpr_(const AddNode* op, int b) final { return VisitExpr(op->a, b) + VisitExpr(op->b, b); } }; MyExprFunctor f; - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 3); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 3); try { f(z - 1, 2); LOG(FATAL) << "should fail"; - } catch(dmlc::Error) { + } catch (dmlc::Error) { } } @@ -91,43 +81,33 @@ TEST(IRF, ExprVisit) { Var x("x"); auto z = x + 1; - class MyVisitor - : public tir::ExprFunctor, - public tir::StmtFunctor { + class MyVisitor : public tir::ExprFunctor, + public tir::StmtFunctor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } - void VisitExpr_(const IntImmNode* op) final { - } + void VisitExpr_(const VarNode* op) final { ++count; } + void VisitExpr_(const IntImmNode* op) final {} void VisitExpr_(const AddNode* op) final { VisitExpr(op->a); VisitExpr(op->b); } - void VisitStmt_(const EvaluateNode* op) final { - VisitExpr(op->value); - } + void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); } }; MyVisitor v; v.VisitStmt(EvaluateNode::make(z)); CHECK_EQ(v.count, 1); } - TEST(IRF, StmtVisitor) { using namespace tvm; using namespace tvm::tir; Var x("x"); - class MyVisitor - : public StmtExprVisitor { + class MyVisitor : public StmtExprVisitor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } + void VisitExpr_(const VarNode* op) final { ++count; } }; MyVisitor v; auto fmaketest = [&]() { @@ -145,24 +125,16 @@ TEST(IRF, StmtMutator) { using namespace tvm::tir; Var x("x"); - class MyVisitor - : public tir::StmtMutator, - public tir::ExprMutator { + class MyVisitor : public tir::StmtMutator, public tir::ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: // implementation - PrimExpr VisitExpr_(const AddNode* op) final { - return op->a; - } - Stmt VisitStmt_(const SeqStmtNode* op) final { - return StmtMutator::VisitSeqStmt_(op, true); - } - PrimExpr VisitExpr(const PrimExpr& expr) final { - return ExprMutator::VisitExpr(expr); - } + PrimExpr VisitExpr_(const AddNode* op) final { return op->a; } + Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); } + PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); } }; auto fmakealloc = [&]() { auto z = x + 1; @@ -220,7 +192,8 @@ TEST(IRF, StmtMutator) { } { - auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = + EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } @@ -255,7 +228,7 @@ TEST(IRF, StmtMutator) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index 438f688da0a41..0df802497434d 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -19,8 +19,8 @@ #include #include -#include #include +#include namespace tvm { namespace test { @@ -59,7 +59,6 @@ class ObjAA : public ObjA { TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA); }; - TVM_REGISTER_OBJECT_TYPE(ObjBase); TVM_REGISTER_OBJECT_TYPE(ObjA); TVM_REGISTER_OBJECT_TYPE(ObjB); @@ -97,7 +96,7 @@ TEST(ObjectHierachy, Basic) { CHECK(refB.as() != nullptr); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 787e0c4b8f4df..523df98913328 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -19,11 +19,11 @@ #include #include -#include #include +#include #include -#include #include +#include TEST(PackedFunc, Basic) { using namespace tvm; @@ -34,15 +34,15 @@ TEST(PackedFunc, Basic) { DLTensor a; Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 3); - CHECK(args.values[0].v_float64 == 1.0); - CHECK(args.type_codes[0] == kDLFloat); - CHECK(args.values[1].v_handle == &a); - CHECK(args.type_codes[1] == kTVMDLTensorHandle); - CHECK(args.values[2].v_handle == &x); - CHECK(args.type_codes[2] == kTVMOpaqueHandle); - *rv = Var("a"); - })(1.0, &a, handle); + CHECK(args.num_args == 3); + CHECK(args.values[0].v_float64 == 1.0); + CHECK(args.type_codes[0] == kDLFloat); + CHECK(args.values[1].v_handle == &a); + CHECK(args.type_codes[1] == kTVMDLTensorHandle); + CHECK(args.values[2].v_handle == &x); + CHECK(args.type_codes[2] == kTVMOpaqueHandle); + *rv = Var("a"); + })(1.0, &a, handle); CHECK(v->name_hint == "a"); } @@ -52,36 +52,32 @@ TEST(PackedFunc, Node) { using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - CHECK(args[0].IsObjectRef()); - Var b = args[0]; - CHECK(x.same_as(b)); - *rv = b; - })(x); + CHECK(args.num_args == 1); + CHECK(args[0].IsObjectRef()); + Var b = args[0]; + CHECK(x.same_as(b)); + *rv = b; + })(x); CHECK(t.same_as(x)); } TEST(PackedFunc, NDArray) { using namespace tvm; using namespace tvm::runtime; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); reinterpret_cast(x->data)[0] = 10.0f; CHECK(x.use_count() == 1); - PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - NDArray y = args[0]; - DLTensor* ptr = args[0]; - CHECK(ptr == x.operator->()); - CHECK(x.same_as(y)); - CHECK(x.use_count() == 2); - *rv = forward(y); - })(x); + NDArray y = args[0]; + DLTensor* ptr = args[0]; + CHECK(ptr == x.operator->()); + CHECK(x.same_as(y)); + CHECK(x.use_count() == 2); + *rv = forward(y); + })(x); CHECK(ret.use_count() == 2); CHECK(ret.same_as(x)); } @@ -90,48 +86,45 @@ TEST(PackedFunc, str) { using namespace tvm; using namespace tvm::runtime; PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - std::string x = args[0]; - CHECK(x == "hello"); - String y = args[0]; - CHECK(y == "hello"); - *rv = x; - })("hello"); + CHECK(args.num_args == 1); + std::string x = args[0]; + CHECK(x == "hello"); + String y = args[0]; + CHECK(y == "hello"); + *rv = x; + })("hello"); PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - runtime::String s = args[0]; - CHECK(s == "hello"); + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); })(runtime::String("hello")); } - TEST(PackedFunc, func) { using namespace tvm; using namespace tvm::runtime; - PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0].operator int() + 1; - }); + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); // function as arguments int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - // TVMArgValue -> TVMRetValue - *rv = args[1]; - })(2, 100); + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); CHECK_EQ(r1, 100); int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // re-assignment - *rv = args[0]; - // TVMRetValue -> Function argument - *rv = addone(args[0].operator PackedFunc()(args[1], 1)); - })(addone, 100); + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); CHECK_EQ(r2, 102); } @@ -140,14 +133,14 @@ TEST(PackedFunc, Expr) { using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { - PrimExpr x = args[0]; - *rv = x.as()->value + 1; + PrimExpr x = args[0]; + *rv = x.as()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); } @@ -155,12 +148,10 @@ TEST(PackedFunc, Type) { using namespace tvm; using namespace tvm::runtime; auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - DataType x = args[0]; - *rv = x; - }); - auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + DataType x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); CHECK(get_type("int32").operator DataType() == DataType::Int(32)); CHECK(get_type("float").operator DataType() == DataType::Float(32)); CHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); @@ -174,9 +165,7 @@ TEST(TypedPackedFunc, HighOrder) { using BindFunc = TypedPackedFunc; BindFunc ftyped; ftyped = [](Int2Func f1, int value) -> Int1Func { - auto binded = [f1, value](int x) { - return f1(value, x); - }; + auto binded = [f1, value](int x) { return f1(value, x); }; Int1Func x(binded); return x; }; @@ -194,28 +183,23 @@ TEST(TypedPackedFunc, Deduce) { using tvm::runtime::detail::function_signature; TypedPackedFunc x; - auto f = [](int x) -> int { - return x + 1; - }; + auto f = [](int x) -> int { return x + 1; }; std::function y; - static_assert(std::is_same::FType, - int(float)>::value, "invariant1"); - static_assert(std::is_same::FType, - int(int)>::value, "invariant2"); - static_assert(std::is_same::FType, - void(float)>::value, "invariant3"); + static_assert(std::is_same::FType, int(float)>::value, + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, + "invariant2"); + static_assert(std::is_same::FType, void(float)>::value, + "invariant3"); } - TEST(PackedFunc, ObjectConversion) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; TVMRetValue rv; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); // assign null rv = ObjectRef(); CHECK_EQ(rv.type_code(), kTVMNullptr); @@ -232,15 +216,15 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); - CHECK(args[0].operator NDArray().same_as(x)); - CHECK(args[0].operator ObjectRef().same_as(x)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(args[1].operator Array().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); + CHECK(args[0].operator NDArray().same_as(x)); + CHECK(args[0].operator ObjectRef().same_as(x)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(args[1].operator Array().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf1(x, ObjectRef()); pf1(ObjectRef(x), NDArray()); @@ -259,14 +243,14 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMModuleHandle); - CHECK(args[0].operator Module().same_as(m)); - CHECK(args[0].operator ObjectRef().same_as(m)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMModuleHandle); + CHECK(args[0].operator Module().same_as(m)); + CHECK(args[0].operator ObjectRef().same_as(m)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf2(m, ObjectRef()); pf2(ObjectRef(m), Module()); } @@ -275,13 +259,12 @@ TEST(TypedPackedFunc, RValue) { using namespace tvm; using namespace tvm::runtime; { - auto inspect = [](TVMArgs args, TVMRetValue* rv) { for (int i = 0; i < args.size(); ++i) { CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); } }; - PackedFunc finspect(inspect); + PackedFunc finspect(inspect); finspect(tir::Var("x")); } { @@ -325,7 +308,7 @@ TEST(TypedPackedFunc, RValue) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5cb79101a05ee..59d0a43782a8f 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -17,9 +17,10 @@ * under the License. */ +#include "../src/arith/pattern_match.h" + #include #include -#include "../src/arith/pattern_match.h" TEST(Pattern, Basic) { using namespace tvm; @@ -64,8 +65,7 @@ TEST(Pattern, Basic) { CHECK((px >= py && px < pz).Match(x >= y && x < z)); CHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - CHECK(select(px >= pz, py, py + pz).Match( - tir::SelectNode::make((x + 1) >= 1, y, y + 1))); + CHECK(select(px >= pz, py, py + pz).Match(tir::SelectNode::make((x + 1) >= 1, y, y + 1))); CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics @@ -81,52 +81,44 @@ TEST(Pattern, Basic) { CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - CHECK(select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 1, y, y + 1))); + CHECK(select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } - CHECK(!select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); - CHECK(!select(px > pz, py, py).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py).Match(tir::SelectNode::make(x > 2, y, y + 1))); { - CHECK(select(px, py, pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(select(px, py, pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { - CHECK(if_then_else(px > pz, py, py + pz).Match( - if_then_else(x > 1, y, y + 1))); + CHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } // cast pattern { - CHECK(!cast(PConst( - DataType::Int(32)), px).Match(tir::CastNode::make(DataType::Float(64), x))); + CHECK(!cast(PConst(DataType::Int(32)), px) + .Match(tir::CastNode::make(DataType::Float(64), x))); CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); - CHECK((cast(pt, px) - cast(pt, py)).Match( - tir::CastNode::make(DataType::Float(64), x) - tir::CastNode::make(DataType::Int(64), x))); + CHECK((cast(pt, px) - cast(pt, py)) + .Match(tir::CastNode::make(DataType::Float(64), x) - + tir::CastNode::make(DataType::Int(64), x))); auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - CHECK(ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 1, 10))); + CHECK(ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 1, 10))); CHECK(planes.Eval() == 10); - CHECK(!ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 2, 10))); + CHECK(!ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 2, 10))); } // broadcast pattern { - CHECK(broadcast(px, planes).Match( - tir::BroadcastNode::make(x, 10))); + CHECK(broadcast(px, planes).Match(tir::BroadcastNode::make(x, 10))); CHECK(planes.Eval() == 10); - CHECK(broadcast(px * py , planes).Match( - tir::BroadcastNode::make(x * 10, 10))); + CHECK(broadcast(px * py, planes).Match(tir::BroadcastNode::make(x * 10, 10))); } } @@ -148,7 +140,7 @@ TEST(Pattern, IntImm) { CHECK(!(v * c).Match((tx + 1) * 3)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 7ccd1123f527d..d7ce0c0e3d6cd 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -18,61 +18,59 @@ */ #include -#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include +#include #include +#include +#include +#include +#include +#include +#include #include +#include #include +#include using namespace tvm; using namespace tvm::relay; TVM_REGISTER_GLOBAL("test.strategy") -.set_body_typed([](const Attrs& attrs, const Array& inputs, - const Type& out_type, const Target& target) { - FTVMCompute fcompute = [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) -> Array { + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { CHECK_EQ(inputs.size(), 2U); return {topi::add(inputs[0], inputs[1])}; - }; - FTVMSchedule fschedule = [](const Attrs& attrs, - const Array& outs, - const Target& target) { + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { With target_scope(target); return topi::generic::schedule_injective(target, outs); - }; + }; - auto n = make_object(); - auto strategy = tvm::relay::OpStrategy(std::move(n)); - strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); - return strategy; -}); + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); + return strategy; + }); TVM_REGISTER_GLOBAL("relay.backend.lower_call") -.set_body_typed([](const relay::Call& call, const Array& inputs, - const Target& target) { - static auto fstrategy = Op::GetAttr("FTVMStrategy"); - Op op = Downcast(call->op); - auto out_type = call->checked_type(); - OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); - auto impl = strategy->specializations[0]->implementations[0]; - auto outs = impl.Compute(call->attrs, inputs, out_type); - auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); - if (!f) { - LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; - } - return (*f)(outs, impl); -}); + .set_body_typed([](const relay::Call& call, const Array& inputs, + const Target& target) { + static auto fstrategy = Op::GetAttr("FTVMStrategy"); + Op op = Downcast(call->op); + auto out_type = call->checked_type(); + OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); + auto impl = strategy->specializations[0]->implementations[0]; + auto outs = impl.Compute(call->attrs, inputs, out_type); + auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); + if (!f) { + LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; + } + return (*f)(outs, impl); + }); TEST(Relay, BuildModule) { auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); @@ -179,7 +177,7 @@ TEST(Relay, GetExprRefCount) { CHECK(ref_count[z.get()] == 1); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 3c416918e4414..cb7330dfab6da 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -19,30 +19,30 @@ #include #include -#include -#include -#include #include +#include #include +#include +#include TEST(Relay, SelfReference) { using namespace tvm; auto tensor_type = relay::TensorType({}, DataType::Bool()); auto x = relay::Var("x", relay::Type()); - auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); + auto f = relay::Function(tvm::Array{x}, x, relay::Type(), {}); CHECK(f->IsInstance()); auto y = relay::Var("y", tensor_type); - auto call = relay::Call(f, Array{ y }); - auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); + auto call = relay::Call(f, Array{y}); + auto fx = relay::Function(tvm::Array{y}, call, relay::Type(), {}); auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); - auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); + auto expected = relay::FuncType(tvm::Array{tensor_type}, tensor_type, {}, {}); CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index d974f023d74b6..e01a6ea4b3daa 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -19,27 +19,25 @@ #include #include -#include #include -#include #include +#include #include +#include #include #include #include #include #include -TVM_REGISTER_GLOBAL("schedule") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("schedule").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); +}); TEST(Relay, Sequential) { using namespace tvm; auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32)); - auto c_data = - tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); // Create a function for optimization. auto c = relay::Constant(c_data); @@ -53,8 +51,7 @@ TEST(Relay, Sequential) { auto z2 = relay::Call(add_op, {z, z1}); // Let expression and varaible a should be dead-code eliminated. auto z3 = relay::Let(a, c, z2); - relay::Function func = - relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); + relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); // Get schedule auto reg = tvm::runtime::Registry::Get("relay.op._Register"); @@ -67,11 +64,8 @@ TEST(Relay, Sequential) { // Run sequential passes. tvm::Array pass_seqs{ - relay::transform::InferType(), - relay::transform::DeadCodeElimination(), - relay::transform::EliminateCommonSubexpr(), - relay::transform::AlterOpLayout() - }; + relay::transform::InferType(), relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()}; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); auto mod = IRModule::FromExpr(func); auto pass_ctx = relay::transform::PassContext::Create(); @@ -96,8 +90,7 @@ TEST(Relay, Sequential) { y1 = relay::Call(add_op, {x1, y1}); auto zz = relay::Call(add_op, {y1, c1}); zz = relay::Call(add_op, {zz, zz}); - relay::Function expected_func = - relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); + relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. auto mod1 = IRModule::FromExpr(expected_func); diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index be4c746748097..36b36452f4fc8 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -33,8 +33,7 @@ TEST(SimplePasses, HasSideEffect) { CHECK(!tvm::tir::HasSideEffect(A[0])); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a9566cb6d0055..ea02ca656dce3 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -30,9 +30,8 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); Tensor B = placeholder({n, l}, DataType::Float(32), "B"); - auto C = compute({m, n}, [&](Var i, Var j) { - return A[i][j]; - }, "C"); + auto C = compute( + {m, n}, [&](Var i, Var j) { return A[i][j]; }, "C"); Tensor::Slice x = A[n]; } @@ -46,13 +45,12 @@ TEST(Tensor, Reduce) { te::Tensor B = te::placeholder({n, l}, DataType::Float(32), "B"); IterVar rv = reduce_axis(Range{0, l}, "k"); - auto C = te::compute({m, n}, [&](Var i, Var j) { - return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); - }, "C"); + auto C = te::compute( + {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index 508705c2630a5..cf7434b4b036b 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -17,13 +17,13 @@ * under the License. */ +#include +#include + #include #include #include -#include -#include - constexpr size_t N = 128; static FTVMParallelLambda atomic_add_task_id = [](int task_id, TVMParallelGroupEnv* penv, diff --git a/tests/cpp/topi_ewise_test.cc b/tests/cpp/topi_ewise_test.cc index a1ca6d7fd2291..10c7b9d7464b6 100644 --- a/tests/cpp/topi_ewise_test.cc +++ b/tests/cpp/topi_ewise_test.cc @@ -17,9 +17,9 @@ * under the License. */ -#include -#include #include +#include +#include namespace topi { TEST(Tensor, Basic) { @@ -28,9 +28,9 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); auto C = topi::exp(A); } -} +} // namespace topi -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index fa809265d5ae9..c9c9f88571d00 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -17,11 +17,11 @@ * under the License. */ -#include - #include #include + #include +#include #include #ifdef USE_MICRO_STANDALONE_RUNTIME @@ -30,9 +30,10 @@ #if defined(__APPLE__) && defined(__MACH__) #include +#include +#include #include #include -#include #include #include #include @@ -41,9 +42,7 @@ #include #include #include - -#include -#include +#include TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); diff --git a/tests/lint/add_asf_header.py b/tests/lint/add_asf_header.py index a44fbd3df1b5c..21d25c25e5738 100644 --- a/tests/lint/add_asf_header.py +++ b/tests/lint/add_asf_header.py @@ -181,7 +181,9 @@ def add_header(fname, header): skipline = False ext = os.path.splitext(fname)[1][1:] - if lines[0][:2] == "#!": + if not lines: + skipline = False # File is enpty + elif lines[0][:2] == "#!": skipline = True elif lines[0][:2] == "= package_version.parse('1.14.0'): + x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32) + in0 = x + in1 = [x, y] + in2 = (x, y, z) + in3 = m + in4 = [m, n] + in5 = (m, n, o) + _test_forward_add_n(in0) + _test_forward_add_n(in1) + _test_forward_add_n(in2) + _test_forward_add_n(in3) + _test_forward_add_n(in4) + _test_forward_add_n(in5) + + ####################################################################### # Logical operators # ----------------- @@ -1151,7 +1256,12 @@ def _test_logical_binary(logical_bin_op, data): with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')] - out = logical_bin_op(in_data[0], in_data[1], name='out') + if logical_bin_op == math_ops.logical_not: + out = math_ops.logical_or(in_data[0], in_data[1], name='out1') + out = logical_bin_op(out, name='out') + else: + out = logical_bin_op(in_data[0], in_data[1], name='out') + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) def _test_forward_logical_and(data): @@ -1162,6 +1272,10 @@ def _test_forward_logical_or(data): """ One iteration of logical or """ return _test_logical_binary(math_ops.logical_or, data) +def _test_forward_logical_not(data): + """ One iteration of logical not """ + return _test_logical_binary(math_ops.logical_not, data) + def test_all_logical(): data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] @@ -1169,6 +1283,7 @@ def test_all_logical(): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_forward_logical_and(data) _test_forward_logical_or(data) + _test_forward_logical_not(data) ####################################################################### # Zeros like @@ -1185,6 +1300,39 @@ def test_forward_zeros_like(): """ ZEROS LIKE """ _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + +####################################################################### +# Fill +# ---- + +def _test_fill(dims, value_data, value_dtype): + """ Use the fill op to create a tensor of value_data with constant dims.""" + + value_data = np.array(value_data, dtype=value_dtype) + # TF 1.13 TFLite convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + with tf.Graph().as_default(): + value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[]) + out = tf.fill(dims, value) + compare_tflite_with_tvm([value_data], ["value"], [value], [out]) + + with tf.Graph().as_default(): + input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims) + # Fill op gets converted to static tensor during conversion + out = tf.fill(dims, value_data) + out1 = tf.add(out, input1) + input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype) + compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1]) + + +def test_forward_fill(): + """ Test FILL op """ + + _test_fill((1, 2, 2, 4), 5, "int32") + _test_fill((1, 2, 2, 4), 5, "float32") + _test_fill((5, ), 5, "int32") + + ####################################################################### # Reduce # ------ @@ -1301,6 +1449,27 @@ def test_all_reduce(): ####################################################################### +# Select, Where +# ------------- + +def test_forward_select(): + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + out = tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + in_data2 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + + compare_tflite_with_tvm([in_data1, in_data2], [ + 'input1:0', 'input2:0'], [input1, input2], [out]) + + # Squeeze # ------- @@ -1664,16 +1833,30 @@ def test_detection_postprocess(): tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) - # check valid count is the same + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same assert tvm_output[3] == tflite_output[3] valid_count = tvm_output[3][0] - tvm_boxes = tvm_output[0][0][:valid_count] - tvm_classes = tvm_output[1][0][:valid_count] - tvm_scores = tvm_output[2][0][:valid_count] - # check the output data is correct - tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5) + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### @@ -1857,25 +2040,100 @@ def test_forward_qnn_mobilenet_v3_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + +####################################################################### +# Quantized SSD Mobilenet +# ----------------------- + +def test_forward_qnn_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + pytest.skip("LLVM bug - getExtendedVectorNumElements - " + + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a " + + "specific target, for example, llvm -mpcu=core-avx2") + + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", + "detect.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = get_real_image_object_detection(300, 300) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # We compare the bounding boxes whose prediction score is above 60%. This is typical in end + # to end application where a low prediction score is discarded. This is also needed because + # multiple low score bounding boxes can have same score and TFlite and TVM can have + # different orderings for same score bounding boxes. Another reason for minor differences in + # low score bounding boxes is the difference between TVM and TFLite for requantize operator. + if tvm_output[2][0][i] > 0.6: + # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2, + # because of differences between for requantiize operator in TFLite and TVM. + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), + np.squeeze(tflite_output[0][0][i]), + rtol=1e-2, atol=1e-2) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), + np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + + ####################################################################### # SSD Mobilenet # ------------- -def test_forward_ssd_mobilenet_v1(): - """Test the SSD Mobilenet V1 TF Lite model.""" - # SSD MobilenetV1 +def test_forward_coco_ssd_mobilenet_v1(): + """Test the FP32 Coco SSD Mobilenet V1 TF Lite model.""" tflite_model_file = tf_testing.get_workload_official( - "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz", - "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite") + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz", + "ssd_mobilenet_v1_coco_2018_01_28.tflite") + with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() + np.random.seed(0) data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2) - for i in range(2): - tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), - rtol=1e-5, atol=2e-5) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### # MediaPipe @@ -1896,6 +2154,7 @@ def test_forward_mediapipe_hand_landmark(): tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), rtol=1e-5, atol=1e-5) + ####################################################################### # Main # ---- @@ -1915,6 +2174,9 @@ def test_forward_mediapipe_hand_landmark(): # Cast test_forward_cast() + # BatchMatMul + test_forward_batch_matmul() + # Tile test_forward_tile() @@ -1932,12 +2194,14 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_select() # NN test_forward_convolution() test_forward_transpose_conv() test_forward_logistic() test_forward_pooling() + test_forward_l2_pool2d() test_forward_softmax() test_forward_tanh() test_forward_relu() @@ -1948,12 +2212,16 @@ def test_forward_mediapipe_hand_landmark(): # Elemwise test_all_elemwise() + test_forward_add_n() # Unary elemwise test_all_unary_elemwise() # Zeros Like test_forward_zeros_like() + # Fill + test_forward_fill() + # Reduce test_all_reduce() @@ -1969,7 +2237,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() - test_forward_ssd_mobilenet_v1() + test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark() # End to End quantized @@ -1979,3 +2247,4 @@ def test_forward_mediapipe_hand_landmark(): #This also fails with a segmentation fault in my run #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() + test_forward_qnn_coco_ssd_mobilenet_v1() diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 82ade4478beac..7ac3496c994fe 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -338,6 +338,106 @@ def check_target(device): check_target("cuda") check_target("vulkan") +def test_warp_reduction1(): + nthx = 32 + nthy = 4 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthx), "threadIdx.x") + thread_y = te.thread_axis((0, nthy), "threadIdx.y") + + def check_target(device, m, n): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # compute + A = te.placeholder((m, n), name='A') + k = te.reduce_axis((0, n)) + B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B') + s = te.create_schedule(B.op) + + # schedule + k = s[B].op.reduce_axis[0] + ko, _ = s[B].split(k, nparts=nthx) + s[B].bind(ko, thread_x) + xo, xi = s[B].split(s[B].op.axis[0], factor=nthy) + s[B].bind(xi, thread_y) + s[B].bind(xo, block_x) + + print(tvm.lower(s, [A, B], simple_mode=True)) + + # validation + func = tvm.build(s, [A, B], "cuda", name="warp_reduction") + a_np = np.random.uniform(size=(m,n)).astype(A.dtype) + b_np = np.zeros((m,), dtype=A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + b_np = np.max(a_np, axis=1) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + check_target("cuda", m=32, n=256) + check_target("cuda", m=10, n=20) + # This is a bug in normal reduction. + # check_target("cuda", m=10, n=37) + +def test_warp_reduction2(): + def fcombine(x, y): + return x[0] + y[0], x[1] * y[1] + + def fidentity(t0, t1): + return tvm.tir.const(0, t0), tvm.tir.const(1, t1) + + add_mul_reducer = te.comm_reducer(fcombine, fidentity, name='add_mul_reducer') + + # compute + m = 16 + n = 256 + A0 = te.placeholder((m, n), name='A0', dtype='float32') + A1 = te.placeholder((m, n), name='Al', dtype='float32') + k = te.reduce_axis((0, n), 'k') + T0, T1 = te.compute((m, ), lambda i: \ + add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T') + + nthdx, nthdy = 32, 2 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthdx), "threadIdx.x") + thread_y = te.thread_axis((0, nthdy), "threadIdx.y") + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # schedule + s = te.create_schedule(T0.op) + ko, _ = s[T0].split(k, nparts=nthdx) + xo, xi = s[T0].split(s[T0].op.axis[0], factor=nthdy) + s[T0].bind(ko, thread_x) + s[T0].bind(xi, thread_y) + s[T0].bind(xo, block_x) + + # validation + ctx = tvm.context(device, 0) + a0_np = np.random.uniform(size=(m,n)).astype(A0.dtype) + a1_np = np.random.uniform(size=(m,n)).astype(A1.dtype) + t0_np = np.zeros((m,), dtype=A0.dtype) + t1_np = np.zeros((m,), dtype=A1.dtype) + a0 = tvm.nd.array(a0_np, ctx) + a1 = tvm.nd.array(a1_np, ctx) + t0 = tvm.nd.array(t0_np, ctx) + t1 = tvm.nd.array(t1_np, ctx) + func = tvm.build(s, [A0, A1, T0, T1], device, name="reduction") + func(a0, a1, t0, t1) + t0_np = np.sum(a0_np, axis=1) + t1_np = np.product(a1_np, axis=1) + tvm.testing.assert_allclose(t0.asnumpy(), t0_np, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3) + + check_target("cuda") + if __name__ == "__main__": test_rfactor_elemwise_threads() test_rfactor_threads() @@ -346,3 +446,5 @@ def check_target(device): test_reduce_prims() test_argmax() test_rfactor_argmax() + test_warp_reduction1() + test_warp_reduction2() diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6cdf7d76be9eb..c9de6754aa89b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -521,6 +521,34 @@ def test_any_pad(): verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) +def verify_any_dilate(data_shape, strides, static_data_shape): + assert len(data_shape) == len(strides) + mod = tvm.IRModule() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.dilate(data, strides) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1 + for i in range(len(static_data_shape))) + ref_out = np.zeros(shape=ref_shape, dtype=dtype) + ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_dilate(): + verify_any_dilate(any_dims(1), (1,), (1,)) + verify_any_dilate(any_dims(1), (1,), (5,)) + verify_any_dilate(any_dims(1), (5,), (5,)) + verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3)) + verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3)) + verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4)) + def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 5ef98a5b8f6d2..9faf6d903a9c4 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -68,8 +68,13 @@ def check_single_op(opfunc, ref): (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))), (tvm.relay.log2, lambda x: 1 / (np.log(2) * x)), (tvm.relay.log10, lambda x: 1 / (np.log(10) * x)), - (tvm.relay.cosh, lambda x: -1.0 * np.sinh(x)), - (tvm.relay.sinh, lambda x: np.cosh(x))]: + (tvm.relay.cosh, lambda x: np.sinh(x)), + (tvm.relay.sinh, lambda x: np.cosh(x)), + (tvm.relay.asin, lambda x: 1. / (1. - x**2) ** (1./2.)), + (tvm.relay.acos, lambda x: -1. / (1. - x**2.) ** (1./2.)), + (tvm.relay.acosh, lambda x: 1./ (x**2 - 1.)**(1./2.)), + (tvm.relay.asinh, lambda x: 1./ (x**2 + 1.)**(1./2.)), + (tvm.relay.atanh, lambda x: -1./ (x**2 - 1.))]: check_single_op(opfunc, ref) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index bbe2c69d62942..947a4bfd0b3b1 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") dtype = "bool" if ref_func in [np.all, np.any] else dtype x = relay.var("x", relay.TensorType(data, dtype)) - z = test_func(x, axis, keepdims, exclude) + if test_func == relay.logsumexp: + z = test_func(x, axis, keepdims) + else: + z = test_func(x, axis, keepdims, exclude) zz = run_infer_type(z) if axis: assert "axis=" in z.astext() @@ -215,6 +218,14 @@ def _wrapper(data, axis=None, keepdims=False): return func(data, axis=axis).reshape(out_shape) return _wrapper + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [[relay.sum, np.sum], [relay.max, np.max], @@ -225,6 +236,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any], + [relay.logsumexp, _np_log_sum_exp], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 03ab9eeb13218..fb60e98052064 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -144,7 +144,32 @@ def test_same_i_qnn_params(): op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_equal(op_res.asnumpy(), golden_output) +def test_call_input(): + # This tests the case where the input to concatenate is not explicitly a + # tuple node but is instead a call node. + x_data = np.ones(shape=(64,)).astype('uint8') + + x = relay.var("x", shape=(64,), dtype='uint8') + x_scale = relay.const(1, 'float32') + y_scale = relay.const(1, 'float32') + x_zero_point = relay.const(0, 'int32') + y_zero_point = relay.const(0, 'int32') + + tup = relay.split(x, 2, axis=0) + z = relay.qnn.op.concatenate(tup, + input_scales=(x_scale, y_scale), + input_zero_points=(x_zero_point, y_zero_point), + output_scale=y_scale, + output_zero_point=relay.const(0, 'int32'), + axis=0) + func = relay.Function([x], z) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_equal(op_res.asnumpy(), x_data) + if __name__ == '__main__': + test_call_input() test_same_io_qnn_params() test_different_io_qnn_params() test_few_same_io_qnn_params() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 9b18f72cb6e77..bc0420f26d9ba 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -153,6 +153,56 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_lrn(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias") + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + new_attrs['kernel_layout'] = 'OIHW16i' + return relay.nn.conv2d(data, weight, **new_attrs) + + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OIHW16i") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + kernel_layout="OIHW16i", + data_layout="NCHW16c") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_alter_layout_dual_path(): """ @@ -1027,6 +1077,7 @@ def expected(): test_alter_return_none() test_alter_layout() test_alter_layout_dual_path() + test_alter_layout_lrn() test_alter_layout_resnet() test_alter_layout_broadcast_op() test_alter_layout_broadcast_scalar_op() diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index b212b26c99a7e..a981667219cda 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body +def test_concatenate_const(): + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) + const = relay.const(data) + concat = relay.op.concatenate([const, const], axis=0) + func = relay.Function([], concat) + return func + + def expected(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) + const = relay.const(data) + func = relay.Function([], const) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") t = relay.TensorType([1, 2, 3], "float32") diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2a4fd31041d76..d78b9eab873ef 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1155,6 +1155,42 @@ def expected(): partitioned = seq(mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) +def test_constant_tuples(): + @reg.register("qnn.concatenate", "target.const_tuples") + def add(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + a = relay.var('a', shape=(10, 10), dtype="uint8") + b = relay.var('b', shape=(10, 10), dtype="uint8") + a1 = relay.abs(a) + + zeroi = relay.const(1, "int32") + zerof = relay.const(0, "float32") + con = relay.qnn.op.concatenate((a1, b), + input_scales=(zerof, zerof), + input_zero_points=(zeroi, zeroi), + output_scale=zerof, + output_zero_point=zeroi, + axis=1) + + f = relay.Function([a, b], con) + mod = tvm.IRModule.from_expr(f) + return mod + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget("const_tuples"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + partitioned = seq(create_graph()) + concat = partitioned["const_tuples_0"].body + assert type(concat.args[1]) == relay.Tuple + assert type(concat.args[2]) == relay.Tuple + assert type(concat.args[3]) == relay.Constant + assert type(concat.args[4]) == relay.Constant + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1171,3 +1207,4 @@ def expected(): test_multiple_use_of_an_output() test_duplicate_outputs() test_duplicate_merge_and_tuplegetitem() + test_constant_tuples() diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 0dcf1fb5344cc..179152273c006 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -126,7 +126,7 @@ def test_floormod_simplify(): x, y = te.var("x"), te.var("y") ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16), flm((x*4) + y + 12, 16)) - + ck.verify(flm(flm((x*4), 16), 8), flm(x, 2) * 4) def test_canonical_mixed(): diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 6efb67b19bad7..372f0e9ce7274 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -64,14 +64,14 @@ def test_deduce(): e2 = (tvm.te.max(5, a * 4) < 0) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" # expression containing variable a is on rhs e2 = (zero < tvm.te.max(5, a * 4)) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" e3 = (-b)+a*c-d res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) @@ -88,8 +88,8 @@ def test_deduce(): # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) - assert str(res5.max_value) == "neg_inf" - assert str(res5.min_value) == "pos_inf" + assert str(res5.max_value) == "neg_inf: handle" + assert str(res5.min_value) == "pos_inf: handle" # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) @@ -111,13 +111,13 @@ def test_deduce(): # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) - assert str(res9.max_value) == "neg_inf" - assert str(res9.min_value) == "pos_inf" + assert str(res9.max_value) == "neg_inf: handle" + assert str(res9.min_value) == "pos_inf: handle" # Unsatisfiable Mul in `EQ` res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf" - assert str(res10.min_value) == "pos_inf" + assert str(res10.max_value) == "neg_inf: handle" + assert str(res10.min_value) == "pos_inf: handle" def test_check(): diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e57dcef75994b..9919c7b96cf1f 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -90,6 +90,20 @@ def test_mod(): flm = tvm.te.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9)) + + floordiv = tvm.te.floordiv + z = te.var("z") + ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3)) + ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, + (0, 7)) + ck1 = IntSetChecker() + ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2)) + ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3)) def test_max_min(): diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py index 28fdb11c3de4b..bec74fb6644ca 100644 --- a/tests/python/unittest/test_runtime_micro.py +++ b/tests/python/unittest/test_runtime_micro.py @@ -25,8 +25,10 @@ from tvm.micro import create_micro_mod from tvm.relay.testing import resnet -# Use the host emulated micro device. -DEV_CONFIG = micro.device.host.default_config() +# # Use the host emulated micro device. +DEV_CONFIG_A = micro.device.host.generate_config() +DEV_CONFIG_B = micro.device.host.generate_config() +TARGET = 'c -device=micro_dev' def relay_micro_build(func, dev_config, params=None): """Create a graph runtime module with a micro device context from a Relay function. @@ -47,22 +49,41 @@ def relay_micro_build(func, dev_config, params=None): mod : tvm.runtime.Module graph runtime module for the target device """ - with tvm.target.build_config(disable_vectorize=True): - graph, c_mod, params = relay.build(func, target="c", params=params) - micro_mod = create_micro_mod(c_mod, dev_config) + disable_vectorize = tvm.target.build_config(disable_vectorize=True) + disable_fusion = relay.build_config(disabled_pass={'FuseOps'}) + with disable_vectorize, disable_fusion: + graph, c_mod, params = relay.build(func, target=TARGET, params=params) + micro_mod = micro.create_micro_mod(c_mod, dev_config) ctx = tvm.micro_dev(0) mod = graph_runtime.create(graph, micro_mod, ctx) mod.set_input(**params) return mod +GDB_INIT_TEMPLATE = """ +layout asm +target remote localhost:{gdb_port} +set $pc = UTVMInit +break UTVMDone +""" + + +def reset_gdbinit(): + if 'server_port' not in DEV_CONFIG_A: + return + gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR'] + with open(f'{gdb_init_dir}/.gdbinit', 'w') as f: + gdb_port = DEV_CONFIG_A['server_port'] - 3333 + f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port)) + + def test_alloc(): """Test tensor allocation on the device.""" if not tvm.runtime.enabled("micro_dev"): return shape = (1024,) dtype = "float32" - with micro.Session(DEV_CONFIG): + with micro.Session(DEV_CONFIG_A): ctx = tvm.micro_dev(0) np_tensor = np.random.uniform(size=shape).astype(dtype) micro_tensor = tvm.nd.array(np_tensor, ctx) @@ -76,6 +97,8 @@ def test_add(): shape = (1024,) dtype = "float32" + reset_gdbinit() + # Construct TVM expression. tvm_shape = tvm.runtime.convert(shape) A = te.placeholder(tvm_shape, name="A", dtype=dtype) @@ -86,14 +109,24 @@ def test_add(): func_name = "fadd" c_mod = tvm.build(s, [A, B, C], target="c", name=func_name) - with micro.Session(DEV_CONFIG): - micro_mod = create_micro_mod(c_mod, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) micro_func = micro_mod[func_name] ctx = tvm.micro_dev(0) - a = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) + + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) + b_np = np.random.uniform(size=shape).astype(dtype) + b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) micro_func(a, b, c) + + # ensure inputs weren't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + tvm.testing.assert_allclose( + b.asnumpy(), b_np) + # ensure output is correct tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy()) @@ -105,6 +138,8 @@ def test_workspace_add(): shape = (1024,) dtype = "float32" + reset_gdbinit() + # Construct TVM expression. tvm_shape = tvm.runtime.convert(shape) A = te.placeholder(tvm_shape, name="A", dtype=dtype) @@ -116,14 +151,19 @@ def test_workspace_add(): func_name = "fadd_two_workspace" c_mod = tvm.build(s, [A, C], target="c", name=func_name) - with micro.Session(DEV_CONFIG): - micro_mod = create_micro_mod(c_mod, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) micro_func = micro_mod[func_name] ctx = tvm.micro_dev(0) - a = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) micro_func(a, c) + # ensure input wasn't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + # ensure output is correct tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + 2.0) @@ -141,47 +181,74 @@ def test_graph_runtime(): z = relay.add(xx, relay.const(1.0)) func = relay.Function([x], z) - with micro.Session(DEV_CONFIG): - mod = relay_micro_build(func, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A): + mod = relay_micro_build(func, DEV_CONFIG_A) x_in = np.random.uniform(size=shape[0]).astype(dtype) mod.run(x=x_in) result = mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + mod.get_input(0).asnumpy(), x_in) tvm.testing.assert_allclose( result, x_in * x_in + 1.0) -def test_multiple_modules(): - """Test loading multiple modules on the device simultaneously.""" +def test_conv2d(): if not tvm.runtime.enabled("micro_dev"): return - shape = (1024,) - dtype = "float32" - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - # Construct Relay subtract program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.subtract(x, relay.const(1.0)) - sub_const_func = relay.Function([x], ret) + from tvm.relay import create_executor + from tvm.relay import transform - with micro.Session(DEV_CONFIG): - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) - sub_const_mod = relay_micro_build(sub_const_func, DEV_CONFIG) + dshape = (1, 4, 16, 16) + dtype = 'int8' + func_name = 'fused_nn_conv2d' - x_in = np.random.uniform(size=shape[0]).astype(dtype) - add_const_mod.run(x=x_in) - add_result = add_const_mod.get_output(0).asnumpy() - sub_const_mod.run(x=x_in) - sub_result = sub_const_mod.get_output(0).asnumpy() + reset_gdbinit() - tvm.testing.assert_allclose( - add_result, x_in + 1.0) - tvm.testing.assert_allclose( - sub_result, x_in - 1.0) + # Construct Relay program. + x = relay.var("x", shape=dshape, dtype=dtype) + conv_expr = relay.nn.conv2d( + x, relay.var("w"), + kernel_size=(3, 3), + padding=(1, 1), + channels=4) + func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr) + mod = tvm.IRModule.from_expr(func) + mod = transform.InferType()(mod) + + x_shape = list(map(lambda x: x.value, mod['main'].params[0].checked_type.shape)) + w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape)) + out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape)) + + with tvm.target.build_config(disable_vectorize=True): + graph, c_mod, params = relay.build(mod, target="c") + + with micro.Session(DEV_CONFIG_A): + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) + candidate_func_name = func_name + for i in range(100): + try: + micro_func = micro_mod[candidate_func_name] + break + except tvm.TVMError as e: + candidate_func_name = f'{func_name}_{i}' + else: + assert False + ctx = tvm.micro_dev(0) + + x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx) + w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx) + result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx) + micro_func(x_data, w_data, result) + + out_data = np.zeros(out_shape, dtype=dtype) + params = { 'x': x_data.asnumpy(), 'w': w_data.asnumpy() } + intrp = create_executor('debug') + expected_result = intrp.evaluate(mod['main'])(x_data, w_data) + + tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy()) def test_interleave_sessions(): @@ -196,8 +263,8 @@ def test_interleave_sessions(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) @@ -205,13 +272,13 @@ def test_interleave_sessions(): np_tensor_b = np.random.uniform(size=shape).astype(dtype) micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) with sess_a: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) add_const_mod.run(x=micro_tensor_a) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( add_result, np_tensor_a + 1.0) with sess_b: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B) add_const_mod.run(x=micro_tensor_b) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( @@ -230,15 +297,15 @@ def test_nested_sessions(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) with sess_b: np_tensor_b = np.random.uniform(size=shape).astype(dtype) micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) add_const_mod.run(x=micro_tensor_a) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( @@ -257,12 +324,12 @@ def test_inactive_session_use(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) with sess_b: # These objects belong to `sess_a`. @@ -272,12 +339,42 @@ def test_inactive_session_use(): add_result, np_tensor_a + 1.0) +# TODO add workspace alloc/free stress test + if __name__ == "__main__": test_alloc() + print() + print('finished alloc test') + input('[press enter to continue]') test_add() + print() + print('finished add test') + input('[press enter to continue]') test_workspace_add() + print() + print('finished workspace add test') + input('[press enter to continue]') test_graph_runtime() + print() + print('finished graph runtime test') + input('[press enter to continue]') + test_conv2d() + print() + print('finished conv2d test') + input('[press enter to continue]') test_multiple_modules() + print() + print('finished multiple modules test') + input('[press enter to continue]') test_interleave_sessions() + print() + print('finished interleaved sessions test') + input('[press enter to continue]') test_nested_sessions() + print() + print('finished nested sessions test') + input('[press enter to continue]') test_inactive_session_use() + print() + print('finished use inactive session test') + input('[press enter to continue]') diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index b61e6bb9fa016..17321bdeb2937 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -18,10 +18,12 @@ from tvm import te import tvm.testing import os +import stat import logging import time import multiprocessing +import pytest import numpy as np from tvm import rpc from tvm.contrib import util @@ -77,11 +79,9 @@ def remotethrow(name): f1 = client.get_function("rpc.test.addone") assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - try: + + with pytest.raises(tvm.error.RPCError): f3("abc") - assert False - except tvm.error.TVMError as e: - assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" @@ -101,6 +101,58 @@ def remote_array_func(y): fremote = remote.get_function("rpc.test.remote_array_func") fremote(r_cpu) + +def test_rpc_large_array(): + # testcase of large array creation + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a_np = np.ones((5041, 720)).astype('float32') + b_np = np.ones((720, 192)).astype('float32') + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + np.testing.assert_equal(a.asnumpy(), a_np) + np.testing.assert_equal(b.asnumpy(), b_np) + + +def test_rpc_echo(): + def check(remote): + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + with pytest.raises(RuntimeError): + raise_err = remote.get_function( + "testing.test_raise_error_callback")("RuntimeError") + raise_err() + + remote.cpu().sync() + with pytest.raises(AttributeError): + f3 = remote.system_lib()["notexist"] + + + temp = rpc.server._server_env([]) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + check(rpc.LocalSession()) + + check(client) + # Test minrpc server. + temp = util.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc("g++")(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec)) + # minrpc on the remote + server = rpc.Server("localhost") + client = rpc.connect( + server.host, server.port, + session_constructor_args=["rpc.PopenSession", + open(minrpc_exec, "rb").read()]) + check(client) + + def test_rpc_file_exchange(): if not tvm.runtime.enabled("rpc"): return @@ -114,14 +166,20 @@ def test_rpc_file_exchange(): def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) # graph - n = tvm.runtime.convert(1024) + n = tvm.runtime.convert(102) A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + def check_remote(remote): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -133,13 +191,45 @@ def check_remote(remote): f.export_library(path_dso) remote.upload(path_dso) f1 = remote.load_module("dev_lib.so") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + def check_minrpc(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None: + return + # export to minrpc + temp = util.tempdir() + f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") + path_minrpc = temp.relpath("dev_lib.minrpc") + f.export_library(path_minrpc, rpc.with_minrpc("g++")) + + with pytest.raises(RuntimeError): + rpc.PopenSession("filenotexist") + + # statrt the minrpc session. + remote = tvm.rpc.PopenSession(path_minrpc) + ctx = remote.cpu(0) + f1 = remote.system_lib() + + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) + time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) + cost = time_f(a, b).mean + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + # change to not executable + os.chmod(path_minrpc, stat.S_IRUSR) + with pytest.raises(RuntimeError): + rpc.PopenSession(path_minrpc) + + def check_remote_link_cl(remote): """Test function to run remote code such as cl @@ -174,8 +264,8 @@ def check_remote_link_cl(remote): fhost = remote.load_module("myadd.o") fdev = remote.load_module("myadd.cl") fhost.import_module(fdev) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) # Option 2: export library as a tar ball then handled by remote compiler @@ -183,13 +273,15 @@ def check_remote_link_cl(remote): f.export_library(path_tar) remote.upload(path_tar) fhost = remote.load_module("myadd.tar") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote(client) check_remote(rpc.LocalSession()) + check_remote(client) + check_minrpc() + def test_rpc_return_func(): @@ -204,6 +296,37 @@ def addone(x): assert fadd(12) == 22 +def test_rpc_session_constructor_args(): + # start server + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + def check_multi_hop(): + # use server0 as proxy to connect to server1 + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + + fecho = client.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + nd = tvm.nd.array([1,2,3], ctx=client.cpu(0)) + assert(nd.asnumpy()[1] == 2) + + def check_error_handling(): + with pytest.raises(tvm.error.RPCError): + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=["rpc.NonExistingConstructor"]) + + check_multi_hop() + check_error_handling() + + def test_rpc_return_ndarray(): # Use closure to check the ref counter correctness nd = tvm.nd.array(np.zeros(10).astype("float32")) @@ -221,6 +344,7 @@ def my_module(name): # start server server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") + m = client.get_function("rpc.test.remote_return_nd") get_arr = m("get_arr") ref_count = m("ref_count") @@ -315,6 +439,7 @@ def target(host, port, device_key, timeout): time.sleep(0.5) summary = client.summary() + assert summary['queue_info'][device_key]['free'] == 0 assert summary['queue_info'][device_key]['pending'] == 1 @@ -334,6 +459,8 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_echo() + test_rpc_session_constructor_args() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() @@ -344,3 +471,4 @@ def target(host, port, device_key, timeout): test_local_func() test_rpc_tracker_register() test_rpc_tracker_request() + test_rpc_large_array() diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 49a7933a0cacc..50705e86037f5 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -55,7 +55,12 @@ def check_cuda(dtype, n, lanes): check_cuda("float32", 64, 2) check_cuda("float32", 64, 3) check_cuda("float32", 64, 4) + check_cuda("int8", 64, 2) + check_cuda("int8", 64, 3) check_cuda("int8", 64, 4) + check_cuda("uint8", 64, 2) + check_cuda("uint8", 64, 3) + check_cuda("uint8", 64, 4) check_cuda("float16", 64, 2) check_cuda("float16", 64, 4) check_cuda("float16", 64, 6) @@ -112,15 +117,17 @@ def check_cuda(dtype, n, lanes): b = tvm.nd.empty((n,), B.dtype, ctx) fun(a,b) tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy()) + check_cuda("int8", 64, 2) + check_cuda("int8", 64, 3) + check_cuda("int8", 64, 4) check_cuda("int8", 64, 8) check_cuda("int8", 64, 16) -def test_cuda_make_int8x4(): - def check_cuda(n, value): +def test_cuda_make_int8(): + def check_cuda(n, value, lanes): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") return - lanes = 4 dtype = 'int8' ctx = tvm.gpu(0) A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype)) @@ -133,9 +140,15 @@ def check_cuda(n, value): a = tvm.nd.empty(np_a.shape, dtype, ctx) fun(a) np.testing.assert_equal(a.asnumpy(), np_a) - check_cuda(64, 0xAB) - check_cuda(64, 0) - check_cuda(64, -3) + check_cuda(64, 0xAB, 4) + check_cuda(64, 0, 4) + check_cuda(64, -3, 4) + check_cuda(64, 0xAB, 3) + check_cuda(64, 0, 3) + check_cuda(64, -3, 3) + check_cuda(64, 0xAB, 2) + check_cuda(64, 0, 2) + check_cuda(64, -3, 2) def test_cuda_inf_nan(): @@ -579,6 +592,8 @@ def check_cuda(dtype, n, l, padding, lanes): (0, 0)), mode='constant', constant_values=0) tvm.testing.assert_allclose(b.asnumpy(), ref) + check_cuda("int8", 64, 16, 3, 2) + check_cuda("uint8", 64, 16, 3, 2) check_cuda("int8", 64, 16, 3, 4) check_cuda("uint8", 64, 16, 3, 4) check_cuda("int32", 64, 16, 3, 4) @@ -589,7 +604,7 @@ def check_cuda(dtype, n, l, padding, lanes): test_cuda_vectorize_add() test_cuda_multiply_add() test_cuda_vectorize_load() - test_cuda_make_int8x4() + test_cuda_make_int8() test_cuda_inf_nan() test_cuda_shuffle() test_vectorized_casts() diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index a7e1e57481a72..c6591721d247a 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -43,6 +43,18 @@ def test_llvm_intrin(): fcode = tvm.build(mod, None, "llvm") +def test_llvm_void_intrin(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("uint8", name="A") + # Create an intrinsic that returns void. + x = tvm.tir.call_llvm_intrin('', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A) + ib.emit(x) + body = ib.get() + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) + fcode = tvm.build(mod, None, "llvm") + + def test_llvm_overloaded_intrin(): # Name lookup for overloaded intrinsics in LLVM 4- requires a name # that includes the overloaded types. diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 9e4d45e9efaa9..9b8d4061afb4b 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -286,8 +286,8 @@ def intrin_func(ins, outs, sp): stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 - assert str(stmt.body.body.value.args[3]) == "(i*i)" - assert str(stmt.body.body.value.args[4]) == "(i + j)" + assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" + assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" if __name__ == "__main__": test_singleton() diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py new file mode 100644 index 0000000000000..3893bb6befda4 --- /dev/null +++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_bound_tile_mod(): + def compute(M_tiles, N_tiles, factor, dtype): + # Algo + M = M_tiles * factor + N = N_tiles * factor + + A = tvm.te.placeholder((N, M), name='A', dtype=dtype) + C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C') + s = tvm.te.create_schedule(C.op) + + return s, A, C + + def schedule(s, factor, padding, A, C): + C_local = s.cache_write(C, "local") + + n, m = C.op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + nio, nii = s[C].split(ni, 2) + n = s[C].fuse(nii, mi) + C_shared = s.cache_write(C, "shared") + bn, bm, ni, mi = C_shared.op.axis + s[C_shared].storage_align(ni, factor * 2, padding) + + n, m = s[C].op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + s[C].set_scope("global") + niio, niii = s[C].split(ni, 32) + s[C_shared].compute_at(s[C], niio) + + return s + + s, A, C = compute(2, 2, 128, "float16") + s = schedule(s, 128, 8, A, C) + bounds = tvm.te.schedule.InferBound(s) + check = (bounds[s.stages[2].op.axis[2]].extent == 16) + if(not check): + print(tvm.lower(s, [A, C], simple_mode=True)) + assert(check) + +if __name__ == "__main__": + test_bound_tile_mod() diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 61a522ccafff3..26bf80f5e1a55 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -63,6 +63,12 @@ def test_unary_intrin(): (tvm.tir.sinh, lambda x : np.sinh(x)), (tvm.tir.cosh, lambda x : np.cosh(x)), (tvm.tir.log1p, lambda x : np.log1p(x)), + (tvm.tir.asin, lambda x : np.arcsin(x)), + (tvm.tir.acos, lambda x : np.arccos(x)), + (tvm.tir.atan, lambda x : np.arctan(x)), + (tvm.tir.asinh, lambda x : np.arcsinh(x)), + (tvm.tir.acosh, lambda x : np.arccosh(x)), + (tvm.tir.atanh, lambda x : np.arctanh(x)), ] def run_test(tvm_intrin, np_func): m = te.var("m",) @@ -72,7 +78,7 @@ def run_test(tvm_intrin, np_func): f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx) b = tvm.nd.array( \ np.random.uniform(size=n).astype(A.dtype), ctx) f(a, b) diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 468ab1dbad6ac..36c9c764f6abc 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -103,7 +103,7 @@ def test_basic(): a = te.var('a') b = te.var('b') c = a + b - assert str(c) == '(%s + %s)' % (a.name, b.name) + assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name) def test_stmt(): @@ -138,11 +138,11 @@ def test_any(): assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % ( + assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) @@ -160,29 +160,29 @@ def test_all(): assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % ( + assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) def test_bitwise(): x = te.var('x') y = te.var('y') - assert str(x << y) == 'shift_left(x, y)' - assert str(x >> y) == 'shift_right(x, y)' - assert str(x & y) == 'bitwise_and(x, y)' - assert str(x | y) == 'bitwise_or(x, y)' - assert str(x ^ y) == 'bitwise_xor(x, y)' - assert str(10 & x) == 'bitwise_and(10, x)' - assert str(10 | x) == 'bitwise_or(10, x)' - assert str(10 ^ x) == 'bitwise_xor(10, x)' - assert str(10 >> x) == 'shift_right(10, x)' - assert str(10 << x) == 'shift_left(10, x)' - assert str(10 % x) == 'floormod(10, x)' - assert str(~x) == 'bitwise_not(x)' + assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 % x) == 'floormod(10, x: int32)' + assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,12 +239,12 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == 'isnan(x)' + assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == 'isnan(float32(y))' + assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)' z = te.var('z', 'int32') - assert str(tvm.tir.isnan(z)) == '(bool)0' + assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') assert str(tvm.tir.isnan(k).dtype) == 'uint1x2' diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 51be480a7cba8..bd553772e087d 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,6 +47,42 @@ def test_lower_warp_memory_local_scope(): assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) +def test_lower_warp_memory_correct_indices(): + n = 32 + A = te.placeholder((2, n, n), name='A', dtype="float32") + C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C') + + s = te.create_schedule(C.op) + bk_x = te.thread_axis("blockIdx.x") + th_y = te.thread_axis("threadIdx.y") + th_x = te.thread_axis("threadIdx.x") + B = s.cache_read(A, "warp", [C]) + cx, ci, cj = C.op.axis + bx, bi, bj = B.op.axis + s[C].bind(cj, th_x) + s[C].bind(cx, bk_x) + s[B].compute_at(s[C], cx) + s[B].bind(bi, th_y) + s[B].bind(bj, th_x) + + bounds = tvm.te.schedule.InferBound(s) + ir = tvm.te.schedule.ScheduleOps(s, bounds) + inner_func = ir.body.body.body.body + store_A_warp = inner_func.body.seq[0].body.body + indices = list(store_A_warp.args) + + # A.warp is actually many buffers, one for each warp, although they are all called A.warp + # 1. If we are accessing from different threads within a same warp (different + # threadIdx.x), we need to distinguish between each elements using threadIdx.x, + # so threadIdx.x is one if the indices. + # 2. If we are accessing from different warps (different threadIdx.y), we are actually + # assessing different buffers, so there is no need to distinguish from elements, + # and therefore threadIdx.y is NOT a index. + idx_names = map(lambda x: x.name, + filter(lambda x: type(x) is tvm.tir.expr.Var, indices)) + assert "threadIdx.x" in idx_names + assert "threadIdx.y" not in idx_names + def test_lower_warp_memory_cuda_end_to_end(): def check_cuda(dtype): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): @@ -182,6 +218,7 @@ def check_cuda(dtype): if __name__ == "__main__": test_lower_warp_memory_local_scope() + test_lower_warp_memory_correct_indices() test_lower_warp_memory_cuda_end_to_end() test_lower_warp_memory_cuda_half_a_warp() test_lower_warp_memory_cuda_2_buffers() diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 819961dc6ebdb..41006f41f754c 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -43,9 +43,6 @@ cd .. make doc rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5 -# JS doc -jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md - # Java doc make javadoc @@ -54,7 +51,6 @@ rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo mv docs/doxygen/html _docs/doxygen -mv out _docs/jsdoc mv jvm/core/target/site/apidocs _docs/javadoc echo "Start creating the docs tarball.." diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index fae07d34e992c..5529632e50ea1 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -54,6 +54,12 @@ cd tests/test_tvm_dso cargo run cd - +# # run wasm32 test +# cd tests/test_wasm32 +# cargo build +# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm +# cd - + # run nn graph test cd tests/test_nn cargo run diff --git a/tests/scripts/task_sphinx_precheck.sh b/tests/scripts/task_sphinx_precheck.sh index 0c82b2cb97ee4..a5a0a1764a09e 100755 --- a/tests/scripts/task_sphinx_precheck.sh +++ b/tests/scripts/task_sphinx_precheck.sh @@ -42,7 +42,7 @@ cd docs make clean TVM_TUTORIAL_EXEC_PATTERN=none make html 2>/tmp/$$.log.txt -grep -v -E "__mro__|RemovedIn|UserWarning|FutureWarning|Keras" < /tmp/$$.log.txt > /tmp/$$.logclean.txt || true +grep -v -E "__mro__|UserWarning|FutureWarning|tensorflow|Keras|pytorch|TensorFlow" < /tmp/$$.log.txt > /tmp/$$.logclean.txt || true echo "---------Sphinx Log----------" cat /tmp/$$.logclean.txt echo "-----------------------------" diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js deleted file mode 100644 index d239f7346e74e..0000000000000 --- a/tests/web/test_packed_func.js +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -function testGetGlobal() { - var targs = [10, 10.0, "hello"] - tvm.registerFunc("my_packed_func", function () { - tvm.assert(Array.from(arguments).toString() == targs, "assert fail"); - return 10 - }); - var f = tvm.getGlobalFunc("my_packed_func") - tvm.assert(tvm.isPackedFunc(f)); - y = f.apply(null, targs); - tvm.assert(y == 10); - f.release(); -} - - -function testReturnFunc() { - function addy(y) { - function add(x) { - return x + y; - } - return add; - } - var myf = tvm.convertFunc(addy); - var f = myf(10); - tvm.assert(tvm.isPackedFunc(f)); - tvm.assert(f(11) == 21); - myf.release(); - f.release(); -} - -function testByteArray() { - var a = new Uint8Array(3); - a[0] = 1; - a[1] = 2; - function myfunc(ss){ - tvm.assert(ss instanceof Uint8Array); - tvm.assert(ss.toString() == a); - } - f = tvm.convertFunc(myfunc); - f(a); - f.release(); -} - -testGetGlobal(); -testReturnFunc(); -testByteArray(); diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py deleted file mode 100644 index 6bd22bf0057b8..0000000000000 --- a/tests/webgl/test_local_gemm.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_gemm(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - nn = 1024 - n = te.var('n') - n = tvm.runtime.convert(nn) - m = n - l = n - A = te.placeholder((n, l), name='A', dtype='int32') - B = te.placeholder((m, l), name='B', dtype='int32') - k = te.reduce_axis((0, l), name='k') - C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), - name='CC') - - s = te.create_schedule(C.op) - s[C].opengl() - print(tvm.lower(s, [A, B, C], simple_mode=True)) - - f = tvm.build(s, [A, B, C], "opengl", name="gemm") - print("------opengl code------") - print(f.imported_modules[0].get_source(fmt="gl")) - - ctx = tvm.opengl() - n, m, l = nn, nn, nn - a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype) - b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - - tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T)) - -if __name__ == "__main__": - test_local_gemm() diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py deleted file mode 100644 index cca68020c0c23..0000000000000 --- a/tests/webgl/test_local_save_load.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -def test_local_save_load(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype='int32') - B = te.placeholder((n,), name='B', dtype='int32') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - - f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx) - f(a, b, c) - - temp = util.tempdir() - path_so = temp.relpath("myadd.so") - f.export_library(path_so) - f1 = tvm.runtime.load_module(path_so) - f1(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - test_local_save_load() diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py deleted file mode 100644 index 0d9b7776096a2..0000000000000 --- a/tests/webgl/test_local_topi_conv2d_nchw.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example code to do convolution. -Copied from topi/tests/python/test_topi_conv2d_nchw.py. -Should be removed once we fix OpenGL testing on Jenkins.""" -import os -import numpy as np -import tvm -from tvm import te -import topi -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name='A') - W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d_nchw(A, W, stride, padding) - C = topi.nn.relu(B) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - - -def test_conv2d_nchw(): - # ResNet18 worklaods - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d_nchw() diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py deleted file mode 100644 index 60dfe1ff690f0..0000000000000 --- a/tests/webgl/test_local_topi_dense.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for dense operator -Copied from topi/tests/python/test_topi_dense.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -from topi.util import get_const_tuple -from tvm.contrib.pickle_memoize import memoize - - -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - C = te.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_dense(D) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(c_np, ctx) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - - -if __name__ == "__main__": - test_dense() diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py deleted file mode 100644 index 3adae7bba51c2..0000000000000 --- a/tests/webgl/test_local_topi_pooling.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for pooling -Copied from topi/tests/python/test_topi_pooling.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -import math -from topi.util import get_const_tuple - -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): - iw = ih - kw = kh - sw = sh - ph, pw = padding - A = te.placeholder((n, ic, ih, iw), name='A') - B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, - pool_type=pool_type, ceil_mode=ceil_mode) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - - - a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - elif pool_type =='max': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B], simple_mode=True)) - - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - - - -def verify_global_pool(n, c, h, w, pool_type): - A = te.placeholder((n, c, h, w), name='A') - B = topi.nn.global_pool(A, pool_type=pool_type) - B = topi.nn.relu(B) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - if pool_type == 'avg': - b_np = np.mean(a_np, axis=(2,3), keepdims=True) - elif pool_type =='max': - b_np = np.max(a_np, axis=(2,3), keepdims=True) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_global_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - - -if __name__ == "__main__": - test_pool() - test_global_pool() diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py deleted file mode 100644 index c0ddbf21419ac..0000000000000 --- a/tests/webgl/test_local_topi_softmax.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for softmax -Copied from topi/tests/python/test_topi_softmax.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" - -import os -import numpy as np -import tvm -from tvm import te -import topi -import logging -from topi.util import get_const_tuple - -def verify_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - - -def verify_log_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.log_softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="log_softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - - -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py deleted file mode 100644 index 34bbb3fa0f002..0000000000000 --- a/tests/webgl/test_remote_save_load.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The following instruction is based on web/README.md. - -Setup an RPC server: -$ python -m tvm.exec.rpc_proxy --example-rpc=1 - -Go to http://localhost:9190 in browser. - -Click "Connect To Proxy". - -Run this test script: -$ python tests/webgl/test_remote_save_load.py -""" - -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -proxy_host = "localhost" -proxy_port = 9090 - -def try_remote_save_load(): - if not tvm.runtime.enabled("rpc"): - return - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - # Build the module. - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd") - - remote = rpc.connect(proxy_host, proxy_port, key="js") - - temp = util.tempdir() - ctx = remote.opengl(0) - path_obj = temp.relpath("myadd.bc") - path_dso = temp.relpath("myadd.js") - path_gl = temp.relpath("myadd.gl") - path_json = temp.relpath("myadd.tvm_meta.json") - - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - f.imported_modules[0].save(path_gl) - - remote.upload(path_dso, "myadd.dso") - remote.upload(path_gl) - remote.upload(path_json) - - remote.download("myadd.dso") - remote.download("myadd.gl") - remote.download("myadd.tvm_meta.json") - - print('Loading myadd.dso') - fhost = remote.load_module("myadd.dso") - - print('Loading myadd.gl') - fdev = remote.load_module("myadd.gl") - - print('import_module') - fhost.import_module(fdev) - - print('running...') - a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx) - c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx) - fhost(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - try_remote_save_load() diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html deleted file mode 100644 index f9268c65edf34..0000000000000 --- a/tests/webgl/test_static_webgl_library.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - -

TVM Test Page

-
- - - - - - - - \ No newline at end of file diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py deleted file mode 100644 index 929da4ca294cd..0000000000000 --- a/tests/webgl/test_static_webgl_library.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Create a static WebGL library and run it in the browser.""" - -from __future__ import absolute_import, print_function - -import os, shutil, SimpleHTTPServer, SocketServer -import tvm -from tvm import te -from tvm.contrib import emscripten, util -import numpy as np - -def try_static_webgl_library(): - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - - # Change to lib/ which contains "libtvm_runtime.bc". - os.chdir(os.path.join(curr_path, "../../lib")) - - # Create OpenGL module. - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="float") - B = te.compute((n,), lambda *i: A[i], name="B") - - s = te.create_schedule(B.op) - s[B].opengl() - - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B], name="identity", target="opengl", - target_host=target_host) - - # Create a JS library that contains both the module and the tvm runtime. - path_dso = "identity_static.js" - f.export_library(path_dso, emscripten.create_js, options=[ - "-s", "USE_GLFW=3", - "-s", "USE_WEBGL2=1", - "-lglfw", - ]) - - # Create "tvm_runtime.js" and "identity_static.html" in lib/ - shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"), - "tvm_runtime.js") - shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"), - "identity_static.html") - - port = 8080 - handler = SimpleHTTPServer.SimpleHTTPRequestHandler - httpd = SocketServer.TCPServer(("", port), handler) - print("Please open http://localhost:" + str(port) + "/identity_static.html") - httpd.serve_forever() - -if __name__ == "__main__": - try_static_webgl_library() diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 98614c3d49031..1b36ace4608f7 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -28,8 +28,8 @@ #include #include -#include #include +#include namespace topi { @@ -49,8 +49,8 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { CHECK_GE(output_shape.size(), t->shape.size()) - << "Not a broadcast, output dimensionality smaller than input.\noutput: " - << output_shape << "\nvs\ninput: " << t; + << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape + << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); CHECK_EQ(output_shape.size(), bh.common_shape.size()); for (size_t i = 0; i < output_shape.size(); ++i) { @@ -59,57 +59,39 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \ - const tvm::PrimExpr& b) { \ - ComputeRule; \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A(i), B); \ - }, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A, B(i)); \ - }, name, tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ } - -#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B) { \ - return topi::OpName(A, B); \ +#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \ + return topi::OpName(A, B); \ } /*! diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index f2ed029f5b33e..0c04eaf31535a 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -24,8 +24,8 @@ #ifndef TOPI_CONTRIB_CUBLAS_H_ #define TOPI_CONTRIB_CUBLAS_H_ -#include #include +#include namespace topi { namespace contrib { @@ -33,65 +33,51 @@ using namespace tvm; using namespace tvm::te; using namespace topi::detail; /*! -* \brief Create an op that multiplies lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } /*! -* \brief Create an op that multiplies batch matrices -* lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_batch_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies batch matrices + * lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto b = lhs->shape[0]; auto n = transa ? lhs->shape[2] : lhs->shape[1]; auto m = transb ? rhs->shape[1] : rhs->shape[2]; - return make_extern( - { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.batch_matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + return make_extern({{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.cublas.batch_matmul"), + pack_buffer(ins[0]), pack_buffer(ins[1]), + pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index f0bf92678f9a4..3baf1057312ca 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -25,6 +25,7 @@ #define TOPI_CONTRIB_ROCBLAS_H_ #include + #include "topi/detail/extern.h" namespace topi { @@ -32,33 +33,26 @@ namespace contrib { using namespace tvm; using namespace tvm::te; /*! -* \brief Create an op that multiplies lhs and rhs with rocBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor rocblas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with rocBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.rocblas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImmNode::make("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 1f0701e1aa380..145d249f90c69 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -24,14 +24,14 @@ #ifndef TOPI_CUDA_DENSE_H_ #define TOPI_CUDA_DENSE_H_ -#include -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -39,21 +39,19 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Implementation of dense for CUDA backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_cuda(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for CUDA backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +66,8 @@ inline tvm::te::Tensor dense_cuda(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::cublas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +77,15 @@ inline tvm::te::Tensor dense_cuda(const Target& target, } /*! -* \brief Create a CUDA schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "cuda" && - target->libs().count("cublas")) { + * \brief Create a CUDA schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "cuda" && target->libs().count("cublas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h index a7792a5f4f1bf..5a5c5af373493 100644 --- a/topi/include/topi/cuda/injective.h +++ b/topi/include/topi/cuda/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_INJECTIVE_H_ #define TOPI_CUDA_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -63,7 +63,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/cuda/normalization.h b/topi/include/topi/cuda/normalization.h index bfc209db213be..f8f498eaffcf3 100644 --- a/topi/include/topi/cuda/normalization.h +++ b/topi/include/topi/cuda/normalization.h @@ -24,20 +24,20 @@ #ifndef TOPI_CUDA_NORMALIZATION_H_ #define TOPI_CUDA_NORMALIZATION_H_ +#include +#include #include #include -#include -#include namespace topi { using namespace tvm; using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ inline Schedule schedule_lrn(const Array& outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h index 75b66b3a7c9d3..87866f2c69332 100644 --- a/topi/include/topi/cuda/pooling.h +++ b/topi/include/topi/cuda/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_CUDA_POOLING_H_ #define TOPI_CUDA_POOLING_H_ +#include +#include +#include +#include #include #include -#include -#include -#include -#include namespace topi { using namespace tvm; @@ -38,14 +38,14 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -105,14 +105,14 @@ inline Schedule schedule_pool(const Target &target, const Array& outs) { } /*! -* \brief Create a CUDA schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -142,7 +142,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array& s[out].split(i, num_thread, &by, &ty); IterVar bx, tx; s[out].split(c, num_thread, &bx, &tx); - s[out].reorder({ by, bx, ty, tx }); + s[out].reorder({by, bx, ty, tx}); s[out].bind(ty, thread_y); s[out].bind(tx, thread_x); s[out].bind(by, block_y); diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index add8d99aec914..35ce346eaaeeb 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_REDUCTION_H_ #define TOPI_CUDA_REDUCTION_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -45,10 +45,8 @@ namespace cuda { * an index, such as argmax or argmin. * * \return The schedule given by sch -*/ -Schedule ScheduleReduce(const Target& target, - Operation op, - Schedule sch, + */ +Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, bool is_idx_reduce = false) { Tensor data_out; Tensor data_in; @@ -61,8 +59,8 @@ Schedule ScheduleReduce(const Target& target, } auto out_stage = sch[data_out]; - CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) << - "reduce_axis must be greater than zero"; + CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) + << "reduce_axis must be greater than zero"; bool all_reduce; int num_thread; @@ -120,10 +118,8 @@ Schedule ScheduleReduce(const Target& target, } } else { if (is_idx_reduce) { - sch[temp_idx_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); - sch[temp_val_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); + sch[temp_idx_input].compute_at(stage_real, stage_real->op.as()->axis[0]); + sch[temp_val_input].compute_at(stage_real, stage_real->op.as()->axis[0]); } } @@ -152,13 +148,13 @@ void TraverseBeforeReduce(Schedule s, Operation op) { } /*! -* \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each -* of the op's inputs. -* -* \param target The target to generate a schedule for. -* \param s The schedule we are building -* \param op The reduce op -*/ + * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each + * of the op's inputs. + * + * \param target The target to generate a schedule for. + * \param s The schedule we are building + * \param op The reduce op + */ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { if (is_broadcast(op->tag)) { LOG(ERROR) << "Elementwise op after reduce is not yet supported"; @@ -178,13 +174,13 @@ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { } /*! -* \brief Create a CUDA schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { CHECK_EQ(outs.size(), 1) << "outs must have size 1"; Array out_ops; diff --git a/topi/include/topi/cuda/softmax.h b/topi/include/topi/cuda/softmax.h index 4c88c3e9eddf3..a3aa857d8c0c4 100644 --- a/topi/include/topi/cuda/softmax.h +++ b/topi/include/topi/cuda/softmax.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_SOFTMAX_H_ #define TOPI_CUDA_SOFTMAX_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -44,7 +44,7 @@ namespace cuda { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/detail/array_utils.h b/topi/include/topi/detail/array_utils.h index 3a3453a7baf4e..d7204722c4f60 100644 --- a/topi/include/topi/detail/array_utils.h +++ b/topi/include/topi/detail/array_utils.h @@ -39,7 +39,7 @@ using namespace tvm::te; * * \return True iff the given array contains the given item. */ -template +template inline bool contains(Array array, T item) { for (auto& i : array) { if (i == item) { diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 8622920dc374f..ca30293278750 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -24,8 +24,8 @@ #ifndef TOPI_DETAIL_BROADCAST_H_ #define TOPI_DETAIL_BROADCAST_H_ -#include #include +#include #include #include @@ -77,10 +77,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else { - CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] - << " and " << shape2[s2_size - i] << " in: " - << tvm::Array(shape1.begin(), shape1.end()) - << " and " + CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " + << shape2[s2_size - i] + << " in: " << tvm::Array(shape1.begin(), shape1.end()) << " and " << tvm::Array(shape2.begin(), shape2.end()); } } @@ -97,10 +96,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, - const tvm::te::Tensor& T, - const std::deque& my_vars, - const std::deque& all_vars) { + const tvm::Array& ovars, const tvm::te::Tensor& T, + const std::deque& my_vars, const std::deque& all_vars) { tvm::Array ivars; CHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. @@ -125,21 +122,16 @@ inline tvm::Array InputIndexFromBroadcast( } template -inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, - const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - const std::string& name = "tensor", - const std::string& tag = "") { +inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, + const tvm::te::Tensor& B, const std::string& name = "tensor", + const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); auto l = [&](tvm::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } } // namespace detail diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index afa8833b3490b..9bd1251199878 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -24,10 +24,10 @@ #ifndef TOPI_DETAIL_CONSTANT_UTILS_H_ #define TOPI_DETAIL_CONSTANT_UTILS_H_ -#include #include -#include #include +#include +#include #include #include @@ -44,10 +44,7 @@ using namespace tvm::te; * * \return true if the given expr is a constant int or uint, false otherwise. */ -inline bool IsConstInt(PrimExpr expr) { - return - expr->IsInstance(); -} +inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } /*! * \brief Get the value of the given constant integer expression. An error @@ -74,13 +71,11 @@ inline int64_t GetConstInt(PrimExpr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues( - Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - CHECK(IsConstInt(expr)) << "All elements of " - << var_name << " must be constant integers"; + CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; @@ -95,8 +90,8 @@ inline std::vector GetConstIntValues( * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values( - Array exprs, const std::string& var_name) { +inline std::vector GetConstInt64Values(Array exprs, + const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -107,8 +102,8 @@ inline std::vector GetConstInt64Values( } /*! - * \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again - * \note This is stronger equality check than tvm::tir::Equal + * \brief Check weather the two expressions are equal or not, if not simplify the expressions and + * check again \note This is stronger equality check than tvm::tir::Equal * * \param lhs First expreesion * \param rhs Second expreesion @@ -120,7 +115,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero); + result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); } return result; } diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index ab83200ff387f..e6ede6a32d77d 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,9 +25,9 @@ #define TOPI_DETAIL_EXTERN_H_ #include -#include -#include +#include +#include namespace topi { namespace detail { @@ -43,13 +43,11 @@ using namespace tvm::te; * * \return The Buffer object */ -inline Buffer DeclExternBuffer(Array shape, - DataType dtype, - std::string name) { +inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - -1, 0, kDefault); + return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, + kDefault); } /*! @@ -76,15 +74,12 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array< Array >& out_shapes, +inline Array make_extern(const Array >& out_shapes, const std::vector& out_types, - const Array& inputs, - FExtern fextern, - std::string name, - std::string tag, - ::tvm::Map attrs) { + const Array& inputs, FExtern fextern, std::string name, + std::string tag, ::tvm::Map attrs) { CHECK_EQ(out_shapes.size(), out_types.size()) - << "make_extern: out_shapes and out_types must have equal size"; + << "make_extern: out_shapes and out_types must have equal size"; Array input_placeholders; for (auto t : inputs) { @@ -98,9 +93,8 @@ inline Array make_extern(const Array< Array >& out_shapes, auto body = fextern(input_placeholders, output_placeholders); auto body_stmt = tvm::tir::EvaluateNode::make(body); - auto op = ExternOpNode::make( - name, tag, attrs, inputs, - input_placeholders, output_placeholders, body_stmt); + auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders, + body_stmt); Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { @@ -119,27 +113,25 @@ inline Array make_extern(const Array< Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + auto shape = + tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + strides = + tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; } - Array pack_args{ - buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset - }; + Array pack_args{buf->data, + shape, + strides, + make_const(DataType::Int(32), static_cast(buf->shape.size())), + make_const(buf->dtype, 0), + buf->elem_offset}; return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, - pack_args, tvm::tir::CallNode::CallType::Intrinsic); + pack_args, tvm::tir::CallNode::CallType::Intrinsic); } /*! @@ -152,8 +144,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, - args, tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + tvm::tir::CallNode::CallType::Intrinsic); } } // namespace detail diff --git a/topi/include/topi/detail/pad_utils.h b/topi/include/topi/detail/pad_utils.h index 1f2a7c5d4185b..7c416ecefb3c4 100644 --- a/topi/include/topi/detail/pad_utils.h +++ b/topi/include/topi/detail/pad_utils.h @@ -18,16 +18,17 @@ */ /*! -* \file pad_utils.h -* \brief Padding helpers -*/ + * \file pad_utils.h + * \brief Padding helpers + */ #ifndef TOPI_DETAIL_PAD_UTILS_H_ #define TOPI_DETAIL_PAD_UTILS_H_ -#include +#include +#include +#include -#include "tvm/tir/expr.h" -#include "tvm/tir/op.h" +#include namespace topi { namespace detail { @@ -50,7 +51,7 @@ inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { auto pad_top = indexdiv(pad_h + 1, 2); auto pad_left = indexdiv(pad_w + 1, 2); - return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left }; + return {pad_top, pad_left, pad_h - pad_top, pad_w - pad_left}; } } // namespace detail diff --git a/topi/include/topi/detail/ravel_unravel.h b/topi/include/topi/detail/ravel_unravel.h index ca46da0a56f2f..c87f2c997ca6e 100644 --- a/topi/include/topi/detail/ravel_unravel.h +++ b/topi/include/topi/detail/ravel_unravel.h @@ -18,9 +18,9 @@ */ /*! -* \file ravel_unravel.h -* \brief Index ravel and unraval operations -*/ + * \file ravel_unravel.h + * \brief Index ravel and unraval operations + */ #ifndef TOPI_DETAIL_RAVEL_UNRAVEL_H_ #define TOPI_DETAIL_RAVEL_UNRAVEL_H_ @@ -34,13 +34,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flatten the indices to 1D -* -* \param indices The input coordinates -* \param shape Shape of the tensor -* -* \return The index after flattening -*/ + * \brief Flatten the indices to 1D + * + * \param indices The input coordinates + * \param shape Shape of the tensor + * + * \return The index after flattening + */ inline PrimExpr RavelIndex(Array indices, Array shape) { CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; CHECK_GT(indices.size(), 0) << "indices must not be empty"; @@ -56,13 +56,13 @@ inline PrimExpr RavelIndex(Array indices, Array shape) { } /*! -* \brief Convert flattened index to coordinate array -* -* \param idx The 1D index -* \param shape Shape of the tensor -* -* \return The coordinate corresponding to the 1D index -*/ + * \brief Convert flattened index to coordinate array + * + * \param idx The 1D index + * \param shape Shape of the tensor + * + * \return The coordinate corresponding to the 1D index + */ inline Array UnravelIndex(PrimExpr idx, Array shape) { std::vector indices; diff --git a/topi/include/topi/detail/tensor_utils.h b/topi/include/topi/detail/tensor_utils.h index 6ac3982c3cf24..d144c75695edf 100644 --- a/topi/include/topi/detail/tensor_utils.h +++ b/topi/include/topi/detail/tensor_utils.h @@ -24,7 +24,6 @@ #ifndef TOPI_DETAIL_TENSOR_UTILS_H_ #define TOPI_DETAIL_TENSOR_UTILS_H_ - #include namespace topi { @@ -63,7 +62,7 @@ inline bool is_empty_shape(const Array& x) { * \return The interpolated value in the given index. */ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, - const PrimExpr max_y, const PrimExpr max_x) { + const PrimExpr max_y, const PrimExpr max_x) { auto in_y = indices[2]; auto yf = tvm::floor(in_y); auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); @@ -85,9 +84,7 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& auto C = input(indices[0], indices[1], y1, x0); auto D = input(indices[0], indices[1], y1, x1); - return A * ( 1 - x_lerp) * ( 1 - y_lerp) + - B * x_lerp * (1 - y_lerp) + - C * (1 - x_lerp) * y_lerp + + return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + D * x_lerp * y_lerp; } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 11eda8651cb4a..70daac2a3339b 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -24,10 +24,12 @@ #ifndef TOPI_ELEMWISE_H_ #define TOPI_ELEMWISE_H_ -#include #include +#include + #include #include + #include "broadcast.h" namespace topi { @@ -35,13 +37,11 @@ using namespace tvm; using namespace tvm::te; // Unary intrinsic operators -#define TOPI_DECLARE_UNARY_OP(OpName) \ - inline Tensor OpName(const Tensor& x, \ - std::string name = "T_" #OpName, \ - std::string tag = kElementWise) { \ - return compute(x->shape, [&](const Array& i) { \ - return ::tvm::OpName(x(i)); \ - }, name, tag); \ +#define TOPI_DECLARE_UNARY_OP(OpName) \ + inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ + std::string tag = kElementWise) { \ + return compute( \ + x->shape, [&](const Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ } TOPI_DECLARE_UNARY_OP(exp); @@ -61,7 +61,12 @@ TOPI_DECLARE_UNARY_OP(cosh); TOPI_DECLARE_UNARY_OP(tan); TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(sinh); +TOPI_DECLARE_UNARY_OP(acos); +TOPI_DECLARE_UNARY_OP(acosh); +TOPI_DECLARE_UNARY_OP(asin); +TOPI_DECLARE_UNARY_OP(asinh); TOPI_DECLARE_UNARY_OP(atan); +TOPI_DECLARE_UNARY_OP(atanh); TOPI_DECLARE_UNARY_OP(isnan); TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(isfinite); @@ -71,9 +76,7 @@ TOPI_DECLARE_UNARY_OP(isinf); * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 */ -inline Tensor fast_tanh_float(const Tensor& in, - std::string name, - std::string tag) { +inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) { // Clamp the inputs to the range [-9, 9] since anything outside // this range is +/-1.0f in single-precision. auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0)); @@ -93,178 +96,171 @@ inline Tensor fast_tanh_float(const Tensor& in, auto beta_4 = make_const(in->dtype, 1.18534705686654e-04); auto beta_6 = make_const(in->dtype, 1.19825839466702e-06); - return compute(x->shape, - [&](const Array& i) { - auto x2 = x(i) * x(i); - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x(i) * p; - - auto q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - return p / q; - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + auto x2 = x(i) * x(i); + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x(i) * p; + + auto q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + return p / q; + }, + name, tag); } /*! -* \brief Creates an operation that returns hyperbolic tanh of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is tanh -*/ -inline Tensor fast_tanh(const Tensor& x, - std::string name = "T_fast_tanh", + * \brief Creates an operation that returns hyperbolic tanh of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is tanh + */ +inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); } else { // fallback to default implementation - return compute(x->shape, [&](const Array& i) { - return ::tvm::tanh(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::tanh(x(i)); }, name, tag); } } /*! -* \brief Creates an operation that returns identity of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the identity operation -*/ -inline Tensor identity(const Tensor& x, - std::string name = "T_identity", + * \brief Creates an operation that returns identity of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the identity operation + */ +inline Tensor identity(const Tensor& x, std::string name = "T_identity", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the negation of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the negation operation -*/ -inline Tensor negative(const Tensor& x, - std::string name = "T_negative", + * \brief Creates an operation that returns the negation of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the negation operation + */ +inline Tensor negative(const Tensor& x, std::string name = "T_negative", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return -x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return -x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the logical NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the logical NOT operation -*/ -inline Tensor logical_not(const Tensor& x, - std::string name = "T_logical_not", + * \brief Creates an operation that returns the logical NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the logical NOT operation + */ +inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return !x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return !x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the bitwise NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the bitwise NOT operation -*/ -inline Tensor bitwise_not(const Tensor& x, - std::string name = "T_bitwise_not", + * \brief Creates an operation that returns the bitwise NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the bitwise NOT operation + */ +inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return ~x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ~x(i); }, name, tag); } /*! -* \brief Returns the sign of the tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the sign -*/ -inline Tensor sign(const Tensor& x, - std::string name = "T_sign", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr zero = make_zero(x->dtype); - PrimExpr one = make_const(x->dtype, 1); - PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); - auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); - return s2; - }, name, tag); + * \brief Returns the sign of the tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sign + */ +inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr zero = make_zero(x->dtype); + PrimExpr one = make_const(x->dtype, 1); + PrimExpr minus_one = make_const(x->dtype, -1); + auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); + auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); + return s2; + }, + name, tag); } /*! -* \brief Creates an operation that returns rsqrt of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the rsqrt operation -*/ -inline Tensor rsqrt(const Tensor& x, - std::string name = "tensor", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr one = make_const(x->dtype, 1); - return one/tvm::sqrt(x(i)); - }, name, tag); + * \brief Creates an operation that returns rsqrt of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the rsqrt operation + */ +inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr one = make_const(x->dtype, 1); + return one / tvm::sqrt(x(i)); + }, + name, tag); } /*! -* \brief Creates an operation that clips each element of a tensor to -* the interval [a_min, a_max] -* -* \param x The input tensor -* \param a_min The inclusive lower bound of the interval -* \param a_max The inclusive upper bound of the interval -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the clip operation -*/ -inline Tensor clip(const Tensor& x, - const PrimExpr& a_min, - const PrimExpr& a_max, - std::string name = "T_clip", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto min_val = tvm::cast(x->dtype, a_min); - auto max_val = tvm::cast(x->dtype, a_max); - return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) - }, name, tag); + * \brief Creates an operation that clips each element of a tensor to + * the interval [a_min, a_max] + * + * \param x The input tensor + * \param a_min The inclusive lower bound of the interval + * \param a_max The inclusive upper bound of the interval + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the clip operation + */ +inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max, + std::string name = "T_clip", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + auto min_val = tvm::cast(x->dtype, a_min); + auto max_val = tvm::cast(x->dtype, a_max); + return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) + }, + name, tag); } /*! @@ -279,22 +275,23 @@ inline Tensor clip(const Tensor& x, * * \return A Tensor whose op member is the cast operation */ -inline Tensor cast(const Tensor& x, - DataType type, - std::string name = "T_cast", +inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto expr = x(i); - if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { - if (expr.dtype().lanes() == type.lanes()) { - return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { - return tvm::tir::BroadcastNode::make(expr, type.lanes()); - } - } - - return tvm::cast(type, x(i)); - }, name, tag); + return compute( + x->shape, + [&](const Array& i) { + auto expr = x(i); + if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { + if (expr.dtype().lanes() == type.lanes()) { + return expr; + } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + return tvm::tir::BroadcastNode::make(expr, type.lanes()); + } + } + + return tvm::cast(type, x(i)); + }, + name, tag); } /*! @@ -309,12 +306,13 @@ inline Tensor cast(const Tensor& x, */ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { - return compute(x->shape, - [&](const Array& i) { - return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, - tvm::tir::CallNode::PureIntrinsic); - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, + tvm::tir::CallNode::PureIntrinsic); + }, + name, tag); } /*! @@ -326,63 +324,58 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te * * \return A Tensor whose op member is the sum operation */ -inline Tensor elemwise_sum(const Array& xs, - std::string name = "T_elemwise_sum", +inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; - return compute(xs[0]->shape, [&](const Array& i) { - auto sum_expr = xs[0](i); - for (size_t j = 1; j < xs.size(); j++) { - sum_expr = sum_expr + xs[j](i); - } - return sum_expr; - }, name, tag); + return compute( + xs[0]->shape, + [&](const Array& i) { + auto sum_expr = xs[0](i); + for (size_t j = 1; j < xs.size(); j++) { + sum_expr = sum_expr + xs[j](i); + } + return sum_expr; + }, + name, tag); } /*! -* \brief Creates an operation that fill a tensor with fill_value -* -* \param shape The shape of a tensor -* \param dtype The Type of fill_value -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the full operation -*/ -inline Tensor full(const Array& shape, - DataType dtype, - const PrimExpr fill_value, - std::string name = "T_full", - std::string tag = kElementWise) { + * \brief Creates an operation that fill a tensor with fill_value + * + * \param shape The shape of a tensor + * \param dtype The Type of fill_value + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the full operation + */ +inline Tensor full(const Array& shape, DataType dtype, const PrimExpr fill_value, + std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } - return compute(shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + shape, [&](const Array& i) { return ev; }, name, tag); } /*! -* \brief Creates an operation that construct a tensor with same shape as input tensor, -* then fill a tensor with fill_value -* -* \param x The input tensor -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op memeber is the full_like operation -*/ -inline Tensor full_like(const Tensor& x, - const PrimExpr fill_value, - std::string name = "T_full_like", - std::string tag = kElementWise) { + * \brief Creates an operation that construct a tensor with same shape as input tensor, + * then fill a tensor with fill_value + * + * \param x The input tensor + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op memeber is the full_like operation + */ +inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, + std::string name = "T_full_like", std::string tag = kElementWise) { PrimExpr ev = cast(x->dtype, fill_value); - return compute(x->shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ev; }, name, tag); } /*! @@ -406,9 +399,7 @@ inline Tensor full_like(const Tensor& x, * Approximation for fractional part: * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ -inline Tensor fast_exp_float32(const Tensor& _x, - std::string name, - std::string tag) { +inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); @@ -423,25 +414,25 @@ inline Tensor fast_exp_float32(const Tensor& _x, auto one_half = make_const(DataType::Float(32), 0.5f); auto b = make_const(DataType::Float(32), 127.0f); - return compute(_x->shape, - [&](const Array& i) { - // clamp x - auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); - // integer part - auto n = ::tvm::floor(x * log2e + one_half); - // fractional part - auto f = x - n * ln2; - auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f - + p[5]) * f * f + f + one; - // Return 2^m * exp(r). - auto ef = tvm::reinterpret(DataType::Float(32), - ::tvm::cast(DataType::Int(32), n + b) << 23); - return ::tvm::max(ef * y, _x(i)); // NOLINT(*) - }, - name, tag); + return compute( + _x->shape, + [&](const Array& i) { + // clamp x + auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); + // integer part + auto n = ::tvm::floor(x * log2e + one_half); + // fractional part + auto f = x - n * ln2; + auto y = + (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one; + // Return 2^m * exp(r). + auto ef = + tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23); + return ::tvm::max(ef * y, _x(i)); // NOLINT(*) + }, + name, tag); } - /*! * \brief Fast exponential function implementation * @@ -452,16 +443,14 @@ inline Tensor fast_exp_float32(const Tensor& _x, * \return A Tensor whose op member is exponent operation * */ -inline Tensor fast_exp(const Tensor& x, - std::string name = "T_fast_exp", - std::string tag = kElementWise) { +inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", + std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_exp_float32(x, name, tag); return ret; } else { - return compute(x->shape, [&](const Array& i) { - return ::tvm::exp(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::exp(x(i)); }, name, tag); } } @@ -469,9 +458,7 @@ inline Tensor fast_exp(const Tensor& x, * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 */ -inline Tensor fast_erf_float32(const Tensor& data, - std::string name, - std::string tag) { +inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) { auto plus_4 = make_const(DataType::Float(32), 4.f); auto minus_4 = make_const(DataType::Float(32), -4.f); @@ -491,28 +478,31 @@ inline Tensor fast_erf_float32(const Tensor& data, auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f); auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f); - return compute(data->shape, [&](const Array &i) { - // clamp x - auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); - auto x2 = x * x; - - // Evaluate the numerator polynomial p. - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - - // Evaluate the denominator polynomial p. - auto q = x2 * beta_8 + beta_6; - q = x2 * q + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - - return p / q; - }, name, tag); + return compute( + data->shape, + [&](const Array& i) { + // clamp x + auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; + }, + name, tag); } /*! @@ -524,8 +514,7 @@ inline Tensor fast_erf_float32(const Tensor& data, * * \return A Tensor whose op member is erf operation */ -inline Tensor fast_erf(const Tensor& x, - std::string name = "T_fast_erf", +inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_erf_float32(x, name, tag); diff --git a/topi/include/topi/generic/default.h b/topi/include/topi/generic/default.h index 640ab95451419..c44bc69190f45 100644 --- a/topi/include/topi/generic/default.h +++ b/topi/include/topi/generic/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_DEFAULT_H_ #define TOPI_GENERIC_DEFAULT_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -36,13 +36,13 @@ using namespace tvm::te; namespace generic { /*! -* \brief Create a generic default schedule for the given output tensors. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a generic default schedule for the given output tensors. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ inline Schedule default_schedule(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { @@ -53,14 +53,14 @@ inline Schedule default_schedule(const Target& target, Array outs) { } /*! -* \brief Create a generic default schedule for the given output tensors, and apply -* auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a generic default schedule for the given output tensors, and apply + * auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ inline Schedule default_schedule_auto_inline(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/generic/extern.h b/topi/include/topi/generic/extern.h index e08158f297bef..6e5507b4e77b8 100644 --- a/topi/include/topi/generic/extern.h +++ b/topi/include/topi/generic/extern.h @@ -24,12 +24,12 @@ #ifndef TOPI_GENERIC_EXTERN_H_ #define TOPI_GENERIC_EXTERN_H_ -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -37,13 +37,13 @@ using namespace tvm::te; namespace generic { /*! -* \brief Schedule an extern op followed by injective operations -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the op. -*/ + * \brief Schedule an extern op followed by injective operations + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the op. + */ inline Schedule schedule_extern(const Target& target, Array outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/generic/injective.h b/topi/include/topi/generic/injective.h index 7a5aff7eaf800..69962dc645c0f 100644 --- a/topi/include/topi/generic/injective.h +++ b/topi/include/topi/generic/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_INJECTIVE_H_ #define TOPI_GENERIC_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 7569bb0a7ba2c..7fbe7eb83b3b3 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -24,12 +24,12 @@ #ifndef TOPI_NN_H_ #define TOPI_NN_H_ -#include #include +#include #include +#include #include #include -#include #include #include @@ -62,43 +62,38 @@ tvm::PrimExpr Map(const tvm::Array& exprs, T op) { * \return A Tensor whose op member is the relu operation */ template -inline tvm::te::Tensor relu(const tvm::te::Tensor& t, - T threshold = static_cast(0), - std::string name = "T_relu", - std::string tag = kElementWise) { +inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast(0), + std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, [&](const tvm::Array& i) { auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, - name, - tag); + name, tag); } /*! -* \brief Creates an operation that performs a leaky rectified linear unit -* -* \param t The input tensor -* \param alpha The slope for the small gradient when t < 0 -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the leaky relu operation -*/ -inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, - double alpha = 0.1, - std::string name = "T_leaky_relu", - std::string tag = kElementWise) { + * \brief Creates an operation that performs a leaky rectified linear unit + * + * \param t The input tensor + * \param alpha The slope for the small gradient when t < 0 + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the leaky relu operation + */ +inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, + std::string name = "T_leaky_relu", + std::string tag = kElementWise) { return tvm::te::compute( - t->shape, - [&](const tvm::Array& i) { - auto value = t(i); - auto calpha = tvm::tir::make_const(value.dtype(), alpha); - return tvm::tir::SelectNode::make(value > 0, value, value * calpha); - }, - name, - tag); + t->shape, + [&](const tvm::Array& i) { + auto value = t(i); + auto calpha = tvm::tir::make_const(value.dtype(), alpha); + return tvm::tir::SelectNode::make(value > 0, value, value * calpha); + }, + name, tag); } /*! @@ -112,27 +107,20 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, * * \return A Tensor whose op member is the parametric relu operation */ -inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, - const tvm::te::Tensor &slope, - const int axis = 1, - std::string name = "T_prelu", - std::string tag = kBroadcast) { - CHECK((size_t)axis < x->shape.size()) << - "Wrong axis (" << axis << ")value. "; - CHECK(topi::detail::GetConstInt(slope->shape[0]) == - topi::detail::GetConstInt(x->shape[axis])) - << "Wrong slope shape received."; +inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope, + const int axis = 1, std::string name = "T_prelu", + std::string tag = kBroadcast) { + CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; + CHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis])) + << "Wrong slope shape received."; - return tvm::te::compute(x->shape, - [&](const tvm::Array &indices) { - auto xval = x(indices); - return tvm::tir::SelectNode::make( - xval > 0, - xval, - xval * slope(indices[axis])); - }, - name, - tag); + return tvm::te::compute( + x->shape, + [&](const tvm::Array& indices) { + auto xval = x(indices); + return tvm::tir::SelectNode::make(xval > 0, xval, xval * slope(indices[axis])); + }, + name, tag); } /*! @@ -172,13 +160,10 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, * * */ -inline tvm::te::Tensor pad(const tvm::te::Tensor& t, - const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - PrimExpr pad_value = PrimExpr(), - std::string name = "T_pad", - std::string tag = kElementWise, - std::string pad_mode = "constant") { +inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, + tvm::Array pad_after = tvm::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", + std::string tag = kElementWise, std::string pad_mode = "constant") { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); @@ -190,10 +175,10 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, tvm::Array output_shape; tvm::Array pad_before_int32; tvm::Array pad_after_int32; - for (const auto &ele : pad_before) { + for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } - for (const auto &ele : pad_after) { + for (const auto& ele : pad_after) { pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } for (size_t i = 0; i < t->shape.size(); ++i) { @@ -228,28 +213,23 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - 0, - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] - 1, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], 0, + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] - 1, ovars[i] - pad_before[i]))); } else if (pad_mode == "reflect") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - pad_before[i] - ovars[i], - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i], + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, + ovars[i] - pad_before[i]))); } } if (sel.size() != 0) { if (pad_mode == "constant") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); + return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); + return tvm::if_then_else(detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); } } return t(indices); @@ -277,34 +257,27 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, * \return A Tensor whose op member is the 2-D convolution operation (NCHW * layout) */ -inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_nchw", - std::string tag = kConv2dNCHW) { +inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_nchw", + std::string tag = kConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - W->shape[0], // O - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W + I->shape[0], // B + W->shape[0], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H + indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), - {i, kh, kw}); + return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -328,14 +301,10 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D convolution operation * (HWCN layout) */ -inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_hwcn", - std::string tag = kConv2dHWCN) { +inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_hwcn", + std::string tag = kConv2dHWCN) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; @@ -343,22 +312,19 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, tvm::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W - I->shape[2], // B - W->shape[3] // O + I->shape[2], // B + W->shape[3] // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), - {i, kh, kw}); + return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } - /*! * \brief Creates an operation that performs a 2-D depthwise convolution with * an NCHW-layout @@ -379,67 +345,59 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D depthwise convolution operation * (NCHW layout) */ -inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nchw", - std::string tag = kDepthwiseConv2dNCHW) { +inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nchw", + std::string tag = kDepthwiseConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B - W->shape[1], // O + I->shape[0], // B + W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * - W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), + W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } -inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nhwc", - std::string tag = kDepthwiseConv2dNHWC) { +inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nhwc", + std::string tag = kDepthwiseConv2dNHWC) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B + I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W - W->shape[3], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W + W->shape[3], // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * - W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), + W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), {kh, kw, i}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -465,22 +423,19 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D groupconvolution operation * (NCHW layout) */ -inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_group_conv2d_ngchw", - std::string tag = kGroupConv2d) { +inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_group_conv2d_ngchw", + std::string tag = kGroupConv2d) { CHECK_EQ(5, I->shape.size()); CHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - I->shape[1], // G - W->shape[2], // O + I->shape[0], // B + I->shape[1], // G + W->shape[2], // O indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W }; @@ -497,9 +452,8 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, tvm::tir::Var o = args[2]; tvm::tir::Var h = args[3]; tvm::tir::Var w = args[4]; - return tvm::sum( - I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), - {i, kh, kw}); + return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), + {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } diff --git a/topi/include/topi/nn/batch_matmul.h b/topi/include/topi/nn/batch_matmul.h index 12075e6d67ea2..80525c4279767 100644 --- a/topi/include/topi/nn/batch_matmul.h +++ b/topi/include/topi/nn/batch_matmul.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_BATCH_MATMUL_H_ #define TOPI_NN_BATCH_MATMUL_H_ -#include #include +#include #include @@ -35,15 +35,14 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates matrix multiplication in batch. -* -* \param x Tensor with shape [batch, M, K] -* \param y Tensor with shape [batch, N, K] -* -* \return Tensor with shape [batch, M, N] -*/ -inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, - const tvm::te::Tensor& y) { + * \brief Creates an operation that calculates matrix multiplication in batch. + * + * \param x Tensor with shape [batch, M, K] + * \param y Tensor with shape [batch, N, K] + * + * \return Tensor with shape [batch, M, N] + */ +inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) { CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; @@ -54,10 +53,8 @@ inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, auto k = tvm::te::reduce_axis(Range(0, K), "k"); auto result = tvm::te::compute( - { batch, M, N }, - [&](Var b, Var i, Var j) { - return tvm::sum(x(b, i, k) * y(b, j, k), { k }); - }, "tensor", "batch_matmul"); + {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); }, + "tensor", "batch_matmul"); return result; } diff --git a/topi/include/topi/nn/bias_add.h b/topi/include/topi/nn/bias_add.h index 209c30ca875b4..18e95deaccb12 100644 --- a/topi/include/topi/nn/bias_add.h +++ b/topi/include/topi/nn/bias_add.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BIAS_ADD_H_ #define TOPI_NN_BIAS_ADD_H_ -#include -#include #include +#include #include +#include #include @@ -35,16 +35,15 @@ namespace topi { namespace nn { /*! -* \brief Creates an operation that calculates data + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param bias Tensor with shape [batch]. -* \param axis The axis to add the bias to. -* \return Tensor with shape [batch, in_dim] -*/ -inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, - const tvm::te::Tensor& bias, - int axis) { + * \brief Creates an operation that calculates data + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param bias Tensor with shape [batch]. + * \param axis The axis to add the bias to. + * \return Tensor with shape [batch, in_dim] + */ +inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, const tvm::te::Tensor& bias, + int axis) { int data_ndim = data->shape.size(); if (axis < 0) { axis += data_ndim; diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index c69fc5406e33b..c0626cd43c7f8 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BNN_H_ #define TOPI_NN_BNN_H_ -#include -#include -#include #include +#include +#include +#include #include @@ -37,71 +37,67 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Binarization and bit-packing along a certain axis. -* -* \param data N-D tensor, can be any layout -* \param axis The axis along which to do binarization and bit-packing. This axis -* must have a size equal to an integer multiple of 32. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return Output tensor with dtype uint32 -*/ -inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, - int axis, - std::string name = "PackedInput", - std::string tag = "binarize_pack") { + * \brief Binarization and bit-packing along a certain axis. + * + * \param data N-D tensor, can be any layout + * \param axis The axis along which to do binarization and bit-packing. This axis + * must have a size equal to an integer multiple of 32. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return Output tensor with dtype uint32 + */ +inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, + std::string name = "PackedInput", + std::string tag = "binarize_pack") { auto ishape = data->shape; CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) - << "binarize_pack: axis size must be a multiple of 32"; + << "binarize_pack: axis size must be a multiple of 32"; arith::Analyzer analyzer; auto n = ishape.size(); Array oshape; for (size_t i = 0; i < n; ++i) { - oshape.push_back(i == static_cast(axis) ? - analyzer.Simplify(indexdiv(ishape[i], 32)) : - ishape[i]); + oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) + : ishape[i]); } return tvm::te::compute( - oshape, - [&](const Array& indices) { - Array start_idx; - for (size_t i = 0; i < n; ++i) { - start_idx.push_back(i == static_cast(axis) ? - indices[i] * 32 : - static_cast(indices[i])); - } - auto packed = make_const(DataType::UInt(32), 0); - for (size_t j = 0; j < 32; ++j) { - Array idx; + oshape, + [&](const Array& indices) { + Array start_idx; for (size_t i = 0; i < n; ++i) { - idx.push_back(i == static_cast(axis) ? - start_idx[i] + static_cast(j) : - start_idx[i]); + start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 + : static_cast(indices[i])); } - auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); - packed = (packed | sign); - if (j == 31) { - return packed; + auto packed = make_const(DataType::UInt(32), 0); + for (size_t j = 0; j < 32; ++j) { + Array idx; + for (size_t i = 0; i < n; ++i) { + idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) + : start_idx[i]); + } + auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); + packed = (packed | sign); + if (j == 31) { + return packed; + } + packed = packed << 1; } - packed = packed << 1; - } - return packed; // never reached, but suppress compiler warning - }, name, tag); + return packed; // never reached, but suppress compiler warning + }, + name, tag); } /*! -* \brief Binary matrix multiplication using xor and bit-count -* -* \param data Tensor with shape [batch, in_dim], dtype is uint32 -* \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 -* -* \return Tensor with shape [batch, out_dim], dtype is float32 -*/ -inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight) { + * \brief Binary matrix multiplication using xor and bit-count + * + * \param data Tensor with shape [batch, in_dim], dtype is uint32 + * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 + * + * \return Tensor with shape [batch, out_dim], dtype is float32 + */ +inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; @@ -113,16 +109,13 @@ inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k }); - }, "tensor", "binary_dense"); + {batch, out_dim}, + [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor", + "binary_dense"); return tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return 32 * in_dim - 2.0f * matmul(i, j); - }, "tensor", kElementWise); + {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor", + kElementWise); } } // namespace nn diff --git a/topi/include/topi/nn/dense.h b/topi/include/topi/nn/dense.h index 57f071a2ebfd4..4ee36c275ef38 100644 --- a/topi/include/topi/nn/dense.h +++ b/topi/include/topi/nn/dense.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_DENSE_H_ #define TOPI_NN_DENSE_H_ -#include #include +#include #include @@ -35,19 +35,17 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates data * weight^T + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Creates an operation that calculates data * weight^T + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, + const tvm::te::Tensor& bias, const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -60,18 +58,17 @@ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(tvm::cast(out_dtype, data(i, k)) * - tvm::cast(out_dtype, weight(j, k)), { k }); - }, "tensor", "dense"); + {batch, out_dim}, + [&](Var i, Var j) { + return tvm::sum(tvm::cast(out_dtype, data(i, k)) * tvm::cast(out_dtype, weight(j, k)), {k}); + }, + "tensor", "dense"); if (bias.defined()) { matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return matmul(i, j) + tvm::cast(out_dtype, bias(j)); - }, "tensor", kBroadcast); + {batch, out_dim}, + [&](Var i, Var j) { return matmul(i, j) + tvm::cast(out_dtype, bias(j)); }, "tensor", + kBroadcast); } return matmul; diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index 32ee1392e8465..0d3ab89bbae69 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_DILATE_H_ #define TOPI_NN_DILATE_H_ -#include -#include #include +#include +#include #include @@ -36,13 +36,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Create a new expression of the logical and of all -* conditions in the arguments. -* -* \param args The arguments to find the logical conjunction of -* -* \return The logical conjunction expression -*/ + * \brief Create a new expression of the logical and of all + * conditions in the arguments. + * + * \param args The arguments to find the logical conjunction of + * + * \return The logical conjunction expression + */ PrimExpr all(Array args) { CHECK_GT(args.size(), 0) << "all requires at least one argument"; @@ -54,53 +54,50 @@ PrimExpr all(Array args) { } /*! -* \brief Dilate data with zeros -* -* \param x The input tensor, this can have any number of -* dimensions and any layout. -* \param strides Dilation stride for each dimension. Stride 1 -* means no dilation. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return The output tensor. -*/ -inline Tensor dilate(const Tensor& x, - Array strides, - std::string name = "tensor", + * \brief Dilate data with zeros + * + * \param x The input tensor, this can have any number of + * dimensions and any layout. + * \param strides Dilation stride for each dimension. Stride 1 + * means no dilation. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The output tensor. + */ +inline Tensor dilate(const Tensor& x, Array strides, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); - CHECK_EQ(n, strides.size()) - << "strides size (" << strides.size() - << ") must match dimension of x (" << n << ")"; + CHECK_EQ(n, strides.size()) << "strides size (" << strides.size() + << ") must match dimension of x (" << n << ")"; Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(analyzer.Simplify( - (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); + out_shape.push_back( + analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } return tvm::te::compute( - out_shape, - [&](const Array& indices) { - Array not_zero; - Array index_tuple; - for (size_t i = 0; i < n; ++i) { - if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { - index_tuple.push_back(indices[i]); - } else { - index_tuple.push_back(indexdiv(indices[i], strides[i])); - not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + out_shape, + [&](const Array& indices) { + Array not_zero; + Array index_tuple; + for (size_t i = 0; i < n; ++i) { + if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { + index_tuple.push_back(indices[i]); + } else { + index_tuple.push_back(indexdiv(indices[i], strides[i])); + not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + } + } + if (not_zero.size() > 0) { + auto all_not_zero = all(not_zero); + return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0)); } - } - if (not_zero.size() > 0) { - auto all_not_zero = all(not_zero); - return tvm::if_then_else( - all_not_zero, x(index_tuple), make_const(x->dtype, 0)); - } - return x(index_tuple); - }, name, tag); + return x(index_tuple); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index 81cef2eda17be..1ac5de4a2ed17 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_FLATTEN_H_ #define TOPI_NN_FLATTEN_H_ -#include -#include #include +#include +#include #include #include @@ -37,25 +37,23 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. -* This requires the input tensor to have constant sized dimensions. -* -* \param x The input tensor. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A 2-D tensor. -*/ -inline Tensor flatten(const Tensor& x, - std::string name = "tensor", - std::string tag = kInjective) { + * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. + * This requires the input tensor to have constant sized dimensions. + * + * \param x The input tensor. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A 2-D tensor. + */ +inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { auto ishape = x->shape; PrimExpr dim = 1; for (size_t i = 1; i < ishape.size(); ++i) { dim = dim * ishape[i]; } - Array oshape({ ishape[0], dim }); + Array oshape({ishape[0], dim}); std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { @@ -64,17 +62,19 @@ inline Tensor flatten(const Tensor& x, std::reverse(extra_shape.begin(), extra_shape.end()); return tvm::te::compute( - oshape, [&](Var i, Var j) { - PrimExpr idx = j; - std::vector index; - for (auto s : extra_shape) { - index.push_back(indexmod(idx, s)); - idx = indexdiv(idx, s); - } - index.push_back(i); - std::reverse(index.begin(), index.end()); - return x(index); - }, name, tag); + oshape, + [&](Var i, Var j) { + PrimExpr idx = j; + std::vector index; + for (auto s : extra_shape) { + index.push_back(indexmod(idx, s)); + idx = indexdiv(idx, s); + } + index.push_back(i); + std::reverse(index.begin(), index.end()); + return x(index); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h index 14dec390e24a7..4e8dfd99a517c 100644 --- a/topi/include/topi/nn/local_response_norm.h +++ b/topi/include/topi/nn/local_response_norm.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_ #define TOPI_NN_LOCAL_RESPONSE_NORM_H_ -#include #include +#include #include @@ -35,60 +35,45 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Local response normalization inference operator -* -* \param data The input tensor. 4-D shape NCHW or NHWC -* \param size Integer to define normalisation window size -* \param axis Input data layout channel axis -* \param alpha Float scaling factor -* \param beta Exponent value -* \param bias Offset to avoid dividing by zero -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the Local response normalization operation -*/ -inline Tensor lrn(const Tensor& data, - int size, - int axis = 1, - float alpha = 0.0001, - float beta = 0.75, - float bias = 2, - std::string name = "tensor", + * \brief Local response normalization inference operator + * + * \param data The input tensor. 4-D shape NCHW or NHWC + * \param size Integer to define normalisation window size + * \param axis Input data layout channel axis + * \param alpha Float scaling factor + * \param beta Exponent value + * \param bias Offset to avoid dividing by zero + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the Local response normalization operation + */ +inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001, + float beta = 0.75, float bias = 2, std::string name = "tensor", std::string tag = kBroadcast) { CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; CHECK_EQ(size % 2, 1) << "size should be odd number"; CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; auto input_shape = data->shape; - Array pad_before{ 0, 0, 0, 0}; - Array pad_after{ 0, 0, 0, 0}; - pad_before.Set(axis, static_cast(size/2)); - pad_after.Set(axis, static_cast(size/2)); + Array pad_before{0, 0, 0, 0}; + Array pad_after{0, 0, 0, 0}; + pad_before.Set(axis, static_cast(size / 2)); + pad_after.Set(axis, static_cast(size / 2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs"); Tensor sqr_sum; if (axis == 1) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l + rxs, j, k) * - pad_data(i, l + rxs, j, k), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs}); + }); } else if (axis == 3) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l, j, k + rxs) * - pad_data(i, l, j, k + rxs), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs}); + }); } - auto sqrt_sum_up = tvm::te::compute( - input_shape, - [&](Var i, Var j, Var k, Var l) { - return tvm::pow(bias + - (div(alpha * sqr_sum(i, j, k, l), size)), - beta); - }); + auto sqrt_sum_up = tvm::te::compute(input_shape, [&](Var i, Var j, Var k, Var l) { + return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta); + }); return topi::divide(data, sqrt_sum_up); } } // namespace nn diff --git a/topi/include/topi/nn/mapping.h b/topi/include/topi/nn/mapping.h index 17d14045ac4b5..d4a3a4766bb0e 100644 --- a/topi/include/topi/nn/mapping.h +++ b/topi/include/topi/nn/mapping.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_MAPPING_H_ #define TOPI_NN_MAPPING_H_ -#include #include +#include #include @@ -35,49 +35,39 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Scale and shift with NCHW order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nchw(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NCHW order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var c, Var h, Var w) { - return x(b, c, h, w) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); }, + name, tag); } /*! -* \brief Scale and shift with NHWC order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nhwc(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NHWC order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var h, Var w, Var c) { - return x(b, h, w, c) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 324ecadfe95a5..ffc4f9856a655 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -45,31 +45,25 @@ enum PoolType : int { kMaxPool, }; - /*! -* \brief Perform pooling on height and width dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param height_axis index of the height dimension -* \param width_axis index of the width dimension -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const size_t height_axis, - const size_t width_axis, - bool count_include_pad) { + * \brief Perform pooling on height and width dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param height_axis index of the height dimension + * \param width_axis index of the width dimension + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const size_t height_axis, + const size_t width_axis, bool count_include_pad) { CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; @@ -103,10 +97,10 @@ inline Tensor pool_impl(const Tensor& x, pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); arith::Analyzer analyzer; - auto out_height = analyzer.Simplify( - indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = analyzer.Simplify( - indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); + auto out_height = + analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); + auto out_width = + analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -115,69 +109,72 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h0 = as_const_int(pad_top); - const int64_t *padding_w0 = as_const_int(pad_left); - const int64_t *padding_h1 = as_const_int(pad_bottom); - const int64_t *padding_w1 = as_const_int(pad_right); + const int64_t* padding_h0 = as_const_int(pad_top); + const int64_t* padding_w0 = as_const_int(pad_left); + const int64_t* padding_h1 = as_const_int(pad_bottom); + const int64_t* padding_w1 = as_const_int(pad_right); const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::max(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_max"); + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::max(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::sum(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::sum(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - return div(pool_sum(indices), (kernel_height * kernel_width)); - } else { - PrimExpr h_start = output[height_axis] * stride_height - pad_top; - PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + return div(pool_sum(indices), (kernel_height * kernel_width)); + } else { + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; } } -inline Tensor pool_grad_impl(const Tensor& out_grad, - const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, bool ceil_mode, - const size_t height_axis, const size_t width_axis, +inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, + const Array& kernel_size, const Array& stride_size, + const Array& padding_size, PoolType pool_type, + bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; @@ -237,38 +234,35 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); auto argmax = MakeArgmaxReducer(); - auto pad_x = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - - auto mp_argmax = - tvm::te::compute( - out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; - window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); - window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); - auto idx = detail::RavelIndex(window_inds, ravel_shape); - return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); - }, - "maxpool_grad_argmax", kCommReduceIdx); + auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + + auto mp_argmax = tvm::te::compute( + out_shape, + [&](const Array& inds) { + Array window_inds{inds.begin(), inds.end()}; + window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); + window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); + auto idx = detail::RavelIndex(window_inds, ravel_shape); + return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); + }, + "maxpool_grad_argmax", kCommReduceIdx); auto mp_inds = mp_argmax[0]; return tvm::te::compute( x->shape, [&](const Array& inds) { - Array pad_inds {inds.begin(), inds.end()}; + Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx {inds.begin(), inds.end()}; + Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); @@ -280,19 +274,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else(tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[width_axis] >= out_idx_lower_w), - mp_inds(out_idx) == idx), + tvm::if_then_else( + tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + out_idx[width_axis] >= out_idx_lower_w), + mp_inds(out_idx) == idx), out_grad(out_idx), make_const(x->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); } else if (pool_type == kAvgPool) { - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); return tvm::te::compute( x->shape, [&](const Array& inds) { @@ -304,12 +297,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); - PrimExpr out_idx_lower_h = tir::SelectNode::make( - pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), - (pad_h_idx - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::SelectNode::make( - pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), - (pad_w_idx - kernel_width) / stride_width + 1); + PrimExpr out_idx_lower_h = + tir::SelectNode::make(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), + (pad_h_idx - kernel_height) / stride_height + 1); + PrimExpr out_idx_lower_w = + tir::SelectNode::make(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), + (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { @@ -321,17 +314,16 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); - divide_factor = - tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::Int(32), 1)); + divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::Int(32), 1)); } - return tvm::sum(tvm::if_then_else( - tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[height_axis] < out_height), - tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, - out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), + return tvm::sum( + tvm::if_then_else( + tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, + out_idx[height_axis] < out_height), + tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, + out_idx[width_axis] < out_width)), + out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -341,15 +333,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, } } -inline bool find_depth_height_width(const std::string& layout, - int* depth_axis, - int* height_axis, +inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis, int* width_axis) { *depth_axis = -1, *height_axis = -1, *width_axis = -1; int curr_idx = 0; for (size_t i = 0; i < layout.size(); ++i) { - if ((layout[i] >= 'A' && layout[i] <= 'Z') || - (layout[i] >= 'a' && layout[i] <= 'z')) { + if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) { if (layout[i] == 'D') { if (*depth_axis != -1) return false; *depth_axis = curr_idx; @@ -370,21 +359,18 @@ inline bool find_depth_height_width(const std::string& layout, return true; } -inline bool find_height_width(const std::string& layout, - int* height_axis, - int* width_axis) { +inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); if (*height_axis != -1 && *width_axis != -1) { return true; } return false; } -inline bool find_width(const std::string& layout, - int* width_axis) { +inline bool find_width(const std::string& layout, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); if (*width_axis != -1) { return true; } @@ -392,48 +378,42 @@ inline bool find_width(const std::string& layout, } /*! -* \brief Perform pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCHW", + * \brief Perform pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; - return pool_impl(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, height_axis, width_axis, - count_include_pad); + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis, + width_axis, count_include_pad); } /*! @@ -476,34 +456,27 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& output_size, - PoolType pool_type, - const std::vector& axes) { + * \brief Perform adaptive pooling on N dimensional data + * + * \param x The input tensor + * \param output_size int vector of size in each dimension + * \param pool_type The type of pooling operator + * \param axes indices of each dimension + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; @@ -533,32 +506,41 @@ inline Tensor adaptive_pool_impl(const Tensor& x, }; if (pool_type == kMaxPool) { - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::max(x(indices), reduce_axes); // NOLINT(*) - }, "tensor", "adaptive_pool_max"); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::max(x(indices), reduce_axes); // NOLINT(*) + }, + "tensor", "adaptive_pool_max"); } else if (pool_type == kAvgPool) { - auto pool_sum = tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::sum(x(indices), reduce_axes); - }, "tensor", "adaptive_pool_sum"); - - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, false); - - PrimExpr divide_factor = tvm::cast(x->dtype, 1); - for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); - } + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::sum(x(indices), reduce_axes); + }, + "tensor", "adaptive_pool_sum"); - return div(pool_sum(indices), divide_factor); - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, false); + + PrimExpr divide_factor = tvm::cast(x->dtype, 1); + for (size_t i = 0; i < n_dim; ++i) { + divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + } + + return div(pool_sum(indices), divide_factor); + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -566,118 +548,107 @@ inline Tensor adaptive_pool_impl(const Tensor& x, } /*! -* \brief Adaptively perform pooling on height and width dimension of data. -* The pooling kernel and stride sizes are automatically chosen for desired output sizes. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor -* \param output_size Vector of two ints: {output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout order -*/ -inline Tensor adaptive_pool(const Tensor& x, - const Array& output_size, - PoolType pool_type, + * \brief Adaptively perform pooling on height and width dimension of data. + * The pooling kernel and stride sizes are automatically chosen for desired output sizes. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor + * \param output_size Vector of two ints: {output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); } /*! -* \brief Adaptively perform pooling on three dimensional data. -* See the two dimensional version above for details. -* \param x The input tensor -* \param output_size Vector of three ints: {output_depth, output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. The default is "NCDHW". -*/ -inline Tensor adaptive_pool3d(const Tensor& x, - const Array& output_size, - PoolType pool_type, - const std::string& layout = "NCDHW") { + * \brief Adaptively perform pooling on three dimensional data. + * See the two dimensional version above for details. + * \param x The input tensor + * \param output_size Vector of three ints: {output_depth, output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. The default is "NCDHW". + */ +inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); } /*! -* \brief Perform global pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, ... are valid for global_pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor represent as layout -* \param pool_type The type of pooling operator -* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the sub-dimension. -* For example, `NCHW16c` can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of -* dimensions other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout with height and width dimension size of 1. -* e.g., for NCHW, the output shape will be [batch, channel, 1, 1] -*/ -inline Tensor global_pool(const Tensor& x, - PoolType pool_type, - const std::string& layout = "NCHW") { + * \brief Perform global pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, ... are valid for global_pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor represent as layout + * \param pool_type The type of pooling operator + * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the sub-dimension. + * For example, `NCHW16c` can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of + * dimensions other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout with height and width dimension size of 1. + * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] + */ +inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { return adaptive_pool(x, Array{1, 1}, pool_type, layout); } /*! -* \brief Perform pooling on N-dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of N ints -* \param stride_size Vector of N ints -* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., -* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param axis Vector of indices for the N dimensions -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl_nd(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::vector& axis, + * \brief Perform pooling on N-dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of N ints + * \param stride_size Vector of N ints + * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., + * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param axis Vector of indices for the N dimensions + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" - " kernel"; + " kernel"; CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; Array daxis; @@ -696,8 +667,8 @@ inline Tensor pool_impl_nd(const Tensor& x, stride[i] = cast(DataType::Int(32), stride_size[i]); pad_head[i] = cast(DataType::Int(32), padding_size[i]); pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]); - const int64_t *padding0 = as_const_int(pad_head[i]); - const int64_t *padding1 = as_const_int(pad_tail[i]); + const int64_t* padding0 = as_const_int(pad_head[i]); + const int64_t* padding1 = as_const_int(pad_tail[i]); do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1)); if (ceil_mode) { @@ -713,69 +684,76 @@ inline Tensor pool_impl_nd(const Tensor& x, arith::Analyzer analyzer; auto out_dim = analyzer.Simplify( - indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); } if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } - return tvm::max(temp(indices), daxis); - }, "tensor", "pool_max"); + return tvm::max(temp(indices), daxis); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } - return tvm::sum(temp(indices), daxis); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } + return tvm::sum(temp(indices), daxis); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - kernel_size *= kernel[i]; - } - return div(pool_sum(indices), kernel_size); - } else { - std::vector start(k_size); - std::vector end(k_size); - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); - kernel_size *= (end[i] - start[i]); - } - - PrimExpr divide_factor = tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + return div(pool_sum(indices), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + PrimExpr divide_factor = + tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -783,94 +761,85 @@ inline Tensor pool_impl_nd(const Tensor& x, } /*! -* \brief Perform pooling on the width dimension of data. -* Width axis is determined by the layout string -* in which 'W' means width. -* Width dimension cannot be split. -* For example, NCW, NCW16c, etc. are valid for pool, -* while NCW16w is not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_width} -* \param stride_size Vector of three ints: {stride_width} -* \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'W' appears. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCW16c can describe a 4-D tensor of -* [batch_size, channel, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `W`, one can pass `NCWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool1d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCW", + * \brief Perform pooling on the width dimension of data. + * Width axis is determined by the layout string + * in which 'W' means width. + * Width dimension cannot be split. + * For example, NCW, NCW16c, etc. are valid for pool, + * while NCW16w is not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_width} + * \param stride_size Vector of three ints: {stride_width} + * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'W' appears. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCW16c can describe a 4-D tensor of + * [batch_size, channel, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `W`, one can pass `NCWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool1d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; - CHECK(find_width(layout, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } /*! -* \brief Perform pooling on depth, height and width dimension of data. -* It decides the depth, height and width dimension according to the layout string, -* in which 'D', 'W' and 'H' means depth, width and height respectively. -* Depth, Width and height dimension cannot be split. -* For example, NCDHW, NCDHW16c, etc. are valid for pool, -* while NCDHW16d, NCDHW16w or NCDHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} -* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} -* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, -* tail_pad_depth, tail_pad_height, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCDHW16c can describe a 6-D tensor of -* [batch_size, channel, depth, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `D`, `H` and `W`, one can pass `NCDHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool3d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCDHW", + * \brief Perform pooling on depth, height and width dimension of data. + * It decides the depth, height and width dimension according to the layout string, + * in which 'D', 'W' and 'H' means depth, width and height respectively. + * Depth, Width and height dimension cannot be split. + * For example, NCDHW, NCDHW16c, etc. are valid for pool, + * while NCDHW16d, NCDHW16w or NCDHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} + * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} + * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, + * tail_pad_depth, tail_pad_height, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCDHW16c can describe a 6-D tensor of + * [batch_size, channel, depth, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `D`, `H` and `W`, one can pass `NCDHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool3d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; std::vector axis = {depth_axis, height_axis, width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } } // namespace nn diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index dc76a9e3e61ab..9ae9d6a73fa54 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_SOFTMAX_H_ #define TOPI_NN_SOFTMAX_H_ -#include #include #include +#include #include #include @@ -37,18 +37,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Softmax activation -* -* \param x The input tensor. Can be any dimension -* \param axis The channel axis along which softmax is performed -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the softmax operation -*/ -inline Tensor softmax(const Tensor &x, - int axis = -1, - std::string name = "tensor", + * \brief Softmax activation + * + * \param x The input tensor. Can be any dimension + * \param axis The channel axis along which softmax is performed + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the softmax operation + */ +inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor", std::string tag = "softmax_output") { auto input_shape = x->shape; auto ndim = input_shape.size(); @@ -64,8 +62,7 @@ inline Tensor softmax(const Tensor &x, tvm::Map attrs; attrs.Set("axis", Integer(axis)); - auto insert_reduce_index = [axis, ndim](const Array &indices, - const IterVar &reduce_index) { + auto insert_reduce_index = [axis, ndim](const Array& indices, const IterVar& reduce_index) { Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { @@ -77,61 +74,54 @@ inline Tensor softmax(const Tensor &x, return eval_range; }; - auto get_non_reduce_indices = [axis, ndim](const Array &indices) { + auto get_non_reduce_indices = [axis, ndim](const Array& indices) { Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { - if (static_cast(i) != axis) - non_reduce_indices.push_back(indices[i]); + if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); } return non_reduce_indices; }; - auto _compute_max = [&](const Array &indices) { + auto _compute_max = [&](const Array& indices) { auto eval_range = insert_reduce_index(indices, k1); return topi::MaxOp(x(eval_range), {k1}); }; - auto _compute_exp = [&](const Tensor &max_elem, - const Array &indices) { + auto _compute_exp = [&](const Tensor& max_elem, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return tvm::exp(x(indices) - max_elem(non_reduce_indices)); }; - auto _compute_expsum = [&](const Tensor &exp, - const Array &indices) { + auto _compute_expsum = [&](const Tensor& exp, const Array& indices) { auto eval_range = insert_reduce_index(indices, k2); return tvm::sum(exp(eval_range), {k2}); }; - auto _normalize = [&](const Tensor &exp, const Tensor &expsum, - const Array &indices) { + auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return exp(indices) / expsum(non_reduce_indices); }; auto max_elem = tvm::te::compute(reduced_shape, _compute_max); - auto exp = tvm::te::compute(input_shape, [&](const Array &indices) { - return _compute_exp(max_elem, indices); - }); - auto expsum = tvm::te::compute(reduced_shape, [&](const Array &indices) { - return _compute_expsum(exp, indices); - }); - return tvm::te::compute(input_shape, [&](const Array &indices) { - return _normalize(exp, expsum, indices); - }, name, tag, attrs); + auto exp = tvm::te::compute( + input_shape, [&](const Array& indices) { return _compute_exp(max_elem, indices); }); + auto expsum = tvm::te::compute( + reduced_shape, [&](const Array& indices) { return _compute_expsum(exp, indices); }); + return tvm::te::compute( + input_shape, [&](const Array& indices) { return _normalize(exp, expsum, indices); }, + name, tag, attrs); } /*! -* \brief Log softmax activation -* -* \param x The input tensor. 2-D where log softmax is performed along the second dimension -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the log softmax operation -*/ -inline Tensor log_softmax(const Tensor& x, - std::string name = "tensor", + * \brief Log softmax activation + * + * \param x The input tensor. 2-D where log softmax is performed along the second dimension + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the log softmax operation + */ +inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", std::string tag = "log_softmax_output") { CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; @@ -139,19 +129,16 @@ inline Tensor log_softmax(const Tensor& x, PrimExpr n = x->shape[1]; auto k = tvm::te::reduce_axis(Range(0, n), "k"); - auto max_elem = tvm::te::compute( - { m }, [&](Var i) { - return tvm::max(x(i, k), Array{ k }); }); + auto max_elem = + tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array{k}); }); k = tvm::te::reduce_axis(Range(0, n), "k"); - auto expsum = tvm::te::compute( - { m }, [&](Var i) { - return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); }); + auto expsum = + tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); }); return tvm::te::compute( - x->shape, [&](Var i, Var j) { - return x(i, j) - max_elem(i) - tvm::log(expsum(i)); - }, name, tag); + x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name, + tag); } } // namespace nn diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 81c6963835e5a..c45bb501a2416 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -24,18 +24,18 @@ #ifndef TOPI_REDUCTION_H_ #define TOPI_REDUCTION_H_ -#include #include +#include +#include #include #include #include -#include -#include +#include #include +#include #include #include -#include namespace topi { using namespace tvm; @@ -45,21 +45,21 @@ using namespace tvm::te; using FReduce = std::function& axis)>; /*! \brief The operation to use for CommReduceIdx */ -using FCommReduce = std::function< - Array(Array exprs, const Array& axis, PrimExpr* condition)>; +using FCommReduce = std::function(Array exprs, const Array& axis, + PrimExpr* condition)>; /*! -* \brief Convert a reduction axis which could be empty or have negative -* elements into a real axis with valid dimension indices. -* -* \param ndim Number of dimensions in the target. -* \param axis The axis parameter. -* -* \return A non-empty sorted array of valid dimension indices, with no duplicates. -* If the input axis is empty, the result will be an axis including all dimensions. -* If any input element is negative, it will be treated as an offset from the -* last dimension (same as python indexing rules). -*/ + * \brief Convert a reduction axis which could be empty or have negative + * elements into a real axis with valid dimension indices. + * + * \param ndim Number of dimensions in the target. + * \param axis The axis parameter. + * + * \return A non-empty sorted array of valid dimension indices, with no duplicates. + * If the input axis is empty, the result will be an axis including all dimensions. + * If any input element is negative, it will be treated as an offset from the + * last dimension (same as python indexing rules). + */ inline std::vector GetRealAxis(int ndim, const Array& axis) { std::vector real_axis; if (!axis.defined() || axis.size() == 0) { @@ -78,8 +78,7 @@ inline std::vector GetRealAxis(int ndim, const Array& axis) { real_axis.push_back(static_cast(val)); } std::sort(real_axis.begin(), real_axis.end()); - real_axis.resize( - std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); + real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); } return real_axis; } @@ -89,17 +88,14 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te Array reduce_axes; for (auto i : real_axis) { std::string name = "k" + std::to_string(i); - reduce_axes.push_back( - tvm::te::reduce_axis(Range(0, data->shape[i]), name)); + reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); } return reduce_axes; } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, - const Tensor& data, - bool keepdims, - bool atleast1d) { +inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); Array target_shape; if (keepdims) { @@ -137,9 +133,7 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, * * \return The result tensor. */ -inline Tensor DoCommReduce(const Tensor& data, - FReduce func, - const Array& target_shape, +inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes) { auto r_axes = MakeReduceAxes(reduce_axes, data); @@ -182,45 +176,39 @@ inline Tensor DoCommReduce(const Tensor& data, * * \return The result tensor. */ -inline Tensor CommReduce(const Tensor& data, - const Array& axis, - FReduce func, - bool keepdims, - bool atleast1d) { +inline Tensor CommReduce(const Tensor& data, const Array& axis, FReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); return DoCommReduce(data, func, target_shape, real_axis, - keepdims ? std::vector() : real_axis); + keepdims ? std::vector() : real_axis); } /*! -* \brief Create an index reduction operation. -* -* \param data The input tensor. -* \param axis The axes along which the reduction is performed. -* \param func The reduction function -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return The result tensor. -*/ -inline Tensor CommReduceIdx(const Tensor& data, - const Array& axis, - FCommReduce func, - bool keepdims, - bool atleast1d) { + * \brief Create an index reduction operation. + * + * \param data The input tensor. + * \param axis The axes along which the reduction is performed. + * \param func The reduction function + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return The result tensor. + */ +inline Tensor CommReduceIdx(const Tensor& data, const Array& axis, FCommReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); - auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data] - (const Array& indices) { + auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { Array eval_range; Array eval_indices; int arg_counter = 0; @@ -247,18 +235,16 @@ inline Tensor CommReduceIdx(const Tensor& data, ravel_shape.push_back(data->shape[i]); } auto idx = detail::RavelIndex(eval_indices, ravel_shape); - return func({ idx, data(eval_range) }, reduce_axes, nullptr); + return func({idx, data(eval_range)}, reduce_axes, nullptr); }; - auto temp_idx_val = tvm::te::compute(target_shape, compute, - data->op->name + "_red_temp", kCommReduceIdx); + auto temp_idx_val = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx); auto temp_idx = temp_idx_val[0]; auto temp_val = temp_idx_val[1]; return tvm::te::compute( - target_shape, - [&temp_idx](const Array& indices) { return temp_idx(indices); }, - data->op->name + "_red", - kCommReduceIdx); + target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); }, + data->op->name + "_red", kCommReduceIdx); } /*! \brief A combiner function for a reduction */ @@ -276,11 +262,10 @@ using FIdentity = std::function(std::vector types)>; * * \return A reducer function which creates a reduce expression over an axis. */ -inline FCommReduce MakeCommReducer(FCombine fcombine, - FIdentity fidentity, +inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { - return [fcombine, fidentity, name] - (Array exprs, const Array& axis, PrimExpr* condition) { + return [fcombine, fidentity, name](Array exprs, const Array& axis, + PrimExpr* condition) { Array lhs, rhs; std::vector dtypes; @@ -299,16 +284,14 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { outputs.push_back( - tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); + tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); } return outputs; }; } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline PrimExpr MinOp(PrimExpr source, Array axis) { - return tvm::min(source, axis); -} +inline PrimExpr MinOp(PrimExpr source, Array axis) { return tvm::min(source, axis); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ inline PrimExpr MaxOp(PrimExpr source, Array axis) { @@ -321,21 +304,19 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis) { } /*! -* \brief Creates an operation that sums array elements over a given axis -* -* \param data The input tensor -* \param axis The axis to sum over. If axis is empty, the operation will -* sum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the sum operation -*/ -inline Tensor sum(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that sums array elements over a given axis + * + * \param data The input tensor + * \param axis The axis to sum over. If axis is empty, the operation will + * sum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the sum operation + */ +inline Tensor sum(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); } @@ -347,8 +328,7 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { std::vector reduce_axes; std::vector squeeze_axes; - for (int i_ax = ishape.size() - 1, - o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { + for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) { --o_ax; continue; @@ -369,106 +349,96 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { } /*! -* \brief Creates an operation that computes the logical AND of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical AND over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor all(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical AND of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical AND over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor all(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } /*! -* \brief Creates an operation that computes the logical OR of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical OR over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor any(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical OR of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical OR over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor any(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the minimum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the minimum over. If axis is empty, the -* operation will find the minimum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the min operation -*/ -inline Tensor min(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the minimum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the minimum over. If axis is empty, the + * operation will find the minimum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the min operation + */ +inline Tensor min(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the maximum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the maximum over. If axis is empty, the -* operation will find the maximum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the max operation -*/ -inline Tensor max(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the maximum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the maximum over. If axis is empty, the + * operation will find the maximum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the max operation + */ +inline Tensor max(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the indices of the minimum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmin is performed. If axis is empty, -* the operation will find the minimum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmin operation -*/ -inline Tensor argmin(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the minimum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmin is performed. If axis is empty, + * the operation will find the minimum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmin operation + */ +inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { Array result; @@ -479,7 +449,7 @@ inline Tensor argmin(const Tensor& data, auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val + result.push_back(tvm::max_value(types[1])); // val return result; }; auto func = MakeCommReducer(fcombine, fidentity, "argmin"); @@ -496,50 +466,46 @@ inline FCommReduce MakeArgmaxReducer() { auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val + result.push_back(tvm::min_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmax"); } /*! -* \brief Creates an operation that finds the indices of the maximum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmax is performed. If axis is empty, -* the operation will find the maximum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmax operation -*/ -inline Tensor argmax(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the maximum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmax is performed. If axis is empty, + * the operation will find the maximum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmax operation + */ +inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto reducer = MakeArgmaxReducer(); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } /*! -* \brief Creates product operation over given axis. -* -* \param data The input tensor -* \param axis The axis to do product over. If axis is empty, the -* operation will do the product over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the prod operation -*/ -inline Tensor prod(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates product operation over given axis. + * + * \param data The input tensor + * \param axis The axis to do product over. If axis is empty, the + * operation will do the product over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the prod operation + */ +inline Tensor prod(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } diff --git a/topi/include/topi/rocm/dense.h b/topi/include/topi/rocm/dense.h index 629b34e6ddaf3..72f8ee62e1557 100644 --- a/topi/include/topi/rocm/dense.h +++ b/topi/include/topi/rocm/dense.h @@ -24,14 +24,15 @@ #ifndef TOPI_ROCM_DENSE_H_ #define TOPI_ROCM_DENSE_H_ -#include -#include #include -#include "topi/detail/array_utils.h" -#include "topi/nn/dense.h" +#include +#include + #include "topi/contrib/rocblas.h" -#include "topi/generic/extern.h" #include "topi/cuda/dense.h" +#include "topi/detail/array_utils.h" +#include "topi/generic/extern.h" +#include "topi/nn/dense.h" namespace topi { using namespace tvm; @@ -39,21 +40,19 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Implementation of dense for rocm backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_rocm(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for rocm backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +67,8 @@ inline tvm::te::Tensor dense_rocm(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::rocblas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +78,15 @@ inline tvm::te::Tensor dense_rocm(const Target& target, } /*! -* \brief Create a rocm schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "rocm" && - target->libs().count("rocblas")) { + * \brief Create a rocm schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "rocm" && target->libs().count("rocblas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/rocm/injective.h b/topi/include/topi/rocm/injective.h index f3a3f3b0cbd23..e7415bfd0ff24 100644 --- a/topi/include/topi/rocm/injective.h +++ b/topi/include/topi/rocm/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_INJECTIVE_H_ #define TOPI_ROCM_INJECTIVE_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/injective.h" @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { return topi::cuda::schedule_injective(target, outs); } diff --git a/topi/include/topi/rocm/normalization.h b/topi/include/topi/rocm/normalization.h index 303f4a8302c71..832868348b678 100644 --- a/topi/include/topi/rocm/normalization.h +++ b/topi/include/topi/rocm/normalization.h @@ -24,22 +24,20 @@ #ifndef TOPI_ROCM_NORMALIZATION_H_ #define TOPI_ROCM_NORMALIZATION_H_ -#include -#include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ -inline Schedule schedule_lrn(const Array& outs) { - return topi::cuda::schedule_lrn(outs); -} + * \brief Create a rocm schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ +inline Schedule schedule_lrn(const Array& outs) { return topi::cuda::schedule_lrn(outs); } } // namespace rocm } // namespace topi diff --git a/topi/include/topi/rocm/pooling.h b/topi/include/topi/rocm/pooling.h index 7d1f36f2ee33d..0b68a0ac53664 100644 --- a/topi/include/topi/rocm/pooling.h +++ b/topi/include/topi/rocm/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_ROCM_POOLING_H_ #define TOPI_ROCM_POOLING_H_ -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -38,26 +38,26 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_pool(target, outs); } /*! -* \brief Create a rocm schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_global_pool(target, outs); } diff --git a/topi/include/topi/rocm/reduction.h b/topi/include/topi/rocm/reduction.h index ea4b656239288..512bf20b4bc19 100644 --- a/topi/include/topi/rocm/reduction.h +++ b/topi/include/topi/rocm/reduction.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_REDUCTION_H_ #define TOPI_ROCM_REDUCTION_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/reduction.h" @@ -37,13 +37,13 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a rocm schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { return topi::cuda::schedule_reduce(target, outs); } diff --git a/topi/include/topi/rocm/softmax.h b/topi/include/topi/rocm/softmax.h index 63a0304d28184..de05c4cec9d33 100644 --- a/topi/include/topi/rocm/softmax.h +++ b/topi/include/topi/rocm/softmax.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_SOFTMAX_H_ #define TOPI_ROCM_SOFTMAX_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/softmax.h" @@ -45,7 +45,7 @@ namespace rocm { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { return topi::cuda::schedule_softmax(target, outs); } diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 8d353b949ab69..1e9ec446dfa3e 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -43,16 +43,12 @@ constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nh constexpr auto kGroupConv2d = "group_conv2d"; inline bool is_broadcast(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0; } inline bool is_injective(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0 || - tag.rfind(kInjective, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0 || + tag.rfind(kInjective, 0) == 0; } } // namespace topi diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 0609020b5c810..e21fc2a5ea3e6 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -24,19 +24,19 @@ #ifndef TOPI_TRANSFORM_H_ #define TOPI_TRANSFORM_H_ -#include -#include -#include -#include #include +#include #include +#include +#include +#include -#include -#include -#include #include +#include #include +#include #include +#include namespace topi { using namespace tvm; @@ -44,30 +44,25 @@ using namespace tvm::te; using namespace topi::detail; /*! -* \brief Creates an operation to insert new dimensions of length 1 -* -* \param x The input tensor -* \param axis The index of the first new dimension (allows negative -* indices as offsets from the last dimension) -* \param num_newaxis The number of new dimensions to insert -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the dim expansion operation -*/ -inline Tensor expand_dims(const Tensor& x, - int axis, - int num_newaxis = 1, - std::string name = "T_expand_dims", - std::string tag = kBroadcast) { + * \brief Creates an operation to insert new dimensions of length 1 + * + * \param x The input tensor + * \param axis The index of the first new dimension (allows negative + * indices as offsets from the last dimension) + * \param num_newaxis The number of new dimensions to insert + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the dim expansion operation + */ +inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, + std::string name = "T_expand_dims", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; if (axis < 0) { // Calculate offset from last dimension axis = ndim + axis + 1; @@ -84,32 +79,32 @@ inline Tensor expand_dims(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Permute the dimensions of an array -* -* \param x The input tensor -* \param axes The indices of the permutation. If this is empty, -* the dimensions will be reversed. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the transpose operation -*/ -inline Tensor transpose(const Tensor& x, - Array axes, - std::string name = "T_transpose", + * \brief Permute the dimensions of an array + * + * \param x The input tensor + * \param axes The indices of the permutation. If this is empty, + * the dimensions will be reversed. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the transpose operation + */ +inline Tensor transpose(const Tensor& x, Array axes, std::string name = "T_transpose", std::string tag = kInjective) { if (!axes.defined() || axes.size() == 0) { axes = Array(); @@ -127,11 +122,11 @@ inline Tensor transpose(const Tensor& x, axes.Set(i, new_axis); } CHECK((new_axis >= 0) && (new_axis < static_cast(x->shape.size()))) - << "axis=" << axis << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; for (size_t j = 0; j < axes.size(); ++j) { - if (i !=j) { + if (i != j) { CHECK(new_axis != static_cast(axes[j]->value)) << "repeated axis in transpose"; } } @@ -139,33 +134,33 @@ inline Tensor transpose(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - std::vector idx; - for (size_t i = 0; i < axes.size(); ++i) { - idx.push_back(1); - } - for (size_t i = 0; i < axes.size(); ++i) { - int axis = static_cast(axes[i]->value); - idx[axis] = indices[i]; - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + std::vector idx; + for (size_t i = 0; i < axes.size(); ++i) { + idx.push_back(1); + } + for (size_t i = 0; i < axes.size(); ++i) { + int axis = static_cast(axes[i]->value); + idx[axis] = indices[i]; + } + return x(idx); + }, + name, tag); } /*! -* \brief flip/reverse elements of an array in a particular axis -* -* \param x The input tensor -* \param axis The axis along which the tensors will be reveresed -* (allows negative indices) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reverse operation -*/ -inline Tensor flip(const Tensor& x, - int axis = 0, - std::string name = "T_flip", + * \brief flip/reverse elements of an array in a particular axis + * + * \param x The input tensor + * \param axis The axis along which the tensors will be reveresed + * (allows negative indices) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reverse operation + */ +inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip", std::string tag = kInjective) { size_t src_tensor_dim = x->shape.size(); int axis_inp = axis; @@ -175,42 +170,42 @@ inline Tensor flip(const Tensor& x, } CHECK((0 <= axis) && (axis < static_cast(x->shape.size()))) - << "axis=" << axis_inp << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis_inp << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; // Reverse the Input Tensor in the axis specified return compute( - x->shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - if (i == static_cast(axis)) { - real_indices.push_back(x->shape[i] - indices[i] - 1); - } else { - real_indices.push_back(indices[i]); + x->shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + if (i == static_cast(axis)) { + real_indices.push_back(x->shape[i] - indices[i] - 1); + } else { + real_indices.push_back(indices[i]); + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Reshape a tensor -* -* \param x The input tensor -* \param newshape The new shape -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reshape operation -*/ -inline Tensor reshape(const Tensor& x, - Array newshape, - std::string name = "T_reshape", + * \brief Reshape a tensor + * + * \param x The input tensor + * \param newshape The new shape + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reshape operation + */ +inline Tensor reshape(const Tensor& x, Array newshape, std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; Array target_shape; - for (const auto &ele : newshape) { + for (const auto& ele : newshape) { if (ele.as()) { target_shape.push_back(cast(DataType::Int(32), ele)); } else { @@ -219,16 +214,16 @@ inline Tensor reshape(const Tensor& x, } if (is_empty_shape(target_shape)) { - return compute(target_shape, - [&](const Array &indices) { return tvm::cast(x->dtype, 0); }, - name, tag); + return compute( + target_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - target_shape, [&](const Array& indices) { - return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), - x_shape)); - }, name, tag); + target_shape, + [&](const Array& indices) { + return x(UnravelIndex( + RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); + }, + name, tag); } } @@ -243,9 +238,7 @@ inline Tensor reshape(const Tensor& x, * \return A Tensor of coordinate arrays. */ -inline Tensor unravel_index(const Tensor& x, - const Tensor& shape, - std::string name = "T_unravel", +inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel", std::string tag = kInjective) { auto x_shape = x->shape; auto shape_shape = shape->shape; @@ -281,23 +274,20 @@ inline Tensor unravel_index(const Tensor& x, } /*! -* \brief Remove size 1 dimensions from the shape of a tensor. -* The removed dimensions must have a constant size of 1. -* -* \param x The input tensor -* \param axis Indices of the dimensions to remove. If this is empty, -* all entries with a constant size of 1 will be removed. + * \brief Remove size 1 dimensions from the shape of a tensor. + * The removed dimensions must have a constant size of 1. + * + * \param x The input tensor + * \param axis Indices of the dimensions to remove. If this is empty, + * all entries with a constant size of 1 will be removed. * \param atleast1d Whether the output need to be atleast1d. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the squeeze operation -*/ -inline Tensor squeeze(const Tensor& x, - Array axis, - bool atleast1d = false, - std::string name = "T_squeeze", - std::string tag = kInjective) { + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the squeeze operation + */ +inline Tensor squeeze(const Tensor& x, Array axis, bool atleast1d = false, + std::string name = "T_squeeze", std::string tag = kInjective) { auto ndim = x->shape.size(); std::vector axis_val; if (!axis.defined() || axis.size() == 0) { @@ -312,8 +302,7 @@ inline Tensor squeeze(const Tensor& x, if (val < 0) { val += static_cast(x->shape.size()); } - CHECK_EQ(GetConstInt(x->shape[val]), 1) << - "Dimension " << val << " must have size 1"; + CHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1"; axis_val.push_back(val); } } @@ -331,45 +320,42 @@ inline Tensor squeeze(const Tensor& x, } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - int flag = 0; - for (size_t i = 0; i < ndim; ++i) { - if (axis_set.count(static_cast(i)) == 0) { - real_indices.push_back(indices[i - flag]); - } else { - real_indices.push_back(0); - flag += 1; + out_shape, + [&](const Array& indices) { + Array real_indices; + int flag = 0; + for (size_t i = 0; i < ndim; ++i) { + if (axis_set.count(static_cast(i)) == 0) { + real_indices.push_back(indices[i - flag]); + } else { + real_indices.push_back(0); + flag += 1; + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Join a sequence of tensors along an existing axis -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be joined -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the concatenate operation -*/ -inline Tensor concatenate(const Array& inputs, - int axis = 0, - std::string name = "T_concat", + * \brief Join a sequence of tensors along an existing axis + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be joined + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the concatenate operation + */ +inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); - CHECK(-ndim <= axis && axis < ndim) - << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + CHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim; } - CHECK_LT(axis, inputs[0]->shape.size()) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; Array axis_sizes; for (auto t : inputs) { @@ -387,96 +373,87 @@ inline Tensor concatenate(const Array& inputs, } return compute( - out_shape, [&](const Array& indices) { - auto ret = inputs[0](indices); - auto ind = indices[axis]; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - ind -= axis_sizes[i]; + out_shape, + [&](const Array& indices) { + auto ret = inputs[0](indices); + auto ind = indices[axis]; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + ind -= axis_sizes[i]; + + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(ind); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); + ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); } - idx.push_back(ind); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - - ret = tvm::if_then_else(ind >= 0, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + return ret; + }, + name, tag); } /*! -* \brief Join a sequence of tensors along a new axis. -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be stacked -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the stack operation -*/ -inline Tensor stack(const Array& inputs, - int axis = 0, - std::string name = "T_stack", + * \brief Join a sequence of tensors along a new axis. + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be stacked + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the stack operation + */ +inline Tensor stack(const Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim + 1; } - CHECK_LT(axis, inputs[0]->shape.size() + 1) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); Array out_shape; - for (size_t i = 0; i < static_cast(axis); ++i) - out_shape.push_back(inputs[0]->shape[i]); + for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); for (size_t i = static_cast(axis); i < static_cast(ndim); ++i) out_shape.push_back(inputs[0]->shape[i]); return compute( - out_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < indices.size(); ++i) - if (i != static_cast(axis)) - idx.push_back(indices[i]); - auto ind = indices[axis]; - auto ret = inputs[0](idx); - for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { - ret = tvm::if_then_else(ind == i + 1, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + out_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < indices.size(); ++i) + if (i != static_cast(axis)) idx.push_back(indices[i]); + auto ind = indices[axis]; + auto ret = inputs[0](idx); + for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { + ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret); + } + return ret; + }, + name, tag); } /*! -* \brief Split a tensor into multiple sub-tensors -* -* \param x The input tensor -* \param split_indices The indices to split the input at. This must be in ascending -* order. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split(const Tensor& x, - Array split_indices, - int axis, - std::string name = "T_split", - std::string tag = kInjective) { + * \brief Split a tensor into multiple sub-tensors + * + * \param x The input tensor + * \param split_indices The indices to split the input at. This must be in ascending + * order. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -488,12 +465,11 @@ inline Array split(const Tensor& x, for (Integer idx : split_indices) { int val = static_cast(idx->value); - CHECK_GT(val, begin_ids.back()) - << "split_indices must be sorted"; + CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted"; begin_ids.push_back(val); } - Array< Array > out_shapes; + Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { int out_axis_size; if (i == begin_ids.size() - 1) { @@ -516,9 +492,9 @@ inline Array split(const Tensor& x, Array result; for (size_t i = 0; i < begin_ids.size(); ++i) { - result.push_back( - compute( - out_shapes[i], [&](const Array& indices) { + result.push_back(compute( + out_shapes[i], + [&](const Array& indices) { auto begin = begin_ids[i]; Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { @@ -530,30 +506,28 @@ inline Array split(const Tensor& x, } return x(real_indices); - }, name, tag)); + }, + name, tag)); } return result; } /*! -* \brief strided_slice of a tensor -* -* \param x The input tensor -* \param begin The indices to begin with in the slicing -* \param end Indicies indicating end of the slice -* \param strides Specifies the stride values, it can be negative -* in that case, the input tensor will be reversed in that particular axis -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Tensor strided_slice(const Tensor& x, - const Array& begin, - const Array& end, - const Array& strides, - std::string name = "T_strided_slice", + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); // Setup the ranges. @@ -615,43 +589,43 @@ inline Tensor strided_slice(const Tensor& x, int64_t end_i = index_canonicalization(end_vec[i]); int interval = std::abs(end_i - begin_i); - int slice_size = static_cast((interval - + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + int slice_size = + static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back(make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - stride_vec[i])); + strides_expr.push_back( + make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); out_shape.push_back(slice_size); } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return x(real_indices); - }, name, tag); + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return x(real_indices); + }, + name, tag); } /*! -* \brief Split a tensor into a number of sub-tensors -* -* \param x The input tensor -* \param num_sections The number of sections to split the tensor into. -* this must be an integer factor of the size of the axis being split. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split_sections(const Tensor& x, - int num_sections, - int axis, + * \brief Split a tensor into a number of sub-tensors + * + * \param x The input tensor + * \param num_sections The number of sections to split the tensor into. + * this must be an integer factor of the size of the axis being split. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split_sections(const Tensor& x, int num_sections, int axis, std::string name = "T_split_sections", std::string tag = kInjective) { if (axis < 0) { @@ -663,8 +637,8 @@ inline Array split_sections(const Tensor& x, CHECK_GT(num_sections, 0) << "Slice count must be > 0"; CHECK_EQ(src_axis_size % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis - << " (" << src_axis_size << ")"; + << "num_sections must be an integer factor of the size of axis " << axis << " (" + << src_axis_size << ")"; Array split_indices; auto seg_size = src_axis_size / num_sections; @@ -679,22 +653,19 @@ inline Array split_sections(const Tensor& x, } /*! -* \brief Take elements from an flattened input array when axis is None. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param mode The mode of the operation. -* \param name The name of the operation. -* \param mode The mode of to handle out of bound indices. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an flattened input array when axis is None. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param mode The mode of the operation. + * \param name The name of the operation. + * \param mode The mode of to handle out of bound indices. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { Array a_shape = a->shape; Array out_shape = indices->shape; PrimExpr a_size = 1; @@ -704,44 +675,44 @@ inline Tensor take(const Tensor& a, if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { - return a(UnravelIndex(indices(out_index), a_shape)); - }, name, tag); + out_shape, + [&](const Array& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } } - /*! -* \brief Mask the out-of-boundary elements of each sequence. -* -* \param data The source array. -* \param valid_length The real length of each sequence. -* \param mask_value The masking value. -* \param axis The axis of the temporal dimension of the sequence -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the sequence_mask operation -*/ -inline Tensor sequence_mask(const Tensor& data, - const Tensor& valid_length, - double mask_value, - int axis, - std::string name = "T_sequence_mask", + * \brief Mask the out-of-boundary elements of each sequence. + * + * \param data The source array. + * \param valid_length The real length of each sequence. + * \param mask_value The masking value. + * \param axis The axis of the temporal dimension of the sequence + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the sequence_mask operation + */ +inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value, + int axis, std::string name = "T_sequence_mask", std::string tag = kInjective) { CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1"; CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; @@ -749,38 +720,36 @@ inline Tensor sequence_mask(const Tensor& data, auto batch_dim = data->shape[1 - axis]; Array out_shape = data->shape; Tensor out = compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - PrimExpr ret = tvm::if_then_else( - tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tir::make_const(data->dtype, mask_value), data(out_index)); + PrimExpr ret = + tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), + tvm::tir::make_const(data->dtype, mask_value), data(out_index)); return ret; - }, name, tag); + }, + name, tag); return out; } /*! -* \brief Take elements from an array along an axis. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param axis The axis over which to select values. By default, -* the flattened input array is used. -* \param mode The mode for handling out of bound indices. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - int axis, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an array along an axis. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param axis The axis over which to select values. By default, + * the flattened input array is used. + * \param mode The mode for handling out of bound indices. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } @@ -801,30 +770,32 @@ inline Tensor take(const Tensor& a, } if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), - axis_dim - 1); + auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -836,12 +807,14 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -854,82 +827,78 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } } /*! -* \brief Return the elements, either from x or y, depending on the condition. -* -* \param condition The condition array. -* \param x First array to be selected. -* \param y Second array to be selected. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor selected from x or y depending on condition. -*/ -inline Tensor where(const Tensor& condition, - const Tensor& x, - const Tensor& y, - std::string name = "T_where", - std::string tag = kBroadcast) { + * \brief Return the elements, either from x or y, depending on the condition. + * + * \param condition The condition array. + * \param x First array to be selected. + * \param y Second array to be selected. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor selected from x or y depending on condition. + */ +inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, + std::string name = "T_where", std::string tag = kBroadcast) { CHECK_EQ(x->shape.size(), y->shape.size()) - << "x and y must have the same shape.Got different number of dimension: " - << x->shape.size() << " vs " << y->shape.size(); - CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " - << x->dtype << " vs " << y->dtype; + << "x and y must have the same shape.Got different number of dimension: " << x->shape.size() + << " vs " << y->shape.size(); + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; Array oshape = x->shape; Tensor out; if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) - << "condition array must be either have the same shape as x or to be a " - "1-D array.Got different number of dimension: " - << condition->shape.size() << " vs " << x->shape.size(); + << "condition array must be either have the same shape as x or to be a " + "1-D array.Got different number of dimension: " + << condition->shape.size() << " vs " << x->shape.size(); out = compute( - oshape, [&](const Array& indices) { - return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); + }, + name, tag); } else { CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) - << "If condition is 1-D, the first dimension must be the same as x: " - << condition->shape[0] << " vs " << x->shape[0]; + << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] + << " vs " << x->shape[0]; out = compute( - oshape, [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::SelectNode::make(condition(condition_idx) != 0, - x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + Array condition_idx{indices[0]}; + return tvm::tir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices)); + }, + name, tag); } return out; } /*! -* \brief Creates an operation to repeat elements of an array -* -* \param x The input tensor -* \param repeats The number of repetitions for each element -* \param axis The axis along which to repeat values (allows -* negative indices as offsets from the last dimension) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the repeat operation -*/ -inline Tensor repeat(const Tensor& x, - int repeats, - int axis, - std::string name = "T_repeat", + * \brief Creates an operation to repeat elements of an array + * + * \param x The input tensor + * \param repeats The number of repetitions for each element + * \param axis The axis along which to repeat values (allows + * negative indices as offsets from the last dimension) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the repeat operation + */ +inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; if (axis < 0) { // Calculate offset from last dimension axis += ndim; @@ -944,32 +913,32 @@ inline Tensor repeat(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - idx.push_back(indexdiv(indices[axis], repeats)); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(indexdiv(indices[axis], repeats)); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Creates an operation to tile elements of an array -* -* \param x The input tensor -* \param reps The number of times for repeating the tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the tile operation -*/ -inline Tensor tile(const Tensor& x, - Array reps, - std::string name = "T_tile", + * \brief Creates an operation to tile elements of an array + * + * \param x The input tensor + * \param reps The number of times for repeating the tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the tile operation + */ +inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); @@ -983,56 +952,47 @@ inline Tensor tile(const Tensor& x, reps_shape.push_back(reps[i]); } } else if (ndim > rdim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < (ndim - rdim); ++i) - reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } else { - for (size_t i = 0; i < (rdim - ndim); ++i) - data_shape.push_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } - for (size_t i = 0; i < tdim; ++i) - new_shape.push_back(data_shape[i] * reps_shape[i]); + for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); if (is_empty_shape(new_shape)) { - return compute(new_shape, - [&](const Array& indices) { return tvm::cast(x->dtype, 0);}, - name, tag); + return compute( + new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - new_shape, [&](const Array& indices) { - Array idx; - if (ndim >= rdim) { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[i], x->shape[i])); - } else { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); + } else { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } + return x(idx); + }, + name, tag); } } /*! -* \brief Gather elements from a n-dimension array. -* -* \param data The source array. -* \param indices The indices of the values to extract. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the gather_nd operation -*/ -inline Tensor gather_nd(const Tensor& data, - const Tensor& indices, - std::string name = "T_gather_nd", + * \brief Gather elements from a n-dimension array. + * + * \param data The source array. + * \param indices The indices of the values to extract. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the gather_nd operation + */ +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); @@ -1051,27 +1011,28 @@ inline Tensor gather_nd(const Tensor& data, out_shape.push_back(make_const(DataType::Int(32), 1)); } return compute( - out_shape, [&](const Array& out_index) { - Array indices_position; - indices_position.push_back(0); - for (size_t i = 0; i < ndim_i - 1; ++i) { - indices_position.push_back(out_index[i]); - } - Array real_indices; - for (size_t i = 0; i < indices_dim0; ++i) { - indices_position.Set(0, make_const(DataType::Int(32), i)); - if (indices->dtype.is_int()) { - real_indices.push_back(indices(indices_position)); - } else { - real_indices.push_back( - tvm::cast(tvm::DataType::Int(32), indices(indices_position))); - } - } - for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { - real_indices.push_back(out_index[i]); + out_shape, + [&](const Array& out_index) { + Array indices_position; + indices_position.push_back(0); + for (size_t i = 0; i < ndim_i - 1; ++i) { + indices_position.push_back(out_index[i]); + } + Array real_indices; + for (size_t i = 0; i < indices_dim0; ++i) { + indices_position.Set(0, make_const(DataType::Int(32), i)); + if (indices->dtype.is_int()) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); } - return data(real_indices); - }, name, tag); + } + for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { + real_indices.push_back(out_index[i]); + } + return data(real_indices); + }, + name, tag); } /*! @@ -1089,18 +1050,13 @@ inline Tensor gather_nd(const Tensor& data, * * \return A Tensor whose op member is the matmul operation */ -inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - bool trans_a = false, - bool trans_b = false, - std::string name = "T_matmul", - std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], - B->shape[trans_b ? 0 : 1]}; +inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, + bool trans_a = false, bool trans_b = false, + std::string name = "T_matmul", std::string tag = kMatMul) { + tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { - return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), - {k}); + return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -1116,45 +1072,35 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - int axes = 2, - std::string name = "T_tensordot", - std::string tag = kMatMul) { +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, + std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_GE(A->shape.size(), axes); CHECK_GE(B->shape.size(), axes); Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); - for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) - output_shape.push_back(*it); + for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); Array iter_vars; for (int i = 0; i < axes; ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); - auto func = - [&A, &B, &iter_vars, axes] - (const Array& input_indices) { - Array A_indices( - input_indices.begin(), - input_indices.begin() + (A->shape.size() - axes)); - for (auto& v : iter_vars) - A_indices.push_back(v); - - Array B_indices; - for (auto& v : iter_vars) - B_indices.push_back(v); - - auto it = input_indices.begin() + (A->shape.size() - axes); - for (; it != input_indices.end(); ++it) - B_indices.push_back(*it); - - // Some passes don't like reductions with empty axis, so avoid it here - if (iter_vars.empty()) - return A(A_indices) * B(B_indices); - else - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + auto func = [&A, &B, &iter_vars, axes](const Array& input_indices) { + Array A_indices(input_indices.begin(), + input_indices.begin() + (A->shape.size() - axes)); + for (auto& v : iter_vars) A_indices.push_back(v); + + Array B_indices; + for (auto& v : iter_vars) B_indices.push_back(v); + + auto it = input_indices.begin() + (A->shape.size() - axes); + for (; it != input_indices.end(); ++it) B_indices.push_back(*it); + + // Some passes don't like reductions with empty axis, so avoid it here + if (iter_vars.empty()) + return A(A_indices) * B(B_indices); + else + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } @@ -1171,11 +1117,8 @@ inline Tensor tensordot(const Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - Array A_axes, - Array B_axes, - std::string name = "T_tensordot", +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array A_axes, + Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_EQ(A_axes.size(), B_axes.size()); @@ -1191,47 +1134,42 @@ inline Tensor tensordot(const Tensor& A, output_shape.push_back(B->shape[i]); Array iter_vars; - for (unsigned i = 0; i < B_axes_val.size(); ++i) - iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); - - auto func = - [&A, &B, &iter_vars, A_axes_val, B_axes_val] - (const Array& input_indices) { - int idx_input = 0; - Array A_indices; - for (unsigned i = 0; i < A->shape.size(); ++i) { - auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); - if (axes_pos == A_axes_val.end()) - A_indices.push_back(input_indices[idx_input++]); - else - A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); - } + for (unsigned i = 0; i < B_axes_val.size(); ++i) + iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); + + auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array& input_indices) { + int idx_input = 0; + Array A_indices; + for (unsigned i = 0; i < A->shape.size(); ++i) { + auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); + if (axes_pos == A_axes_val.end()) + A_indices.push_back(input_indices[idx_input++]); + else + A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); + } - Array B_indices; - for (unsigned i = 0; i < B->shape.size(); ++i) { - auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); - if (axes_pos == B_axes_val.end()) - B_indices.push_back(input_indices[idx_input++]); - else - B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); - } - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + Array B_indices; + for (unsigned i = 0; i < B->shape.size(); ++i) { + auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); + if (axes_pos == B_axes_val.end()) + B_indices.push_back(input_indices[idx_input++]); + else + B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); + } + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } -inline Tensor arange(const PrimExpr& start, - const PrimExpr& stop, - const PrimExpr& step, - DataType dtype, - std::string name = "T_arange", - std::string tag = kInjective) { - PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil( - tvm::cast(tvm::DataType::Float(32), stop - start) / step)); +inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step, + DataType dtype, std::string name = "T_arange", std::string tag = kInjective) { + PrimExpr num_elem = tvm::cast( + tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); Array shape; - return compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start + step * indices[0]); - }, name, tag); + return compute( + {num_elem}, + [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, + tag); } /*! @@ -1243,8 +1181,7 @@ inline Tensor arange(const PrimExpr& start, * \param tag output tensor tag. * \return A tensor with shape in \p dst_layout */ -inline Tensor layout_transform(const Tensor& src, - const std::string& src_layout, +inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, const std::string name = "T_layout_trans", const std::string tag = kInjective) { @@ -1256,20 +1193,21 @@ inline Tensor layout_transform(const Tensor& src, } CHECK(src_layout_struct.defined() && dst_layout_struct.defined()) - << "cannot convert from/to undefined layout"; + << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); - CHECK(layout_converter.defined()) - << "cannot convert from " << src_layout << " to " << dst_layout; + CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; Array dst_shape = layout_converter.ForwardShape(src->shape); return compute( - dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); - return src(src_indices); - }, name, tag); + dst_shape, + [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + return src(src_indices); + }, + name, tag); } /*! @@ -1280,20 +1218,21 @@ inline Tensor layout_transform(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor shape(const Tensor& src, - DataType dtype, - const std::string name = "T_shape", +inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_shape{ndim}; - return compute(out_shape, [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = 0; - for (int i = 0; i < ndim; ++i) { - ret = tvm::if_then_else(idx == i, src->shape[i], ret); - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = 0; + for (int i = 0; i < ndim; ++i) { + ret = tvm::if_then_else(idx == i, src->shape[i], ret); + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1304,19 +1243,21 @@ inline Tensor shape(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor ndarray_size(const Tensor& src, - const DataType& dtype, +inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, const std::string& name = "ndarray_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_ndarray_size = {1}; - return compute(out_ndarray_size, [&](const Array& indices) { - PrimExpr ret = 1; - for (int i = 0; i < ndim; ++i) { - ret *= src->shape[i]; - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_ndarray_size, + [&](const Array& indices) { + PrimExpr ret = 1; + for (int i = 0; i < ndim; ++i) { + ret *= src->shape[i]; + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1332,14 +1273,9 @@ inline Tensor ndarray_size(const Tensor& src, * \param tag output tensor tag. * \return one-hot tensor. */ -inline Tensor one_hot(const Tensor& indices, - const PrimExpr on_value, - const PrimExpr off_value, - int depth, - int axis, - const DataType& dtype, - const std::string name = "T_one_hot", - const std::string tag = kInjective) { +inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, + int depth, int axis, const DataType& dtype, + const std::string name = "T_one_hot", const std::string tag = kInjective) { Array oshape; int ndim = indices->shape.size() + 1; int indices_index = 0; @@ -1354,19 +1290,23 @@ inline Tensor one_hot(const Tensor& indices, PrimExpr on_value_cast = cast(dtype, on_value); PrimExpr off_value_cast = cast(dtype, off_value); - return compute(oshape, [&](const Array& iter_vars) { - Array indices_indices; - for (size_t i = 0; i < iter_vars.size(); i++) { - if (static_cast(i) == true_axis) { - continue; - } + return compute( + oshape, + [&](const Array& iter_vars) { + Array indices_indices; + for (size_t i = 0; i < iter_vars.size(); i++) { + if (static_cast(i) == true_axis) { + continue; + } - indices_indices.push_back(iter_vars[i]); - } + indices_indices.push_back(iter_vars[i]); + } - auto idx = iter_vars[true_axis]; - return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); - }, name, tag); + auto idx = iter_vars[true_axis]; + return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, + off_value_cast); + }, + name, tag); } } // namespace topi diff --git a/topi/include/topi/vision/reorg.h b/topi/include/topi/vision/reorg.h index 06931e424de34..5bd79f67f052d 100644 --- a/topi/include/topi/vision/reorg.h +++ b/topi/include/topi/vision/reorg.h @@ -24,11 +24,11 @@ #ifndef TOPI_VISION_REORG_H_ #define TOPI_VISION_REORG_H_ -#include #include #include #include #include +#include #include #include @@ -39,18 +39,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Reorg operation -* -* \param data The input tensor. Can be any dimension -* \param stride The input integer used as stride in reorg operation -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reorg operation -*/ -inline Tensor reorg(const Tensor &data, - int stride = 1, - std::string name = "tensor", + * \brief Reorg operation + * + * \param data The input tensor. Can be any dimension + * \param stride The input integer used as stride in reorg operation + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reorg operation + */ +inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tensor", std::string tag = "reorg_output") { auto input_shape = data->shape; @@ -60,15 +58,14 @@ inline Tensor reorg(const Tensor &data, int w_in = GetConstInt(input_shape[3]); int out_c = c_in / (stride * stride); - auto out = tvm::te::compute(input_shape, - [&](Var b, Var k, Var j, Var i) { - return data(b * stride * stride, - indexmod(k, out_c) * stride * stride, - (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride, - (i*stride + indexmod(indexdiv(k, out_c), stride))); - }, - name, - tag); + auto out = tvm::te::compute( + input_shape, + [&](Var b, Var k, Var j, Var i) { + return data(b * stride * stride, indexmod(k, out_c) * stride * stride, + (j * stride + indexdiv(indexdiv(k, out_c), stride)) * stride, + (i * stride + indexmod(indexdiv(k, out_c), stride))); + }, + name, tag); out_c = c_in * stride * stride; int out_h = h_in / stride; diff --git a/topi/include/topi/x86/bnn.h b/topi/include/topi/x86/bnn.h index 53b7a8e0739e1..a59d30da3dce5 100644 --- a/topi/include/topi/x86/bnn.h +++ b/topi/include/topi/x86/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_BNN_H_ #define TOPI_X86_BNN_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -35,14 +35,14 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Create a generic schedule for binarize_pack -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binarize_pack(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binarize_pack + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binarize_pack(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -67,14 +67,14 @@ inline Schedule schedule_binarize_pack(const Target &target, const Array } /*! -* \brief Create a generic schedule for binary_dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binary_dense(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binary_dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binary_dense(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/x86/default.h b/topi/include/topi/x86/default.h index 9b6efa511d8d9..07337810a6945 100644 --- a/topi/include/topi/x86/default.h +++ b/topi/include/topi/x86/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_X86_DEFAULT_H_ #define TOPI_X86_DEFAULT_H_ -#include #include +#include +#include #include #include -#include namespace topi { using namespace tvm; @@ -36,16 +36,15 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Helper to create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* \param auto_inline Whether to apply the auto inline step. -* -* \return A schedule for the given ops. -*/ -inline Schedule MakeDefaultSchedule(const Target &target, - const Array& outs, + * \brief Helper to create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * \param auto_inline Whether to apply the auto inline step. + * + * \return A schedule for the given ops. + */ +inline Schedule MakeDefaultSchedule(const Target& target, const Array& outs, bool auto_inline) { Array out_ops; for (auto t : outs) { @@ -66,7 +65,7 @@ inline Schedule MakeDefaultSchedule(const Target &target, if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(s[x], {n, c}); // for nhwc layout, fuse n and h s[x].parallel(fused); } else { s[x].parallel(axis[0]); @@ -76,26 +75,26 @@ inline Schedule MakeDefaultSchedule(const Target &target, } /*! -* \brief Create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, false); } /*! -* \brief Create a default x86 schedule for the given ops, with auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule_auto_inline(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops, with auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule_auto_inline(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, true); } diff --git a/topi/include/topi/x86/injective.h b/topi/include/topi/x86/injective.h index 182140d68c5c1..069a97170816a 100644 --- a/topi/include/topi/x86/injective.h +++ b/topi/include/topi/x86/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_INJECTIVE_H_ #define TOPI_X86_INJECTIVE_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -48,7 +48,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(sch[out], {n, c}); // for nhwc layout, fuse n and h sch[out].parallel(fused); } else { sch[out].parallel(axis[0]); @@ -57,14 +57,14 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out } /*! -* \brief Create an x86 schedule for the given injective ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_injective(const Target &target, const Array& outs) { + * \brief Create an x86 schedule for the given injective ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index eb05dd839e32b..e121fbc7ec6d0 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -25,3 +25,4 @@ from .bitserial_conv2d import * from .bitserial_dense import * from .injective import * +from . import cortex_m7 diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 25b338e06b5f8..df63ae3e9e598 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -31,6 +31,7 @@ conv2d_spatial_pack_nhwc, \ schedule_conv2d_spatial_pack_nchw, \ schedule_conv2d_spatial_pack_nhwc +from .cortex_m7.conv2d import direct_simd @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu") @@ -425,3 +426,15 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + +@autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu") +def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with SIMD (v7e-m).""" + return direct_simd.conv2d_direct_simd_compute( + cfg, data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu") +def schedule_conv2d_direct_simd(cfg, outs): + """Create schedule for conv2d_direct_simd""" + return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py index 3bb9dc73e2db8..a4d7ad83b1c86 100644 --- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py +++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py @@ -152,13 +152,13 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, cfg["ann_reduce"].apply(s, conv, [kh, kw], axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)], - max_unroll=16, + max_unroll=None, cfg=cfg) cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], axis_lens=[cfg['tile_oh'].size[-1], cfg['tile_ow'].size[-1], cfg['tile_co'].size[-1]], - max_unroll=16, + max_unroll=None, cfg=cfg) # schedule fusion diff --git a/topi/python/topi/arm_cpu/cortex_m7/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/__init__.py new file mode 100644 index 0000000000000..631c5f7ff4471 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Schedules specialized for cortex-m7.""" + + +from . import conv2d diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py new file mode 100644 index 0000000000000..cc4faf97b126d --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Conv2d implementations for cortex-m7.""" + +from . import direct_simd diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py new file mode 100644 index 0000000000000..7d3e945fef142 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Direct implementation of conv2d.""" + +import tvm +from tvm import autotvm +from tvm.autotvm.task import deserialize_args +from topi.nn.conv2d import conv2d_nchw, conv2d_nhwc +from topi.util import get_const_tuple, get_const_int, traverse_inline + +def conv2d_direct(*args, **kwargs): + """Schedule function for directly-scheduled conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + conv = conv2d_direct_compute(*args) + if layout == 'NHWC': + sched = conv2d_direct_nhwc_schedule(cfg, [data, kernel, conv]) + elif layout == 'NCHW': + sched = conv2d_direct_nchw_schedule(cfg, [data, kernel, conv]) + else: + raise RuntimeError(f'unsupported data layout "{layout}"') + return sched, [data, kernel, conv] + + +conv2d_direct.template_key = 'direct' +conv2d_direct.default_data_layout = 'NHWC' +conv2d_direct.default_kernel_layout = 'HWIO' + +@autotvm.register_topi_compute('conv2d_direct.micro_dev') +def conv2d_direct_compute(*args): + layout = args[-2] + if layout == 'NHWC': + return _conv2d_direct_nhwc_compute(*args) + if layout == 'NCHW': + return _conv2d_direct_nchw_compute(*args) + + raise RuntimeError(f'unsupported data layout "{layout}"') + + +def _conv2d_direct_nhwc_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): + assert layout == 'NHWC' + conv = conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + # Config Space Definition + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + n, oh, ow, co = cfg.axis(N), cfg.axis(H), cfg.axis(W), cfg.axis(CO) + kh, kw, ci = cfg.reduce_axis(KH), cfg.reduce_axis(KW), cfg.reduce_axis(CI) + + # TODO should we add a max_factor attr to these splits? + co, vc = cfg.define_split('tile_co', co, num_outputs=2) + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) + + cfg.define_reorder('reorder_0', + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + policy='candidate', candidate=[ + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + [n, co, oh, ow, ci, kh, kw, vc, vh, vw], + [n, co, oh, ow, ci, vh, vw, vc, kh, kw], + [n, co, oh, ow, ci, vc, vh, vw, kh, kw]]) + + cfg.define_annotate('ann_reduce', [kh, kw], policy='try_unroll') + cfg.define_annotate('ann_spatial', [vh, vw, vc], policy='try_unroll') + + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +def _conv2d_direct_nchw_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): + assert layout == 'NCHW' + conv = conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) + + ########################### + # Config Space Definition # + ########################### + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +@autotvm.register_topi_schedule('conv2d_direct_nhwc.micro_dev') +def conv2d_direct_nhwc_schedule(cfg, outs): + """Schedule function for directly-scheduled conv2d on NHWC layout.""" + sched = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc' not in op.tag: + return + + ### extract tensors ### + output = op.output(0) + conv = op + data_vec = conv.input_tensors[0] + kernel = conv.input_tensors[1] # pylint: disable=unused-variable + last = outs[0] # pylint: disable=unused-variable + + # tile reduction axes + n, oh, ow, co = sched[conv].op.axis + kh, kw, ci = sched[conv].op.reduce_axis + # NOTE we can't inline data padding in the SIMD path, because it + # introduces conditionals in the inner loop. + data_pad = data_vec.op + sched[data_pad].compute_inline() + + co, vc = cfg['tile_co'].apply(sched, conv, co) + oh, vh = cfg['tile_oh'].apply(sched, conv, oh) + ow, vw = cfg['tile_ow'].apply(sched, conv, ow) + cfg['reorder_0'].apply(sched, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc]) + cfg['ann_reduce'].apply(sched, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=8, + cfg=cfg) + cfg['ann_spatial'].apply(sched, conv, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=8, + cfg=cfg) + + kernel_scope = n # this is the scope to attach global config inside this kernel + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(sched, outs[-1].op, _callback) + return sched + + +@autotvm.register_topi_schedule('conv2d_direct_nchw.micro_dev') +def conv2d_direct_nchw_schedule(cfg, outs): + """Schedule function for Cortex-M7 direct implementation of conv2d.""" + # use default schedule + sched = tvm.create_schedule([x.op for x in outs]) + + conv = outs[-1].op + output = conv.output(0) + data_vec = conv.input_tensors[0] + data_pad = data_vec.op + sched[data_pad].compute_inline() + + # TODO add more schedule opts (similar to the NHWC template) + + n, _, _, _ = sched[conv].op.axis + kernel_scope = n # this is the scope to attach global config inside this kernel + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + return sched diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py new file mode 100644 index 0000000000000..fd411251272e3 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-value-for-parameter +"""Direct implementation of conv2d.""" + +from tvm import autotvm +from tvm.autotvm.task import deserialize_args +from tvm import te +from topi.util import simplify, traverse_inline +from topi.nn.pad import pad +from topi.nn.util import get_pad_tuple + +from ..micro_kernel.gemm import ( + intrin_gemm_MxKxN, gemm_MxKxN_impl, +) + +def conv2d_direct_simd(*args, **kwargs): + """Defines the Cortex-M7 SIMD implementation of conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + assert layout == 'NHWC' + conv = conv2d_direct_simd_compute(*args) + sched = conv2d_direct_simd_nhwc_schedule(cfg, [data, kernel, conv]) + return sched, [data, kernel, conv] + + +conv2d_direct_simd.template_key = 'direct_simd' +conv2d_direct_simd.default_data_layout = 'NHWC' +conv2d_direct_simd.default_kernel_layout = 'HWOI' + +def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute function for Cortex-M7 SIMD implementation of conv2d.""" + assert isinstance(strides, int) or len(strides) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(strides, int): + stride_h = stride_w = strides + else: + stride_h, stride_w = strides + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch_size, in_height, in_width, in_channels = data.shape + kernel_h, kernel_w, out_channels, _ = kernel.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + padded_data = pad(data, pad_before, pad_after, name='padded_data') + + rc = te.reduce_axis((0, in_channels), name='rc') + ry = te.reduce_axis((0, kernel_h), name='ry') + rx = te.reduce_axis((0, kernel_w), name='rx') + + conv = te.compute( + (batch_size, out_height, out_width, out_channels), + lambda nn, yy, xx, ff: te.sum( + padded_data[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + kernel[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]), + name='conv2d', tag='conv2d_nhwc') + + ########################### + # Config Space Definition # + ########################### + n, oh, ow, co = (cfg.axis(batch_size.value), + cfg.axis(out_height.value), + cfg.axis(out_width.value), + cfg.axis(out_channels.value)) + kh, kw, ci = (cfg.reduce_axis(kernel_h.value), + cfg.reduce_axis(kernel_w.value), + cfg.reduce_axis(in_channels.value)) + + assert in_channels.value % 4 == 0 + owo, owi = cfg.define_split('tile_ow', ow, policy='factors', num_outputs=2) + cio, cii = cfg.define_split('tile_ci', ci, policy='factors', num_outputs=2, + filter=lambda x: x.size[-1] % 4 == 0) + coo, coi = cfg.define_split('tile_co', co, policy='factors', num_outputs=2) + + cfg.define_reorder('reorder_0_simd', + [n, oh, owo, owi, coo, coi, kh, kw, cio, cii], + policy='candidate', candidate=[ + [n, oh, kh, kw, owo, coo, cio, owi, coi, cii], + [n, oh, kh, kw, coo, owo, cio, owi, coi, cii], + [n, kh, kw, oh, owo, coo, cio, owi, coi, cii], + [n, kh, kw, oh, coo, owo, cio, owi, coi, cii]]) + + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +def conv2d_direct_simd_nhwc_schedule(cfg, outs): + """Schedule function for Cortex-M7 SIMD implementation of conv2d.""" + sched = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc' not in op.tag: + return + + # extract tensors + output = op.output(0) + conv = op + data_vec = conv.input_tensors[0] + kernel = conv.input_tensors[1] # pylint: disable=unused-variable + last = outs[0] # pylint: disable=unused-variable + + # tile reduction axes + n, oh, ow, co = sched[conv].op.axis + kh, kw, ci = sched[conv].op.reduce_axis + + M = cfg['tile_ow'].size[-1] + K = cfg['tile_ci'].size[-1] + N = cfg['tile_co'].size[-1] + + owo, owi = cfg['tile_ow'].apply(sched, conv, ow) + cio, cii = cfg['tile_ci'].apply(sched, conv, ci) + coo, coi = cfg['tile_co'].apply(sched, conv, co) + + cfg['reorder_0_simd'].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii]) + + gemm, uniq_id = intrin_gemm_MxKxN(M, K, N, data_vec.dtype, output.dtype) + sched[output].tensorize(owi, gemm) + sched[output].pragma(n, 'import_c', gemm_MxKxN_impl(M, K, N, uniq_id)) + + # this is the scope to attach global config inside this kernel + kernel_scope = n + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(sched, outs[-1].op, _callback) + return sched diff --git a/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py new file mode 100644 index 0000000000000..9af7bef95b7c6 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-value-for-parameter +"""Defines gemm intrinsics for SIMD matrix multiplication.""" + +import random +import string + +import tvm +from tvm import te + +########################## +# MxKxN MatMul Intrinsic # +########################## + +# NOTE this is transposed matmul (A * B^T) +def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype): + """Defines a SIMD-accelerated transposed matmul.""" + # we generate a unique ID for every intrinsic definition, to prevent name + # collisions in the generated source (e.g., if there are multiple operators + # in the same module that use the same intrinsic) + # + # TODO(weberlo, areusch): to cut down on memory usage, we should cache each intrinsic + # instantiation and include it only once, eliminating the need for unique + # IDs + UNIQ_ID_LEN = 8 + uniq_id = ''.join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN)) + + if isinstance(M, tvm.tir.IntImm): + M = M.value + if isinstance(K, tvm.tir.IntImm): + K = K.value + if isinstance(N, tvm.tir.IntImm): + N = N.value + assert K % 4 == 0 + # TODO(weberlo, areusch): support more dtypes? + assert in_dtype == 'int8' + assert out_dtype == 'int32' + A = te.placeholder((M, K), name='a', dtype=in_dtype) + B = te.placeholder((N, K), name='b', dtype=in_dtype) + k = te.reduce_axis((0, K), name='k') + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(out_dtype) * B[j, k].astype(out_dtype), axis=k), + name='c') + A_buf = tvm.tir.decl_buffer( + A.shape, A.dtype, + name="A", + offset_factor=1, + strides=[te.var("A_s"), 1]) + B_buf = tvm.tir.decl_buffer( + B.shape, B.dtype, + name="B", + offset_factor=1, + strides=[te.var("B_s"), 1]) + C_buf = tvm.tir.decl_buffer( + C.shape, C.dtype, + name="C", + offset_factor=1, + strides=[te.var("C_s"), 1]) + def intrin_func(ins, outs): + aa, bb = ins + cc = outs[0] + def _reduce_update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_update_{uniq_id}", + aa.access_ptr("r"), + bb.access_ptr("r"), + cc.access_ptr("w"), + aa.strides[0], + bb.strides[0], + cc.strides[0])) + return ib.get() + def _reduce_reset(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", + cc.access_ptr("w"), + cc.strides[0])) + return ib.get() + def _body(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_body_{uniq_id}", + aa.access_ptr("r"), + bb.access_ptr("r"), + cc.access_ptr("w"), + aa.strides[0], + bb.strides[0], + cc.strides[0])) + return ib.get() + return _body(), _reduce_reset(), _reduce_update() + with tvm.target.build_config(offset_factor=1): + intrin_decl = te.decl_tensor_intrin( + C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf}) + return intrin_decl, uniq_id + + +def gemm_MxKxN_impl(M, K, N, uniq_id): + """Emit C code for gemm impl.""" + # TODO(weberlo, areusch): are there any SIMD tricks to zero out arrays quickly? + aa_pad_size = M * K + bb_pad_size = N * K + # code reference: CMSIS-NN paper (https://arxiv.org/abs/1801.06601) + cc_code = f""" +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_{uniq_id}( + int8_t *aa, int8_t *bb, int32_t *cc, + int A_stride, int B_stride, int C_stride) {{ + int16_t aa_pad[{aa_pad_size}]; + int16_t bb_pad[{bb_pad_size}]; + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&aa[i*A_stride + j*4], (int32_t*) &aa_pad[i*{K} + j*4], (int32_t*) &aa_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {N}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&bb[i*B_stride + j*4], (int32_t*) &bb_pad[i*{K} + j*4], (int32_t*) &bb_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + int32_t sum = 0; + for (int l = 0; l < {K} / 2; l++) {{ + sum = __SMLAD( + *((int32_t*) &aa_pad[i*{K} + l*2]), + *((int32_t*) &bb_pad[j*{K} + l*2]), + sum); + }} + // NOTE: this is the line where `*_body` differs from `*_update`. here + // we're *setting* the result, instead of accumulating, because we know + // the `i` and `j` itervars span their entire respective axes. + cc[i*C_stride + j] = sum; + }} + }} + + return 0; +}} + +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_{uniq_id}( + int8_t *aa, int8_t *bb, int32_t *cc, + int A_stride, int B_stride, int C_stride) {{ + int16_t aa_pad[{aa_pad_size}]; + int16_t bb_pad[{bb_pad_size}]; + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&aa[i*A_stride + j*4], (int32_t*) &aa_pad[i*{K} + j*4], (int32_t*) &aa_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {N}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&bb[i*B_stride + j*4], (int32_t*) &bb_pad[i*{K} + j*4], (int32_t*) &bb_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + int32_t sum = 0; + for (int l = 0; l < {K} / 2; l++) {{ + sum = __SMLAD( + *((int32_t*) &aa_pad[i*{K} + l*2]), + *((int32_t*) &bb_pad[j*{K} + l*2]), + sum); + }} + cc[i*C_stride + j] += sum; + }} + }} + + return 0; +}} + +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{ + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + cc[i*C_stride + j] = 0; + }} + }} + return 0; +}} + """ + return cc_code diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 2b7a845cd9ec7..8ccd80f38a919 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -25,6 +25,7 @@ from .conv2d_hwcn import * from .conv2d_int8 import * from .conv2d_winograd import * +from .conv2d_nhwc_winograd import * from .depthwise_conv2d import * from .group_conv2d_nchw import * from . import conv2d_alter_op diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index 8d9e86c192a0a..c1e207cc29389 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -111,6 +111,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_winograd_without_weight_transform( inputs[0], weight, **new_attrs) + if topi_tmpl in ('conv2d_nhwc_winograd_direct.cuda', 'conv2d_nhwc_winograd_tensorcore.cuda'): + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + + # Pre-compute weight transformation in winograd + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) + weight = relay.nn.contrib_conv2d_winograd_weight_transform(kernel_transform, + tile_size=tile_size) + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) + new_attrs['tile_size'] = tile_size + new_attrs['channels'] = CO + # Store the same config for the altered operator (workload) + new_data = data + new_weight = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), + dtype=kernel.dtype) + if topi_tmpl == "conv2d_nhwc_winograd_direct.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_direct_without_weight_transform.cuda") + elif topi_tmpl == "conv2d_nhwc_winograd_tensorcore.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs) + if topi_tmpl == "group_conv2d_NCHWc_int8.cuda": assert data_layout == "NCHW" and kernel_layout == "OIHW" N, CI, H, W = get_const_tuple(data.shape) diff --git a/topi/python/topi/cuda/conv2d_nhwc_winograd.py b/topi/python/topi/cuda/conv2d_nhwc_winograd.py new file mode 100644 index 0000000000000..2f5b85eed6201 --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc_winograd.py @@ -0,0 +1,639 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +# pylint: disable=too-many-arguments,too-many-locals +# pylint: disable=too-many-statements +"""Winograd template for cuda backend""" + +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from ..util import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + +def _infer_tile_size(data, kernel): + """Compute the tile size""" + N, H, W, CI = get_const_tuple(data.shape) + if H % 8 == 0: + return 4 + return 2 + + +def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm tensorcore""" + A = data_pack + B = kernel_pack + C = bgemm + _, _, P, out_dim = get_const_tuple(C.shape) + out_dtype = C.dtype + + # Explicit memory access + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + CS = s.cache_read(CF, 'shared', [C]) + + # Create tuning space + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 1, 2, 4, 8]) + cfg.define_knob("offsetCS", [0, 1, 2, 4, 8]) + cfg.define_knob("vec", [1, 2, 4, 8]) + + # Ensure that the default parameters are applicable when autotvm is not in use + if (P % 16 == 0 and out_dim % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (P % 32 == 0 and out_dim % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + elif (P % 8 == 0 and out_dim % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + + warp_size = 32 + wmma_k = 16 + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offsetAB = cfg["offset"].val + offsetCS = cfg["offsetCS"].val + wmma_m = cfg["wmma_m"].val + vec = cfg["vec"].val + + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + # Define the stride of intrin functions + AS_align = chunk * wmma_k + offsetAB + BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB + CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS + AS_stride = [AS_align, 1] + BS_stride = [BS_align, 1] + AF_stride = [wmma_k, 1] + BF_stride = [wmma_n * warp_col_tiles, 1] + CF_stride = [warp_col_tiles * wmma_n, 1] + CS_stride = [CS_align, 1] + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Schedule for computation + block_factor_b = wmma_m * warp_row_tiles * block_row_warps + block_factor_o = wmma_n * warp_col_tiles * block_col_warps + alpha_1, alpha_2, b, o = C.op.axis + block_k = s[C].fuse(alpha_1, alpha_2) + block_i, bc = s[C].split(b, factor=block_factor_b) + block_j, oc = s[C].split(o, factor=block_factor_o) + s[C].reorder(block_k, block_i, block_j, bc, oc) + t = s[C].fuse(bc, oc) + t, vi = s[C].split(t, factor=vec) + t, tx = s[C].split(t, factor=warp_size) + t, ty = s[C].split(t, factor=block_row_warps) + t, tz = s[C].split(t, factor=block_col_warps) + s[C].bind(block_k, block_z) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(tz, thread_z) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].vectorize(vi) + + # Schedule for wmma store + s[CS].compute_at(s[C], block_j) + _, _, bb, oo = CS.op.axis + s[CS].storage_align(bb, CS_align - 1, CS_align) + bb, bbi = s[CS].split(bb, factor=wmma_m) + oo, ooi = s[CS].split(oo, factor=wmma_n) + bb, bbii = s[CS].split(bb, factor=warp_row_tiles) + oo, ooii = s[CS].split(oo, factor=warp_col_tiles) + s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi) + + # Schedule for wmma computation + s[CF].compute_at(s[CS], oo) + _, _, warp_i, warp_j = CF.op.axis + warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) + warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) + k, = CF.op.reduce_axis + k, _k = s[CF].split(k, factor=wmma_k) + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k) + + # Schedule for wmma_matrix_a load + s[AF].compute_at(s[CF], ki) + _, _, b, i = AF.op.axis + b, b_ii = s[AF].split(b, factor=wmma_m) + i, i_jj = s[AF].split(i, factor=wmma_k) + s[AF].reorder(b, i, b_ii, i_jj) + + # Schedule for wmma_matrix_b load + s[BF].compute_at(s[CF], ki) + _, _, i, o = BF.op.axis + o, o_ii = s[BF].split(o, factor=wmma_n) + i, i_ii = s[BF].split(i, factor=wmma_k) + s[BF].reorder(i, o, i_ii, o_ii) + + # Schedule for A's(B's) shared memory load + def shared_shedule(stage, strides): + s[stage].compute_at(s[CF], ko) + _, _, xo, yo = stage.op.axis + s[stage].storage_align(xo, strides - 1, strides) + t = s[stage].fuse(xo, yo) + t, vi = s[stage].split(t, factor=vec) + t, tx = s[stage].split(t, factor=warp_size) + t, ty = s[stage].split(t, factor=block_row_warps) + _, tz = s[stage].split(t, factor=block_col_warps) + s[stage].bind(ty, thread_y) + s[stage].bind(tz, thread_z) + s[stage].bind(tx, thread_x) + s[stage].vectorize(vi) + + shared_shedule(AS, AS_align) + shared_shedule(BS, BS_align) + + shape = (wmma_m, wmma_n, wmma_k) + in_dtype = 'float16' + AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) + BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') + CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * + BL_gemm[k_gemm, jj].astype(out_dtype), + axis=k_gemm), name='CL_compute') + + # Lower the computation loops down to TensorCore hardware intrinsics + # by mapping the tensorcore to tensor intrinsics + s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major", + (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) + s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major", + (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16')) + s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, + BF_stride, CF_stride, shape)) + s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype, + (wmma_m, wmma_n), (wmma_m, wmma_n))) + + +def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm direct""" + b1, b2, y, x = s[bgemm].op.axis + rc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + # Create tuning space + cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4, + filter=lambda x: x.size[-3:] == [1, 1, 1]) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8]) + cfg.define_knob("vector_bgemm", [1, 2, 4, 8]) + offset_bgemm = cfg["offset_bgemm"].val + vector_bgemm = cfg["vector_bgemm"].val + + C = bgemm + A0, B0 = kernel_pack, data_pack + + # Designate the memory hierarchy + OL = s.cache_write(C, 'local') + AA = s.cache_read(A0, 'shared', [OL]) + BB = s.cache_read(B0, 'shared', [OL]) + + # Tile and bind spatial axes + b = s[bgemm].fuse(b1, b2) + bgemm_scope, b = s[bgemm].split(b, nparts=1) + bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b) + by, vy, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x) + s[C].bind(bz, te.thread_axis("blockIdx.z")) + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(vz, te.thread_axis("vthread")) + s[C].bind(vy, te.thread_axis("vthread")) + s[C].bind(vx, te.thread_axis("vthread")) + s[C].bind(tz, te.thread_axis("threadIdx.z")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi) + + # Tile reduction axes + s[OL].compute_at(s[C], tx) + b1, b2, y, x = s[OL].op.axis + b = s[OL].fuse(b1, b2) + rc, = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, b, y, x, rci) + + s[AA].compute_at(s[OL], rco) + _, _, k, n = s[AA].op.axis + AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3] + s[AA].storage_align(k, AA_align - 1, AA_align) + + s[BB].compute_at(s[OL], rco) + _, _, m, k = s[BB].op.axis + BB_align = offset_bgemm + cfg["tile_rc"].size[1] + s[BB].storage_align(m, BB_align - 1, BB_align) + + # Schedule for A and B shared memory load + for load in [AA, BB]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, ti = s[load].split(fused, factor=vector_bgemm) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_b"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(ti) + + +def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore, pre_computed): + """Compute declaration for winograd""" + tile_size = _infer_tile_size(data, kernel) + N, H, W, CI = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # Kernel tensor is raw tensor, do strict check + if dilation_h != 1 or dilation_w != 1: + kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1)) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op. + # Dilation is not supported + alpha, _, CI, CO = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad") + + r = KW + m = tile_size + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + trans_type = "float16" + else: + trans_type = data.dtype + + # Compute transform matrix + A, _, _ = winograd_transform_matrices(m, r, out_dtype) + _, B, G = winograd_transform_matrices(m, r, data.dtype) + + # Transform kernel + if not pre_computed: + # Check if we are currently tuning, if so we want to avoid counting + # prepacking in time costs. Just use a placeholder with the packed shape instead. + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_pack = te.placeholder((alpha, alpha, CI, CO), + dtype=kernel.dtype, + name='kernel_pack') + else: + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: + te.sum((kernel[r_kh][r_kw][ci][co]) * + G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='kernel_pack') + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + # Pack input tile + input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu: + data_pad[idxdiv(p, (nH * nW)), + idxmod(idxdiv(p, nW), nH) * m + eps, + idxmod(p, nW) * m + nu, + c], name='d') + + # Transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack') + + # Convert data type of input feature maps and weights for tensorcore + Transdata = te.compute( + data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type)) + TransFilter = te.compute( + kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type)) + + # Do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) * + (TransFilter[eps][nu][ci][co]).astype(out_dtype), + axis=[ci]), name='bgemm') + + # Inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_a') + inverse = te.compute((P, CO, m, m), lambda p, co, vh, vw: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse') + + # Output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co, + idxmod(h, m), + idxmod(w, m)], + name='output', tag='conv2d_nhwc_winograd') + cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) + return output + + +def data_weight_transform(s, data_trans, input_tile, thread_num_trans, offset_trans, trans_tag): + """Schedule for data or kernel transform""" + kernel_align = thread_num_trans + offset_trans + indata_s = s.cache_read(input_tile, 'shared', [data_trans]) + data_l = s.cache_write(data_trans, 'local') + # Schedule for data or kernel transform + eps, nu, p, c = s[data_trans].op.axis + + block_x, thread_x = s[data_trans].split(c, thread_num_trans) + block_x = s[data_trans].fuse(p, block_x) + s[data_trans].reorder(block_x, thread_x, eps, nu) + s[data_trans].bind(thread_x, te.thread_axis("threadIdx.x")) + s[data_trans].bind(block_x, te.thread_axis("blockIdx.x")) + + s[data_l].compute_at(s[data_trans], thread_x) + eps_l, nu_l, p_l, c_l = s[data_l].op.axis + r_a, r_b = s[data_l].op.reduce_axis + block_x_l, thread_x_l = s[data_l].split(c_l, thread_num_trans) + block_x_l = s[data_l].fuse(p_l, block_x_l) + + s[data_l].reorder(block_x_l, thread_x_l, eps_l, nu_l, r_a, r_b) + + for axis in [eps_l, nu_l, r_a, r_b]: + s[data_l].unroll(axis) + + # Schedule for share memory load + s[indata_s].compute_at(s[data_l], block_x_l) + if trans_tag == "data": + p_is, c_is, eps_is, nu_is = s[indata_s].op.axis + data_align = get_const_int(eps_is.dom.extent) * \ + get_const_int(nu_is.dom.extent) + offset_trans + s[indata_s].storage_align(c_is, data_align - 1, data_align) + block_x_is, thread_x_is = s[indata_s].split(c_is, thread_num_trans) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + else: + eps_is, nu_is, ci_is, co_is = s[indata_s].op.axis + s[indata_s].storage_align(nu_is, kernel_align - 1, kernel_align) + block_x_is, thread_x_is = s[indata_s].split(co_is, thread_num_trans) + s[indata_s].reorder(ci_is, block_x_is, eps_is, nu_is, thread_x_is) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + + +def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed): + """Schedule winograd template""" + # Get stages + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + Transdata, TransFilter = s[bgemm].op.input_tensors + data_pack = s[Transdata].op.input_tensors[0] + kernel_pack = s[TransFilter].op.input_tensors[0] + s[Transdata].compute_inline() + s[TransFilter].compute_inline() + + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # Define the stride of intrin functions + cfg.define_knob("thread_num_inverse", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_data", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_kernel", [1, 32, 64, 128, 256]) + cfg.define_knob("offset_inverse", [0, 2, 4]) + cfg.define_knob("offset_data", [0, 1, 2, 4]) + cfg.define_knob("offset_kernel", [0, 1, 2, 4]) + cfg.define_knob("inverse_in_vector", [1, 2, 4]) + + thread_num_data = cfg["thread_num_data"].val + thread_num_kernel = cfg["thread_num_kernel"].val + thread_num_inverse = cfg["thread_num_inverse"].val + offset_data = cfg["offset_data"].val + offset_kernel = cfg["offset_kernel"].val + offset_inverse = cfg["offset_inverse"].val + inverse_in_vector = cfg["inverse_in_vector"].val + + # Data transform + s[B].compute_inline() + data_weight_transform(s, data_pack, input_tile, thread_num_data, offset_data, trans_tag="data") + s[input_tile].compute_inline() + s[pad_data].compute_inline() + + # Kernel transform + if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning: + kernel, G = s[kernel_pack].op.input_tensors + s[G].compute_inline() + data_weight_transform(s, kernel_pack, kernel, thread_num_kernel, + offset_kernel, trans_tag="kernel") + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + b1, b2, y, x = s[bgemm].op.axis + alpha = get_const_int(b1.dom.extent) + _, _, P, CI = get_const_tuple(Transdata.shape) + _, _, _, CO = get_const_tuple(TransFilter.shape) + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + schedule_bgemm_tensorcore(cfg, s, bgemm, Transdata, TransFilter) + else: + schedule_bgemm_direct(cfg, s, bgemm, Transdata, TransFilter) + + # Schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope('local') + output = s.outputs[0] + + s[A].compute_inline() + inverse_s = s.cache_read(bgemm, 'shared', [inverse]) + + m = alpha - 3 + 1 + offset_inverse_in = offset_inverse + vector_width_inverse_in = inverse_in_vector + + # Schedule for output + n, h, w, co = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + s[output].reorder(n, ho, wo, co, hi, wi) + fused = s[output].fuse(n, ho, wo) + + block_x_s, thread_x_s = s[output].split(co, thread_num_inverse) + block_x_s = s[output].fuse(fused, block_x_s) + s[output].reorder(block_x_s, thread_x_s, hi, wi) + + if OL is not None: + s[OL].compute_inline() + + # Schedule for inverse + s[inverse].compute_at(s[output], thread_x_s) + p_inv, co_inv, eps_inv, nu_inv = s[inverse].op.axis + block_x_inv, thread_x_inv = s[inverse].split(co_inv, thread_num_inverse) + r_a, r_b = s[inverse].op.reduce_axis + for axis in [eps_inv, nu_inv, r_a, r_b]: + s[inverse].unroll(axis) + + # Schedule for share memory load + s[inverse_s].compute_at(s[output], block_x_s) + eps_inv_s, nu_inv_s, p_inv_s, co_inv_s = s[inverse_s].op.axis + inverse_in_align = offset_inverse_in + thread_num_inverse + s[inverse_s].storage_align(p_inv_s, inverse_in_align - 1, inverse_in_align) + block_x_inv_s, thread_x_inv_s = s[inverse_s].split(co_inv_s, thread_num_inverse) + block_x_inv_s = s[inverse_s].fuse(p_inv_s, block_x_inv_s) + s[inverse_s].reorder(block_x_inv_s, eps_inv_s, nu_inv_s, thread_x_inv_s) + t = s[inverse_s].fuse(eps_inv_s, nu_inv_s, thread_x_inv_s) + t, ti = s[inverse_s].split(t, factor=vector_width_inverse_in) + t, tx = s[inverse_s].split(t, factor=thread_num_inverse) + s[inverse_s].bind(tx, te.thread_axis("threadIdx.x")) + s[inverse_s].vectorize(ti) + + s[output].bind(thread_x_s, te.thread_axis("threadIdx.x")) + s[output].bind(block_x_s, te.thread_axis("blockIdx.x")) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct.cuda") +def conv2d_nhwc_winograd_direct(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct.cuda") +def schedule_conv2d_nhwc_winograd_direct(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore.cuda") +def conv2d_nhwc_winograd_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def conv2d_nhwc_winograd_direct_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_direct_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index 26c18eeaa3068..98399843e55c9 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -22,7 +22,7 @@ from ..util import traverse_inline -def schedule_adaptive_pool(outs): +def schedule_adaptive_pool(outs, layout='NCHW'): """Schedule for adaptive_pool. Parameters @@ -51,8 +51,12 @@ def _schedule(Pool): else: Out = outs[0].op.output(0) s[Pool].set_scope("local") + by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) - bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) + if layout == 'NHWC': + bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread) + else: + bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) s[Out].reorder(by, bx, ty, tx) s[Out].bind(ty, thread_y) s[Out].bind(tx, thread_x) diff --git a/topi/python/topi/generic/default.py b/topi/python/topi/generic/default.py index d4c642ab88147..59e5a255c6e19 100644 --- a/topi/python/topi/generic/default.py +++ b/topi/python/topi/generic/default.py @@ -24,7 +24,7 @@ def default_schedule(outs, auto_inline): """Default schedule for llvm.""" target = tvm.target.Target.current(allow_none=False) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - if target.target_name != "llvm": + if target.target_name not in ("llvm", "c"): raise RuntimeError("schedule not registered for '%s'" % target) s = te.create_schedule([x.op for x in outs]) if auto_inline: diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index c95cbf88238df..d715308573a44 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -194,6 +194,74 @@ def sinh(x): return te.compute(x.shape, lambda *i: te.sinh(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def acos(x): + """Take arc cos of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.acos(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def acosh(x): + """Take arc cosh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.acosh(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def asin(x): + """Take arc sin of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.asin(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def asinh(x): + """Take arc sinh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.asinh(x(*i))) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def atan(x): """Take atan of input x. @@ -210,6 +278,22 @@ def atan(x): """ return te.compute(x.shape, lambda *i: te.atan(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.atanh(x(*i))) + @tvm.te.tag_scope(tag=tag.ELEMWISE) def floor(x): """Take floor of input x. diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index d8713110056a4..7c021785544c4 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -35,10 +35,8 @@ def _conv2d_nhwc_python(a_np, w_np, stride, padding): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str or a list/tuple of 2 or 4 ints - Padding size, or ['VALID', 'SAME'], or - [pad_height, pad_width] for 2 ints, or - [pad_top, pad_left, pad_bottom, pad_right] for 2 ints + padding : int or str or a list/tuple of two ints + Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] Returns ------- diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc index b14754573c648..e13c09ebb9226 100644 --- a/topi/src/broadcast.cc +++ b/topi/src/broadcast.cc @@ -18,39 +18,33 @@ */ /*! -* \brief Registration of broadcast operators -* \file broadcast.cc -*/ -#include -#include - + * \brief Registration of broadcast operators + * \file broadcast.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName) \ - .set_body([](TVMArgs args, TVMRetValue *rv) { \ - bool lhs_is_tensor = args[0].IsObjectRef(); \ - bool rhs_is_tensor = args[1].IsObjectRef(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::PrimExpr()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::PrimExpr()); \ - } \ - }); \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \ + bool lhs_is_tensor = args[0].IsObjectRef(); \ + bool rhs_is_tensor = args[1].IsObjectRef(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \ + } \ + }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); @@ -77,9 +71,8 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = broadcast_to(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/elemwise.cc b/topi/src/elemwise.cc index a19467c1399f6..10ac8f8c4ceef 100644 --- a/topi/src/elemwise.cc +++ b/topi/src/elemwise.cc @@ -18,162 +18,140 @@ */ /*! -* \brief Registration of elemwise operators -* \file elemwise.cc -*/ + * \brief Registration of elemwise operators + * \file elemwise.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = exp(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.acos").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = acos(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.acosh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = acosh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.asin").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = asin(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.asinh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = asinh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.atanh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = atanh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = exp(args[0]); }); -TVM_REGISTER_GLOBAL("topi.fast_exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_exp(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = erf(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = erf(args[0]); }); -TVM_REGISTER_GLOBAL("topi.fast_erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_erf(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.tan") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tan(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.tan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tan(args[0]); }); -TVM_REGISTER_GLOBAL("topi.cos") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cos(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.cos").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cos(args[0]); }); -TVM_REGISTER_GLOBAL("topi.cosh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cosh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cosh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sin") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sin(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.sin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sin(args[0]); }); -TVM_REGISTER_GLOBAL("topi.sinh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sinh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sinh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.fast_tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.atan") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.atan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = atan(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sigmoid") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sigmoid").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sigmoid(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sqrt").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sqrt(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rsqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { -*rv = rsqrt(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.rsqrt").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = rsqrt(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.log") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = log(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.log").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log(args[0]); }); -TVM_REGISTER_GLOBAL("topi.log2") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.log2").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log2(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.log10") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.log10").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log10(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.identity") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.identity").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = identity(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.negative") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.negative").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = negative(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.clip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.clip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = clip(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cast") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cast").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cast(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reinterpret") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.reinterpret").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reinterpret(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.elemwise_sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = elemwise_sum(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sign") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sign").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sign(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full_like") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full_like").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full_like(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.logical_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.logical_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = logical_not(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.bitwise_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = bitwise_not(args[0]); - }); +}); } // namespace topi diff --git a/topi/src/nn.cc b/topi/src/nn.cc index 77b208db0dd0b..3ec47787ec6ec 100644 --- a/topi/src/nn.cc +++ b/topi/src/nn.cc @@ -18,23 +18,22 @@ */ /*! -* \brief Registration of NN operators -* \file nn.cc -*/ -#include -#include - + * \brief Registration of NN operators + * \file nn.cc + */ #include +#include #include #include #include #include #include +#include #include #include #include -#include -#include +#include +#include namespace topi { @@ -42,144 +41,113 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = relu(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = leaky_relu(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.prelu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = prelu(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.pad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = pad(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::bias_add(args[0], args[1], args[2]); - }); +}); /* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::batch_matmul(args[0], args[1]); - }); +}); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dilate(args[0], args[1]); - }); +}); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::flatten(args[0]); - }); +}); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nchw(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nhwc(args[0], args[1], args[2]); - }); +}); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool_grad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], - static_cast(static_cast(args[5])), - args[6], args[7], args[8]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::global_pool(args[0], - static_cast(static_cast(args[1])), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool(args[0], args[1], - static_cast(static_cast(args[2])), + static_cast(static_cast(args[5])), args[6], args[7], + args[8]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::global_pool(args[0], static_cast(static_cast(args[1])), args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool3d(args[0], args[1], - static_cast(static_cast(args[2])), +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool3d(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.pool1d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool1d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool3d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::log_softmax(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::lrn(args[0], args[1], args[2], - static_cast(args[3]), - static_cast(args[4]), - static_cast(args[5])); - }); +TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::lrn(args[0], args[1], args[2], static_cast(args[3]), + static_cast(args[4]), static_cast(args[5])); +}); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binary_dense(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/reduction.cc b/topi/src/reduction.cc index e1fdada73eefd..b981495411bac 100644 --- a/topi/src/reduction.cc +++ b/topi/src/reduction.cc @@ -18,58 +18,49 @@ */ /*! -* \brief Registration of reduction operators -* \file reduction.cc -*/ -#include -#include - + * \brief Registration of reduction operators + * \file reduction.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.min") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.min").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.max") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmin") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.prod") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.all") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.all").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.any") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); } // namespace topi diff --git a/topi/src/schedule.cc b/topi/src/schedule.cc index 936f39031e6a5..b974acaf2dd59 100644 --- a/topi/src/schedule.cc +++ b/topi/src/schedule.cc @@ -18,212 +18,181 @@ */ /*! -* \brief Registration of TVM schedules -* \file schedule.cc -*/ + * \brief Registration of TVM schedules + * \file schedule.cc + */ #define TOPI_REDUCE_ATLEAST1D 0 -#include -#include -#include -#include -#include - -#include -#include -#include - #include #include +#include #include #include #include -#include - -#include -#include -#include - +#include +#include +#include +#include #include #include +#include #include #include #include -#include - -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.TEST_create_target") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tvm::Target::Create(args[0]); - }); +}); /* Generic schedules */ -TVM_REGISTER_GLOBAL("topi.generic.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::generic::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_extern(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); + }); /* x86 schedules */ -TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binary_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::x86::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); + }); /* ROCm schedules */ -TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_lrn(args[0]); - }); +}); /* CUDA schedules */ -TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_lrn(args[0]); - }); +}); /* Utility functions */ -TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::detail::is_empty_shape(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); - }); +}); /*! \brief Builder function for instantiating schedules. */ -using FTVMScheduleBuilder = std::function< - tvm::te::Schedule(const tvm::Target& target, const tvm::Array& outs)>; +using FTVMScheduleBuilder = std::function& outs)>; /*! * \brief Helper function for registering generic functions matching the @@ -242,7 +211,7 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { if (argNodeRef->type_index() == outs->type_index()) { outs = args[0]; } else { - outs = Array { args[0] }; + outs = Array{args[0]}; } *ret = builder(target, outs); @@ -250,49 +219,49 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(schedule_injective) -.set_default(WrapSchedule(topi::generic::schedule_injective)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); + .set_default(WrapSchedule(topi::generic::schedule_injective)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective)); TVM_REGISTER_GENERIC_FUNC(schedule_softmax) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax)); TVM_REGISTER_GENERIC_FUNC(schedule_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) -.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense)) + .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense)); TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) -.set_default(WrapSchedule(topi::generic::default_schedule)); + .set_default(WrapSchedule(topi::generic::default_schedule)); TVM_REGISTER_GENERIC_FUNC(schedule_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_reduce) -.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); + .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce)); TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack)); TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense)); /*! \brief Builder function for instantiating schedules from existing schedules. */ -using FTVMScheduleFromExistingBuilder = std::function< - tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; +using FTVMScheduleFromExistingBuilder = + std::function; /*! * \brief Helper function for registering generic functions matching the @@ -304,33 +273,30 @@ using FTVMScheduleFromExistingBuilder = std::function< * \return The wrapped schedule builder */ inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - *ret = builder(args[0], args[1]); - }); + return PackedFunc( + [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); }); } TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) -.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) -.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) -.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( - topi::cuda::schedule_injective_from_existing)); + .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) + .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) + .register_func({"cuda", "gpu"}, + WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); /*! \brief Builder function for instantiating dense ops. */ -using FTVMDenseOpBuilder = std::function; +using FTVMDenseOpBuilder = std::function; /*! -* \brief Helper function for registering dense ops matching the -* FTVMDenseOpBuilder signature. The op builder function is wrapped -* with a PackedFunc suitable for passing to a tvm::GenericFunc. -* -* \param builder The op builder to wrap. -* -* \return The wrapped op builder -*/ + * \brief Helper function for registering dense ops matching the + * FTVMDenseOpBuilder signature. The op builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The op builder to wrap. + * + * \return The wrapped op builder + */ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { auto target = Target::Current(false); @@ -344,14 +310,12 @@ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(dense) -.set_default(WrapDenseOp([](const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { - return topi::nn::dense(data, weight, bias, out_dtype); -})) -.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) -.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); + .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { + return topi::nn::dense(data, weight, bias, out_dtype); + })) + .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda)) + .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm)); } // namespace topi diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 4f0d4f8e68256..fa27b995c3652 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -18,67 +18,56 @@ */ /*! -* \brief Registration of transform operators -* \file transform.cc -*/ -#include -#include - + * \brief Registration of transform operators + * \file transform.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.expand_dims").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = expand_dims(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.transpose") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = transpose(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.flip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = flip(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reshape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reshape(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.squeeze") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = squeeze(args[0], ArrayOrInt(args[1])); - }); +}); -TVM_REGISTER_GLOBAL("topi.concatenate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = concatenate(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.stack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = stack(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = shape(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ndarray_size(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.split") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { *rv = split_sections(args[0], args[1], args[2]); } else { @@ -86,13 +75,11 @@ TVM_REGISTER_GLOBAL("topi.split") } }); -TVM_REGISTER_GLOBAL("topi.layout_transform") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = layout_transform(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.take") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 3) { std::string mode = args[2]; *rv = take(args[0], args[1], mode); @@ -101,56 +88,55 @@ TVM_REGISTER_GLOBAL("topi.take") std::string mode = args[3]; *rv = take(args[0], args[1], axis, mode); } - }); +}); -TVM_REGISTER_GLOBAL("topi.sequence_mask") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body([](TVMArgs args, TVMRetValue* rv) { double pad_val = args[2]; int axis = args[3]; *rv = sequence_mask(args[0], args[1], pad_val, axis); }); -TVM_REGISTER_GLOBAL("topi.where") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.where").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = where(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.arange") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = arange(args[0], args[1], args[2], args[3]); }); -TVM_REGISTER_GLOBAL("topi.repeat") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = repeat(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.tile") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tile(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.gather_nd") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = gather_nd(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.unravel_index") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = unravel_index(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { - switch ( args.size() ) { - case 2: *rv = matmul(args[0], args[1]); break; - case 3: *rv = matmul(args[0], args[1], args[2]); break; - case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break; - default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; - }}); - -TVM_REGISTER_GLOBAL("topi.tensordot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +}); + +TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) { + switch (args.size()) { + case 2: + *rv = matmul(args[0], args[1]); + break; + case 3: + *rv = matmul(args[0], args[1], args[2]); + break; + case 4: + *rv = matmul(args[0], args[1], args[2], args[3]); + break; + default: + CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + } +}); + +TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 2) { *rv = tensordot(args[0], args[1]); } else if (args.size() == 3) { @@ -159,19 +145,17 @@ TVM_REGISTER_GLOBAL("topi.tensordot") Array axes = args[3]; *rv = tensordot(args[0], args[1], args[2], axes); } - }); +}); -TVM_REGISTER_GLOBAL("topi.strided_slice") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = strided_slice(args[0], args[1], args[2], args[3]); - }); +}); -TVM_REGISTER_GLOBAL("topi.one_hot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { int depth = args[3]; int axis = args[4]; DataType dtype = args[5]; *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); - }); +}); } // namespace topi diff --git a/topi/src/vision.cc b/topi/src/vision.cc index 1a4884e8d7c6f..0485177cf9d58 100644 --- a/topi/src/vision.cc +++ b/topi/src/vision.cc @@ -18,22 +18,20 @@ */ /*! -* \brief Registration of vision operators -* \file vision.cc -*/ + * \brief Registration of vision operators + * \file vision.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = vision::reorg(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py new file mode 100644 index 0000000000000..a7e55320d6ddc --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +# pylint: disable=bad-whitespace +"""Example code to do convolution.""" + +import numpy as np +import tvm +import topi +import topi.testing +from tvm import te +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + + +_conv2d_nhwc_winograd_tensorcore = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_tensorcore, + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore) +} + +_conv2d_nhwc_winograd_direct = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_direct, + topi.cuda.schedule_conv2d_nhwc_winograd_direct) +} + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, + devices='cuda', bgemm="direct"): + """Test the conv2d with winograd for nhwc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_height, in_width, in_channel), name='A') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') + bias = te.placeholder((1, 1, 1, num_filter), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + if bgemm == "direct": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_direct) + elif bgemm == "tensorcore": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_tensorcore) + C = fcompute(A, W, stride, padding, dilation, 'float32') + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3) + + check_device(devices) + + +def test_conv2d_nhwc_winograd_direct(): + """Test the conv2d with winograd for nhwc layout""" + # resnet 18 workloads + print("test_winograd_direct...") + verify_conv2d_nhwc(1, 64, 56, 64, 3, 1, 1, bgemm="direct") + verify_conv2d_nhwc(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1) + verify_conv2d_nhwc(1, 48, 35, 64, 5, 1, 2) + + # weird workloads + verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1) + verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1) + verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1) + + # Asymmetric padding + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, "SAME") + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, (1, 1), add_relu=True) + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True) + verify_conv2d_nhwc(1, 48, 35, 48, 5, 1, "VALID") + +def test_conv2d_nhwc_winograd_tensorcore(): + """Test the conv2d with winograd for nhwc layout""" + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + return + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") + + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore") + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore") + + +if __name__ == "__main__": + test_conv2d_nhwc_winograd_direct() + test_conv2d_nhwc_winograd_tensorcore() diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py index 9bdbb1073fd05..9f71a316f7e1d 100644 --- a/topi/tests/python/test_topi_pooling.py +++ b/topi/tests/python/test_topi_pooling.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument """Test code for pooling""" import math import numpy as np @@ -44,6 +45,7 @@ } def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): + """verify function of pool""" iw = ih kw = kh sw = sh @@ -76,15 +78,17 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ for i in range(oh): for j in range(ow): if count_include_pad: - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) + b_np[:, :, i, j] = \ + np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) else: - pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3)) - b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1) + pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3)) + b_np[:, :, i, j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) \ + / np.maximum(pad_count, 1) - elif pool_type =='max': + elif pool_type == 'max': for i in range(oh): for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) + b_np[:, :, i, j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) b_np = np.maximum(b_np, 0.0) def check_device(device): @@ -108,11 +112,11 @@ def check_device(device): def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, add_relu=False): + """verify function of pool_grad""" iw = ih kw = kh sw = sh pt, pl, pb, pr = padding - layout = "NCHW" A = te.placeholder((n, ic, ih, iw), name='A') B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, @@ -164,6 +168,7 @@ def check_device(device): check_device(device) def test_pool(): + """test cases of pool""" verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False) @@ -179,6 +184,7 @@ def test_pool(): verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) def test_pool_grad(): + """test cases of pool_grad""" verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) @@ -200,10 +206,10 @@ def test_pool_grad(): verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True) -def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): - +def verify_global_pool(dshape, pool_type, layout='NCHW'): + """verify function of global_pool""" assert layout in ["NCHW", "NHWC"] - A = te.placeholder((n, c, h, w), name='A') + A = te.placeholder(shape=dshape, name='A') B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout) B = topi.nn.relu(B) @@ -212,7 +218,7 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): axis = (layout.find('H'), layout.find('W')) if pool_type == 'avg': b_np = np.mean(a_np, axis=axis, keepdims=True) - elif pool_type =='max': + elif pool_type == 'max': b_np = np.max(a_np, axis=axis, keepdims=True) b_np = np.maximum(b_np, 0.0) @@ -224,7 +230,10 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) - s = s_func(B) + if device == "cuda": + s = s_func(B, layout) + else: + s = s_func(B) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) f = tvm.build(s, [A, B], device) @@ -235,17 +244,19 @@ def check_device(device): check_device(device) def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC') - verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC') - verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC') - verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC') + """test cases of global_pool""" + verify_global_pool((1, 1024, 7, 7), 'avg') + verify_global_pool((4, 1024, 7, 7), 'avg') + verify_global_pool((1, 1024, 7, 7), 'max') + verify_global_pool((4, 1024, 7, 7), 'max') + verify_global_pool((1, 7, 7, 1024), 'avg', 'NHWC') + verify_global_pool((4, 7, 7, 1024), 'avg', 'NHWC') + verify_global_pool((1, 7, 7, 1024), 'max', 'NHWC') + verify_global_pool((4, 7, 7, 1024), 'max', 'NHWC') def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): + """verify function of adaptive_pool""" np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) oshape = np_out.shape @@ -265,7 +276,10 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) - s = s_func(out) + if device == "cuda": + s = s_func(out, layout) + else: + s = s_func(out) a = tvm.nd.array(np_data, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx) f = tvm.build(s, [data, out], device) @@ -277,6 +291,7 @@ def check_device(device): def test_adaptive_pool(): + """test cases of adaptive_pool""" verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max") verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg") verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") @@ -295,6 +310,7 @@ def test_adaptive_pool(): def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, layout='NCDHW'): + """verify function of pool3d""" id = iw = ih kd = kw = kh sd = sw = sh @@ -334,6 +350,7 @@ def check_device(device): def test_pool3d(): + """test cases of pool3d""" verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True) verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True) verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False) @@ -351,6 +368,7 @@ def test_pool3d(): def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type, ceil_mode, count_include_pad=True, layout='NCW'): + """verify function of pool1d""" input_shape = (n, ic, iw) kernel = [kw] stride = [sw] @@ -387,6 +405,7 @@ def check_device(device): def test_pool1d(): + """test cases of pool1d""" verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True) verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 32738551f7c8e..e01a4ecdf4c36 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -21,25 +21,12 @@ This article is an introductory tutorial to deploy TFLite models with Relay. -To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. -A quick solution is to install Flatbuffers via pip +To get started, TFLite package needs to be installed as prerequisite. .. code-block:: bash - pip install flatbuffers --user - - -To install TFlite packages, you could use our prebuilt wheel: - -.. code-block:: bash - - # For python3: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl - pip3 install -U tflite-1.13.1-py3-none-any.whl --user - - # For python2: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl - pip install -U tflite-1.13.1-py2-none-any.whl --user + # install tflite + pip install tflite=2.1.0 --user or you could generate TFLite package yourself. The steps are the following: diff --git a/vta/runtime/device_api.cc b/vta/runtime/device_api.cc index 047a6fdbd50d7..298403ca840d7 100644 --- a/vta/runtime/device_api.cc +++ b/vta/runtime/device_api.cc @@ -22,12 +22,11 @@ * \brief TVM device API for VTA */ -#include #include +#include -#include "runtime.h" #include "../../src/runtime/workspace_pool.h" - +#include "runtime.h" namespace tvm { namespace runtime { @@ -42,25 +41,14 @@ class VTADeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final { + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final { return VTABufferAlloc(size); } - void FreeDataSpace(TVMContext ctx, void* ptr) final { - VTABufferFree(ptr); - } + void FreeDataSpace(TVMContext ctx, void* ptr) final { VTABufferFree(ptr); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int kind_mask = 0; if (ctx_from.device_type != kDLCPU) { @@ -69,33 +57,27 @@ class VTADeviceAPI final : public DeviceAPI { if (ctx_to.device_type != kDLCPU) { kind_mask |= 1; } - VTABufferCopy(from, from_offset, - to, to_offset, - size, kind_mask); + VTABufferCopy(from, from_offset, to, to_offset, size, kind_mask); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct VTAWorkspacePool : public WorkspacePool { - VTAWorkspacePool() : - WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} + VTAWorkspacePool() : WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} }; void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { @@ -104,10 +86,10 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { // Register device api with override. static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ = -::tvm::runtime::Registry::Register("device_api.ext_dev", true) -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VTADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); + ::tvm::runtime::Registry::Register("device_api.ext_dev", true) + .set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = VTADeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm diff --git a/vta/runtime/runtime.cc b/vta/runtime/runtime.cc index 038d5cfa398c6..49fe9c557336f 100644 --- a/vta/runtime/runtime.cc +++ b/vta/runtime/runtime.cc @@ -24,24 +24,23 @@ * The runtime depends on specific instruction * stream spec as specified in hw_spec.h */ -#include -#include +#include "runtime.h" + #include #include +#include +#include #include #include #include -#include #include - -#include "runtime.h" +#include namespace vta { // Avoid bad configurations. -static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, - "VTA_UOP_WIDTH do not match VTAUop size"); +static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, "VTA_UOP_WIDTH do not match VTAUop size"); /*! \brief Enable coherent access of data buffers between VTA and CPU */ static const bool kBufferCoherent = VTA_COHERENT_ACCESSES; @@ -53,13 +52,9 @@ static const bool kAlwaysCache = true; */ struct DataBuffer { /*! \return Virtual address of the data. */ - void* virt_addr() const { - return data_; - } + void* virt_addr() const { return data_; } /*! \return Physical address of the data. */ - vta_phy_addr_t phy_addr() const { - return phy_addr_; - } + vta_phy_addr_t phy_addr() const { return phy_addr_; } /*! * \brief Invalidate the cache of given location in data buffer. * \param offset The offset to the data. @@ -67,9 +62,7 @@ struct DataBuffer { */ void InvalidateCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAInvalidateCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAInvalidateCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! @@ -79,16 +72,14 @@ struct DataBuffer { */ void FlushCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAFlushCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAFlushCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! * \brief Performs a copy operation from host memory to buffer allocated with VTAMemAlloc. - * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). - * \param src The source buffer in host memory. - * \param size Size of the region in Bytes. + * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with + * VTAMemAlloc(). \param src The source buffer in host memory. \param size Size of the region in + * Bytes. */ void MemCopyFromHost(void* dst, const void* src, size_t size) { VTAMemCopyFromHost(dst, src, size); @@ -99,9 +90,7 @@ struct DataBuffer { * \param src The source buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). * \param size Size of the region in Bytes. */ - void MemCopyToHost(void* dst, const void* src, size_t size) { - VTAMemCopyToHost(dst, src, size); - } + void MemCopyToHost(void* dst, const void* src, size_t size) { VTAMemCopyToHost(dst, src, size); } /*! * \brief Allocate a buffer of a given size. * \param size The size of the buffer. @@ -128,8 +117,7 @@ struct DataBuffer { * \return The corresponding data buffer header. */ static DataBuffer* FromHandle(const void* buffer) { - return const_cast( - reinterpret_cast(buffer)); + return const_cast(reinterpret_cast(buffer)); } private: @@ -157,9 +145,7 @@ class UopKernel { * \param signature The pointer to signature. * \param nbytes Number of bytes. */ - UopKernel(const char* signature, int nbytes) - : signature_(signature, signature + nbytes) { - } + UopKernel(const char* signature, int nbytes) : signature_(signature, signature + nbytes) {} /*! * \brief Verify if the signature is correct. * \param signature Signature ptr. @@ -170,21 +156,13 @@ class UopKernel { return memcmp(signature, signature_.data(), nbytes) == 0; } /*! \return Whether the kernel is cached in SRAM. */ - bool cached() const { - return sram_begin_ != sram_end_; - } + bool cached() const { return sram_begin_ != sram_end_; } /*! \return The length of the micro op sequence. */ - size_t size() const { - return seq_.size(); - } + size_t size() const { return seq_.size(); } /*! \return The micro-op data. */ - const VTAUop* data() const { - return seq_.data(); - } + const VTAUop* data() const { return seq_.data(); } /*! \return The loop structure. */ - const std::vector& loop() const { - return loop_; - } + const std::vector& loop() const { return loop_; } /*! * \brief Declare loop start. * \param extent The loop extent. @@ -192,9 +170,7 @@ class UopKernel { * \param src_factor Loop factor of input index * \param wgt_factor Loop factor of weight index. */ - void PushLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, + void PushLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { LoopEntry le; le.extent = extent; @@ -209,9 +185,7 @@ class UopKernel { /*! * \brief Declare loop end. */ - void PushLoopEnd() { - --loop_ptr_; - } + void PushLoopEnd() { --loop_ptr_; } /*! * \brief Push micro op into kernel. * \param mode Set to GEMM mode if set to 0, ALU mode is set to 1. @@ -223,14 +197,8 @@ class UopKernel { * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ - void Push(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { + void Push(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { // The loop nest structure VerifyDep(dst_index); VTAUop op; @@ -268,10 +236,7 @@ class UopKernel { uint32_t size = seq_.size(); printf("There are %u uops\n", size); for (uint32_t i = 0; i < size; ++i) { - printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", - i, - seq_[i].dst_idx, - seq_[i].src_idx, + printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", i, seq_[i].dst_idx, seq_[i].src_idx, seq_[i].wgt_idx); } printf("\n"); @@ -294,7 +259,7 @@ class UopKernel { } } // The uop buffer - template + template friend class UopQueue; friend class CommandQueue; // SRAM location if begin != end @@ -322,26 +287,21 @@ class BaseQueue { } } /*! \return Content of DRAM buffer. */ - char* dram_buffer() const { - return dram_buffer_; - } + char* dram_buffer() const { return dram_buffer_; } /*! \return Physical address of DRAM. */ vta_phy_addr_t dram_phy_addr() const { CHECK(fpga_buff_phy_); return fpga_buff_phy_; } /*! \return Whether there is pending information. */ - bool pending() const { - return sram_begin_ != sram_end_; - } + bool pending() const { return sram_begin_ != sram_end_; } /*! \brief Initialize the space of the buffer. */ void InitSpace(uint32_t elem_bytes, uint32_t max_bytes, bool coherent, bool always_cache) { coherent_ = coherent; always_cache_ = always_cache; elem_bytes_ = elem_bytes; // Allocate buffer ahead of time - fpga_buff_ = static_cast(VTAMemAlloc( - max_bytes, coherent_ || always_cache_)); + fpga_buff_ = static_cast(VTAMemAlloc(max_bytes, coherent_ || always_cache_)); CHECK(fpga_buff_ != nullptr); fpga_buff_phy_ = VTAMemGetPhyAddr(fpga_buff_); } @@ -351,6 +311,9 @@ class BaseQueue { */ virtual void Reset() { dram_buffer_.clear(); + // reset to 0 as we always copy data to area starting from fpga_buff base + // we do mem copy for every DeviceRun + sram_end_ = 0; sram_begin_ = sram_end_; } @@ -376,14 +339,12 @@ class BaseQueue { /*! * \brief Micro op buffer that manages the micro op cache. */ -template +template class UopQueue : public BaseQueue { public: - void InitSpace() { - BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); - } + void InitSpace() { BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); } // Push data to the queue - template + template void Push(UopKernel* kernel, FAutoSync fautosync) { // if the micro-op is cached in VTA SRAM, skip if (kernel->cached()) return; @@ -446,13 +407,18 @@ class UopQueue : public BaseQueue { } /*! \brief clear cache and reset base queue buffer.*/ void Reset() { + // unmark "cached" status + // as we cannot assume it is still in SRAM across DeviceRun + for (UopKernel* kernel : cache_) { + kernel->sram_begin_ = 0; + kernel->sram_end_ = 0; + } + cache_.clear(); cache_idx_ = 0; BaseQueue::Reset(); } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -467,18 +433,14 @@ class UopQueue : public BaseQueue { uint32_t offset = 0; for (uint32_t i = 0; i < cache_.size(); ++i) { uint32_t ksize = cache_[i]->size() * kElemBytes; - VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, - cache_[i]->data(), - ksize); + VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, cache_[i]->data(), ksize); // Update offset offset += ksize; } // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - offset); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, offset); } } @@ -497,8 +459,7 @@ class UopQueue : public BaseQueue { class UopKernelMap { public: // Simple hash map - UopKernel** Get(void* signature, - int nbytes) { + UopKernel** Get(void* signature, int nbytes) { uint32_t key = 0; CHECK(nbytes == 0 || nbytes == sizeof(int)); if (nbytes == sizeof(int)) { @@ -516,15 +477,10 @@ class UopKernelMap { std::vector kmap_; }; -enum PipelineStage : int { - kNoneStage = 0, - kLoadStage = 1, - kComputeStage = 2, - kStoreStage = 3 -}; +enum PipelineStage : int { kNoneStage = 0, kLoadStage = 1, kComputeStage = 2, kStoreStage = 3 }; // Instruction Queue -template +template class InsnQueue : public BaseQueue { public: /*! \brief Initialize the space. */ @@ -535,13 +491,9 @@ class InsnQueue : public BaseQueue { std::fill(pending_pop_next_, pending_pop_next_ + 4, 0); } /*! \return The data pointer. */ - VTAGenericInsn* data() { - return dram_buffer_.data(); - } + VTAGenericInsn* data() { return dram_buffer_.data(); } /*! \return Number of instructions. */ - uint32_t count() { - return dram_buffer_.size(); - } + uint32_t count() { return dram_buffer_.size(); } // Insert dependency push of load void DepPop(int from, int to) { // NOTE: This instruction executes on queue[to] @@ -569,10 +521,12 @@ class InsnQueue : public BaseQueue { if (GetPipelineStage(mptr) == from) { if (from < to && !mptr->push_next_dep) { // push(LD->C) or push(C->ST) - mptr->push_next_dep = true; return; + mptr->push_next_dep = true; + return; } else if (from > to && !mptr->push_prev_dep) { // push(C->LD) or push(ST->C) - mptr->push_prev_dep = true; return; + mptr->push_prev_dep = true; + return; } } } @@ -585,25 +539,15 @@ class InsnQueue : public BaseQueue { } } // Create a new instruction for a GEMM stage - VTAGemInsn* CreateGemInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAGemInsn* CreateGemInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a ALU stage - VTAAluInsn* CreateAluInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAAluInsn* CreateAluInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a memory stage VTAMemInsn* CreateMemInsn(int memory_type) { - return reinterpret_cast( - Create(GetMemPipelineStage(memory_type))); + return reinterpret_cast(Create(GetMemPipelineStage(memory_type))); } // create a new instruction for a store stage - VTAMemInsn* CreateStoreInsn() { - return reinterpret_cast( - Create(kStoreStage)); - } + VTAMemInsn* CreateStoreInsn() { return reinterpret_cast(Create(kStoreStage)); } // Rewrite instruction stream to force serial execution void RewriteForceSerial() { int insn_count = count(); @@ -653,7 +597,7 @@ class InsnQueue : public BaseQueue { } CommitPendingPop(kComputeStage); } else { - pending_pop_next_[kComputeStage] = 0; + pending_pop_next_[kComputeStage] = 0; } DepPush(kComputeStage, kLoadStage); DepPop(kLoadStage, kComputeStage); @@ -666,30 +610,30 @@ class InsnQueue : public BaseQueue { } // Helper function: Get Opcode string const char* getOpcodeString(int opcode, bool use_imm) { - // The string name - if (opcode == VTA_ALU_OPCODE_MIN) { - if (use_imm) { - return "min imm"; - } else { - return "min"; - } - } else if (opcode == VTA_ALU_OPCODE_MAX) { - if (use_imm) { - return "max imm"; - } else { - return "max"; - } - } else if (opcode == VTA_ALU_OPCODE_ADD) { - if (use_imm) { - return "add imm"; - } else { - return "add"; - } - } else if (opcode == VTA_ALU_OPCODE_SHR) { - return "shr"; + // The string name + if (opcode == VTA_ALU_OPCODE_MIN) { + if (use_imm) { + return "min imm"; + } else { + return "min"; + } + } else if (opcode == VTA_ALU_OPCODE_MAX) { + if (use_imm) { + return "max imm"; + } else { + return "max"; + } + } else if (opcode == VTA_ALU_OPCODE_ADD) { + if (use_imm) { + return "add imm"; + } else { + return "add"; } + } else if (opcode == VTA_ALU_OPCODE_SHR) { + return "shr"; + } - return "unknown op"; + return "unknown op"; } // Dump instructions in the queue void DumpInsn() { @@ -718,10 +662,8 @@ class InsnQueue : public BaseQueue { printf("NOP-MEMORY-STAGE\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); // Count status in queues if (c.mem.opcode == VTA_OPCODE_STORE) { CHECK(c.mem.pop_next_dep == false); @@ -729,8 +671,7 @@ class InsnQueue : public BaseQueue { if (c.mem.pop_prev_dep) g2s_queue--; if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { CHECK(c.mem.pop_prev_dep == false); CHECK(c.mem.push_prev_dep == false); if (c.mem.pop_next_dep) g2l_queue--; @@ -757,65 +698,44 @@ class InsnQueue : public BaseQueue { printf("STORE:\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); - printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", - static_cast(c.mem.dram_base), + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); + printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", static_cast(c.mem.dram_base), static_cast(c.mem.sram_base)); - printf("\ty: size=%d, pad=[%d, %d]\n", - static_cast(c.mem.y_size), - static_cast(c.mem.y_pad_0), - static_cast(c.mem.y_pad_1)); - printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", - static_cast(c.mem.x_size), - static_cast(c.mem.x_stride), - static_cast(c.mem.x_pad_0), + printf("\ty: size=%d, pad=[%d, %d]\n", static_cast(c.mem.y_size), + static_cast(c.mem.y_pad_0), static_cast(c.mem.y_pad_1)); + printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", static_cast(c.mem.x_size), + static_cast(c.mem.x_stride), static_cast(c.mem.x_pad_0), static_cast(c.mem.x_pad_1)); } else if (c.mem.opcode == VTA_OPCODE_GEMM) { // Print instruction field information printf("GEMM\n"); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.gemm.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.gemm.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.gemm.uop_bgn), static_cast(c.gemm.uop_end)); printf("\touter loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_out), - static_cast(c.gemm.wgt_factor_out), - static_cast(c.gemm.src_factor_out), - static_cast(c.gemm.dst_factor_out)); + static_cast(c.gemm.iter_out), static_cast(c.gemm.wgt_factor_out), + static_cast(c.gemm.src_factor_out), static_cast(c.gemm.dst_factor_out)); printf("\tinner loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_in), - static_cast(c.gemm.wgt_factor_in), - static_cast(c.gemm.src_factor_in), - static_cast(c.gemm.dst_factor_in)); + static_cast(c.gemm.iter_in), static_cast(c.gemm.wgt_factor_in), + static_cast(c.gemm.src_factor_in), static_cast(c.gemm.dst_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information printf("ALU - %s\n", getOpcodeString(c.alu.alu_opcode, c.alu.use_imm)); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.alu.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.alu.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.alu.uop_bgn), static_cast(c.alu.uop_end)); - printf("\touter loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_out), - static_cast(c.alu.dst_factor_out), - static_cast(c.alu.src_factor_out)); - printf("\tinner loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_in), - static_cast(c.alu.dst_factor_in), - static_cast(c.alu.src_factor_in)); + printf("\touter loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_out), + static_cast(c.alu.dst_factor_out), static_cast(c.alu.src_factor_out)); + printf("\tinner loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_in), + static_cast(c.alu.dst_factor_in), static_cast(c.alu.src_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_FINISH) { printf("FINISH\n"); } @@ -823,25 +743,23 @@ class InsnQueue : public BaseQueue { // Count status in queues if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) { - CHECK(c.mem.pop_next_dep == false); - CHECK(c.mem.push_next_dep == false); - if (c.mem.pop_prev_dep) g2s_queue--; - if (c.mem.push_prev_dep) s2g_queue++; + CHECK(c.mem.pop_next_dep == false); + CHECK(c.mem.push_next_dep == false); + if (c.mem.pop_prev_dep) g2s_queue--; + if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { - CHECK(c.mem.pop_prev_dep == false); - CHECK(c.mem.push_prev_dep == false); - if (c.mem.pop_next_dep) g2l_queue--; - if (c.mem.push_next_dep) l2g_queue++; + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { + CHECK(c.mem.pop_prev_dep == false); + CHECK(c.mem.push_prev_dep == false); + if (c.mem.pop_next_dep) g2l_queue--; + if (c.mem.push_next_dep) l2g_queue++; } else { - if (c.mem.pop_prev_dep) l2g_queue--; - if (c.mem.push_prev_dep) g2l_queue++; - if (c.mem.pop_next_dep) s2g_queue--; - if (c.mem.push_next_dep) g2s_queue++; + if (c.mem.pop_prev_dep) l2g_queue--; + if (c.mem.push_prev_dep) g2l_queue++; + if (c.mem.pop_next_dep) s2g_queue--; + if (c.mem.push_next_dep) g2s_queue++; } - } else if (c.mem.opcode == VTA_OPCODE_GEMM || - c.mem.opcode == VTA_OPCODE_ALU) { + } else if (c.mem.opcode == VTA_OPCODE_GEMM || c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information if (c.gemm.pop_prev_dep) l2g_queue--; if (c.gemm.push_prev_dep) g2l_queue++; @@ -857,11 +775,8 @@ class InsnQueue : public BaseQueue { // Handle the LD<->compute queue // NOTE: pop executes on target(stage) CHECK(stage > 0 && stage < 4); - if (pending_pop_prev_[stage] || - pending_pop_next_[stage]) { - PushNoop(stage, false, false, - pending_pop_prev_[stage], - pending_pop_next_[stage]); + if (pending_pop_prev_[stage] || pending_pop_next_[stage]) { + PushNoop(stage, false, false, pending_pop_prev_[stage], pending_pop_next_[stage]); pending_pop_prev_[stage] = 0; pending_pop_next_[stage] = 0; } @@ -878,9 +793,7 @@ class InsnQueue : public BaseQueue { } return false; } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -888,15 +801,11 @@ class InsnQueue : public BaseQueue { uint32_t buff_size = dram_buffer_.size() * elem_bytes_; CHECK(buff_size <= kMaxBytes); // Copy contents of DRAM buffer to FPGA buff - VTAMemCopyFromHost(fpga_buff_, - dram_buffer_.data(), - buff_size); + VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size); // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - buff_size); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, buff_size); } } @@ -947,15 +856,14 @@ class InsnQueue : public BaseQueue { // Get stage of memory and computation static PipelineStage GetPipelineStageAll(VTAMemInsn* insn) { - PipelineStage stage = GetPipelineStage(insn); - if (stage != kNoneStage) return stage; - return GetMemPipelineStage(insn->memory_type); + PipelineStage stage = GetPipelineStage(insn); + if (stage != kNoneStage) return stage; + return GetMemPipelineStage(insn->memory_type); } // Push no-op - void PushNoop(int stage, - bool push_prev_dep, bool push_next_dep, - bool pop_prev_dep, bool pop_next_dep) { + void PushNoop(int stage, bool push_prev_dep, bool push_next_dep, bool pop_prev_dep, + bool pop_next_dep) { VTAMemInsn* insn = reinterpret_cast(NextInsn()); insn->opcode = (stage == kStoreStage ? VTA_OPCODE_STORE : VTA_OPCODE_LOAD); insn->push_prev_dep = push_prev_dep; @@ -987,9 +895,7 @@ class InsnQueue : public BaseQueue { */ class CommandQueue { public: - CommandQueue() { - this->InitSpace(); - } + CommandQueue() { this->InitSpace(); } void InitSpace() { uop_queue_.InitSpace(); insn_queue_.InitSpace(); @@ -997,31 +903,29 @@ class CommandQueue { CHECK(device_ != nullptr); } - ~CommandQueue() { - VTADeviceFree(device_); - } + ~CommandQueue() { VTADeviceFree(device_); } uint32_t GetElemBytes(uint32_t memory_id) { uint32_t elem_bytes = 0; switch (memory_id) { case VTA_MEM_ID_UOP: - elem_bytes = VTA_UOP_ELEM_BYTES; - break; + elem_bytes = VTA_UOP_ELEM_BYTES; + break; case VTA_MEM_ID_INP: - elem_bytes = VTA_INP_ELEM_BYTES; - break; + elem_bytes = VTA_INP_ELEM_BYTES; + break; case VTA_MEM_ID_WGT: - elem_bytes = VTA_WGT_ELEM_BYTES; - break; + elem_bytes = VTA_WGT_ELEM_BYTES; + break; case VTA_MEM_ID_ACC: - elem_bytes = VTA_ACC_ELEM_BYTES; - break; + elem_bytes = VTA_ACC_ELEM_BYTES; + break; case VTA_MEM_ID_OUT: - elem_bytes = VTA_OUT_ELEM_BYTES; - break; + elem_bytes = VTA_OUT_ELEM_BYTES; + break; default: - LOG(FATAL) << "Memory id not recognized:" << memory_id; - break; + LOG(FATAL) << "Memory id not recognized:" << memory_id; + break; } /* * elements size should not larger than VTA_PAGE_BYTES. @@ -1031,16 +935,9 @@ class CommandQueue { return elem_bytes; } - void LoadBuffer2D(void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, + void LoadBuffer2D(void* src_dram_addr, uint32_t src_elem_offset, uint32_t x_size, uint32_t y_size, + uint32_t x_stride, uint32_t x_pad_before, uint32_t y_pad_before, + uint32_t x_pad_after, uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(dst_memory_type); insn->opcode = VTA_OPCODE_LOAD; @@ -1058,12 +955,8 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void StoreBuffer2D(uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, + void StoreBuffer2D(uint32_t src_sram_index, uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride) { VTAMemInsn* insn = insn_queue_.CreateStoreInsn(); insn->opcode = VTA_OPCODE_STORE; @@ -1081,27 +974,21 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void DepPush(int from_qid, int to_qid) { - insn_queue_.DepPush(from_qid, to_qid); - } + void DepPush(int from_qid, int to_qid) { insn_queue_.DepPush(from_qid, to_qid); } - void DepPop(int from_qid, int to_qid) { - insn_queue_.DepPop(from_qid, to_qid); - } + void DepPop(int from_qid, int to_qid) { insn_queue_.DepPop(from_qid, to_qid); } void ReadBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_READ_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->FlushCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->FlushCache(elem_bytes * start, elem_bytes * extent); } } void WriteBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_WRITE_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->InvalidateCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->InvalidateCache(elem_bytes * start, elem_bytes * extent); } } @@ -1131,16 +1018,13 @@ class CommandQueue { insn_queue_.DumpInsn(); } // Make sure that the last instruction is a finish instruction - CHECK(reinterpret_cast( - insn_queue_.data())[insn_queue_.count()-1].opcode == VTA_OPCODE_FINISH); + CHECK(reinterpret_cast(insn_queue_.data())[insn_queue_.count() - 1].opcode == + VTA_OPCODE_FINISH); // Make sure that we don't exceed contiguous physical memory limits CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); - int timeout = VTADeviceRun( - device_, - insn_queue_.dram_phy_addr(), - insn_queue_.count(), - wait_cycles); + int timeout = + VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles); CHECK_EQ(timeout, 0); // Reset buffers uop_queue_.Reset(); @@ -1154,14 +1038,9 @@ class CommandQueue { } // Set debug flag - void SetDebugFlag(int debug_flag) { - debug_flag_ = debug_flag; - } + void SetDebugFlag(int debug_flag) { debug_flag_ = debug_flag; } - void PushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1180,10 +1059,7 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void PushALUUop(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushALUUop(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1203,23 +1079,19 @@ class CommandQueue { } static std::shared_ptr& ThreadLocal() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); if (inst == nullptr) { inst = std::make_shared(); } return inst; } - static void Shutdown() { - ThreadLocal().reset(); - } + static void Shutdown() { ThreadLocal().reset(); } private: // Push GEMM uop to the command buffer void PushGEMMOp(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1230,7 +1102,7 @@ class CommandQueue { insn->reset_reg = kernel->reset_out_; insn->uop_bgn = kernel->sram_begin_; insn->uop_end = kernel->sram_end_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() > 0) { insn->iter_out = loop[0].extent; insn->wgt_factor_out = loop[0].wgt_factor; @@ -1257,8 +1129,7 @@ class CommandQueue { // Push ALU uop to the command buffer void PushALUUop(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1272,7 +1143,7 @@ class CommandQueue { insn->alu_opcode = kernel->opcode_; insn->use_imm = kernel->use_imm_; insn->imm = kernel->imm_val_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() == 0) { insn->iter_out = 1; insn->dst_factor_out = 0; @@ -1305,9 +1176,7 @@ class CommandQueue { } } // Auto sync when instruction overflow - void AutoSync() { - this->Synchronize(1 << 31); - } + void AutoSync() { this->Synchronize(1 << 31); } // Internal debug flag int debug_flag_{0}; @@ -1323,19 +1192,11 @@ class CommandQueue { } // namespace vta -void* VTABufferAlloc(size_t size) { - return vta::DataBuffer::Alloc(size); -} +void* VTABufferAlloc(size_t size) { return vta::DataBuffer::Alloc(size); } -void VTABufferFree(void* buffer) { - vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); -} +void VTABufferFree(void* buffer) { vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); } -void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, +void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, int kind_mask) { vta::DataBuffer* from_buffer = nullptr; vta::DataBuffer* to_buffer = nullptr; @@ -1353,143 +1214,87 @@ void VTABufferCopy(const void* from, // This is an FPGA to host mem transfer from_buffer->InvalidateCache(from_offset, size); from_buffer->MemCopyToHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } else if (to_buffer) { // This is a host to FPGA mem transfer to_buffer->MemCopyFromHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); to_buffer->FlushCache(to_offset, size); } } -VTACommandHandle VTATLSCommandHandle() { - return vta::CommandQueue::ThreadLocal().get(); -} +VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); } -void VTARuntimeShutdown() { - vta::CommandQueue::Shutdown(); -} +void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); } void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { - static_cast(cmd)-> - SetDebugFlag(debug_flag); + static_cast(cmd)->SetDebugFlag(debug_flag); } void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { return vta::DataBuffer::FromHandle(buffer)->virt_addr(); } -void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - WriteBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->WriteBarrier(buffer, elem_bits, start, extent); } -void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - ReadBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->ReadBarrier(buffer, elem_bits, start, extent); } -void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, - uint32_t dst_memory_type) { - static_cast(cmd)-> - LoadBuffer2D(src_dram_addr, src_elem_offset, - x_size, y_size, x_stride, - x_pad_before, y_pad_before, - x_pad_after, y_pad_after, - dst_sram_index, dst_memory_type); +void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, uint32_t x_pad_before, + uint32_t y_pad_before, uint32_t x_pad_after, uint32_t y_pad_after, + uint32_t dst_sram_index, uint32_t dst_memory_type) { + static_cast(cmd)->LoadBuffer2D( + src_dram_addr, src_elem_offset, x_size, y_size, x_stride, x_pad_before, y_pad_before, + x_pad_after, y_pad_after, dst_sram_index, dst_memory_type); } -void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride) { - static_cast(cmd)-> - StoreBuffer2D(src_sram_index, src_memory_type, - dst_dram_addr, dst_elem_offset, - x_size, y_size, x_stride); +void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, uint32_t src_memory_type, + void* dst_dram_addr, uint32_t dst_elem_offset, uint32_t x_size, + uint32_t y_size, uint32_t x_stride) { + static_cast(cmd)->StoreBuffer2D( + src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, x_size, y_size, x_stride); } -void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->Push(mode, reset_out, dst_index, src_index, - wgt_index, opcode, use_imm, imm_val); +void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { + vta::CommandQueue::ThreadLocal()->record_kernel()->Push(mode, reset_out, dst_index, src_index, + wgt_index, opcode, use_imm, imm_val); } -void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopBegin(extent, dst_factor, src_factor, wgt_factor); + vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopBegin(extent, dst_factor, src_factor, + wgt_factor); } -void VTAUopLoopEnd() { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopEnd(); -} +void VTAUopLoopEnd() { vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopEnd(); } -int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushGEMMOp(uop_handle, finit, signature, nbytes); +int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushGEMMOp(uop_handle, finit, signature, nbytes); return 0; } -int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushALUUop(uop_handle, finit, signature, nbytes); +int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushALUUop(uop_handle, finit, signature, nbytes); return 0; } int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPush(from_qid, to_qid); + static_cast(cmd)->DepPush(from_qid, to_qid); return 0; } int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPop(from_qid, to_qid); + static_cast(cmd)->DepPop(from_qid, to_qid); return 0; } void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) { - static_cast(cmd)-> - Synchronize(wait_cycles); + static_cast(cmd)->Synchronize(wait_cycles); } diff --git a/vta/runtime/runtime.h b/vta/runtime/runtime.h index bb16d3a3bfc21..24ebb8e1247b3 100644 --- a/vta/runtime/runtime.h +++ b/vta/runtime/runtime.h @@ -64,12 +64,8 @@ TVM_DLL void VTABufferFree(void* buffer); * \param size Size of copy. * \param kind_mask The memory copy kind. */ -TVM_DLL void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - int kind_mask); +TVM_DLL void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t size, int kind_mask); /*! \brief VTA command handle */ typedef void* VTACommandHandle; @@ -99,10 +95,7 @@ TVM_DLL void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -113,10 +106,7 @@ TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -142,17 +132,10 @@ TVM_DLL void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); * \param dst_sram_index Destination SRAM index. * \param dst_memory_type Destination memory type. */ -TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, +TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, + uint32_t x_pad_before, uint32_t y_pad_before, uint32_t x_pad_after, + uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type); /*! @@ -167,13 +150,9 @@ TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, * \param y_size The number of rows. * \param x_stride The x axis stride. */ -TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, +TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, + uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride); /*! @@ -207,14 +186,8 @@ TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ -TVM_DLL void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val); +TVM_DLL void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val); /*! * \brief Mark start of a micro op loop. @@ -223,9 +196,7 @@ TVM_DLL void VTAUopPush(uint32_t mode, * \param src_factor The input factor. * \param wgt_factor The weight factor. */ -TVM_DLL void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +TVM_DLL void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor); /*! @@ -241,10 +212,7 @@ TVM_DLL void VTAUopLoopEnd(); * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push ALU uop kernel into the command handle. @@ -254,10 +222,7 @@ TVM_DLL int VTAPushGEMMOp(void** uop_handle, * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push dependence token. diff --git a/web/.eslintignore b/web/.eslintignore new file mode 100644 index 0000000000000..1521c8b7652b1 --- /dev/null +++ b/web/.eslintignore @@ -0,0 +1 @@ +dist diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 0000000000000..a3135cf24b9d8 --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,6 @@ +.vscode +*~ +out +node_modules +package-lock.json +build diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json deleted file mode 100644 index 33783b3bbb21c..0000000000000 --- a/web/.jsdoc_conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "templates": { - "default": { - "includeDate": false - } - } -} diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 0000000000000..c0b8f077dbf77 --- /dev/null +++ b/web/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + +.PHONY: clean all removetypedep + +all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \ + -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +EMCC_LDFLAGS = --pre-js emcc/preload.js + +dist/wasm/%.bc: emcc/%.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d + $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< + + +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc dist/wasm/webgpu_runtime.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) + + +dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py + python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@ + +clean: + @rm -rf dist/wasm + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md index 5dfd6917934bf..358884ca26b11 100644 --- a/web/README.md +++ b/web/README.md @@ -15,163 +15,83 @@ -# TVM WebAssembly and Javascript Backend +# TVM WebAssembly Runtime -This folder contains TVM WebAssembly and Javascript backend through Emscripten. +This folder contains TVM WebAssembly Runtime. ## Installation -While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other -system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against -the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten``` -as a backend target. + +The LLVM main branch support webassembly as a target, we can directly +build TVM with LLVM mainline to generate wasm modules. +Note that, however, we still need emscripten to compile the runtime and provide system library support. + +Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install). ### Setup Emscripten -Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html) -to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document. -```bash -./emsdk update -./emsdk install latest -./emsdk activate latest -``` +We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy +to the browser environment. -Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library. -Which can be installed via following command. +Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment. -```bash -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -``` +### Build TVM Wasm Runtime -### Setup Environment Variable +After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder. -In normal setting, we can setup the necessary environment variable with the following command. ```bash -source /path-to-emsdk-portable/emsdk_env.sh +make ``` -However, this will put emscripten's clang and llvm path ahead of the current system path. -What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones. -You can get the detailed path by type ```./emsdk activate``` -```bash -export PATH=${PATH}:/emsdk-related-path-here +This command will create the follow files: +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. +- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime. -``` -### Build TVM with Fastcomp LLVM +### Build TVM Wasm JS Frontend -To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk``` -to point to fastcomp's llvm-config and build TVM normally. +Type the following command in the web folder. ```bash -LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config +npm run bundle ``` -### Build TVM Web Runtime +This command will create the tvmjs library that we can use to interface with the wasm runtime. -The above command gives us the TVM compiling environment. Now we need to build runtime, -to do so, make sure we set the environment correctly as in previous section and type -```bash -make web -``` +## Use TVM to Generate Wasm Library and Run it -This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```. - -## Use TVM to Generate Javascript Library - -The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```. - -The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate -the compilation process. - -```python -import tvm -from tvm import te -from tvm.contrib import emscripten -import os -def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - if not tvm.runtime.enabled(target): - raise RuntimeError("Target %s is not enbaled" % target) - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') - s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path) - -if __name__ == "__main__": - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) -``` +Check code snippet in -In this workflow, we use TVM to generate a ```.bc``` file and statically link -that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that). -The result js library is a library that contains both TVM runtime and the compiled function. - - -## Run the Generated Library - -The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate -how to run the compiled library. - -```js -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Load system library, the compiled function is registered in sysLib. -var sysLib = tvm.systemLib(); - -function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { - return Math.random() * max; - }); -} - -function testAddOne() { - // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); - // call the function. - faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array - // verify - for (var i = 0; i < BB.length; ++i) { - assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); - } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); -``` +- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py) + shows how to create a wasm library that links with tvm runtime. + - Note that all wasm libraries have to created using the `--system-lib` option + - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc` +- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate + how to run the generated library through tvmjs API. -Current example supports static linking, which is the preferred way to get more efficiency -in javascript backend. -## Proxy based RPC +## Run Wasm Remotely through WebSocket RPC. -We can now use javascript end to start an RPC server and connect to it from python side, +We can now use js side to start an RPC server and connect to it from python side, making the testing flow easier. -The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install) -- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -- Open broswer, goto the server webpage click Connect to proxy. - - Alternatively run "node web/example_rpc_node.js" -- run "python tests/web/websock_rpc_test.py" to run the rpc client. +The following is an example to reproduce this. +- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy. +- Start the WebSocket RPC + - Browswer version: open https://localhost:8888, click connect to proxy + - NodeJS version: `npm run rpc` +- run `python tests/node/websock_rpc_test.py` to run the rpc client. + + +## WebGPU Experiments + +Web gpu is still experimental, so apis can change. +Right now we use the SPIRV to generate shaders that can be accepted by Chrome and Firefox. -The general idea is to use Emscripten's dynamic linking to dynamically load modules. +- Obtain a browser that support webgpu. + - So far only Chrome Canary on MacOS works + - Firefox should be close pending the support of Fence. +- Download vulkan SDK (1.1 or higher) that supports SPIRV 1.3 +- Start the WebSocket RPC +- run `python tests/node/webgpu_rpc_test.py` diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html new file mode 100644 index 0000000000000..6d353e29b08dd --- /dev/null +++ b/web/apps/browser/rpc_server.html @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + TVM RPC Test Page + + + + + +

TVM WebSocket RPC Server

+ To use this page +
    +
  • Run "make" and "npm run bundle" to create the libraries.
  • +
  • + run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. +
  • +
  • Click Connect to proxy.
  • +
  • run "python tests/python/websock_rpc_test.py" to run the rpc client.
  • +
+ +

Options

+ Proxy URL
+ RPC Server Key
+ + +
+ + + diff --git a/web/apps/node/example.js b/web/apps/node/example.js new file mode 100644 index 0000000000000..f81a9c903e5d8 --- /dev/null +++ b/web/apps/node/example.js @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +// the async version of the API. +tvmjs.instantiate(wasmSource, new EmccWASI()) +.then((tvm) => { + // List all the global functions from the runtime. + console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); +}); + diff --git a/web/apps/node/wasi_example.js b/web/apps/node/wasi_example.js new file mode 100644 index 0000000000000..95ec2e0b1d075 --- /dev/null +++ b/web/apps/node/wasi_example.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const { WASI } = require('wasi'); +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +const wasi = new WASI({ args: process.argv, env: process.env }); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi); + +// List all the global functions from the runtime. +console.log("Runtime using WASI\n", tvm.listGlobalFuncNames()); diff --git a/web/apps/node/wasi_rpc_server.js b/web/apps/node/wasi_rpc_server.js new file mode 100644 index 0000000000000..eb4c6ed52be9e --- /dev/null +++ b/web/apps/node/wasi_rpc_server.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Example code to start the RPC server on nodejs using WASI + */ +const { WASI } = require("wasi"); +const tvmjs = require("../../dist"); + +// Get import returns a fresh library in each call. +const getImports = () => { + return new WASI({ + args: process.argv, + env: process.env + }); +}; + +const proxyUrl = "ws://localhost:8888/ws"; + +new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log); diff --git a/web/emcc/decorate_as_wasi.py b/web/emcc/decorate_as_wasi.py new file mode 100644 index 0000000000000..741e33bb22ea3 --- /dev/null +++ b/web/emcc/decorate_as_wasi.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Decorate emcc generated js to a WASI compatible API.""" + +import sys + +template_head = """ +function EmccWASI() { +""" + +template_tail = """ + this.Module = Module; + this.start = Module.wasmLibraryProvider.start; + this.imports = Module.wasmLibraryProvider.imports; + this.wasiImport = this.imports["wasi_snapshot_preview1"]; +} + +if (typeof module !== "undefined" && module.exports) { + module.exports = EmccWASI; +} +""" + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage ") + result = template_head + open(sys.argv[1]).read() + template_tail + with open(sys.argv[2], "w") as fo: + fo.write(result) diff --git a/web/emcc/preload.js b/web/emcc/preload.js new file mode 100644 index 0000000000000..882280f9cac09 --- /dev/null +++ b/web/emcc/preload.js @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-unused-vars */ +/** + * JS config used by --pre-js in emcc. + * Wrap module as a LibraryProvider. + */ + +var __wasmLib = {}; + +function __wasmLibInstantiateWasm(imports, successCallback) { + __wasmLib.imports = imports; + __wasmLib.successCallback = successCallback; +} + +function __wasmLibStart(wasmInstance) { + __wasmLib.successCallback(wasmInstance); +} + +__wasmLib.start = __wasmLibStart; + +var Module = { + "instantiateWasm": __wasmLibInstantiateWasm, + "wasmLibraryProvider": __wasmLib +}; diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc new file mode 100644 index 0000000000000..6abd12252d1d8 --- /dev/null +++ b/web/emcc/tvmjs_support.cc @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvmjs_support.cc + * \brief Support functions to be linked with wasm_runtime to provide + * PackedFunc callbacks in tvmjs. + * We do not need to link this file in standalone wasm. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include +#include +#include +#include + +#include "../../src/runtime/rpc/rpc_local_session.h" + +extern "C" { +// --- Additional C API for the Wasm runtime --- +/*! + * \brief Allocate space aligned to 64 bit. + * \param size The size of the space. + * \return The allocated space. + */ +TVM_DLL void* TVMWasmAllocSpace(int size); + +/*! + * \brief Free the space allocated by TVMWasmAllocSpace. + * \param data The data pointer. + */ +TVM_DLL void TVMWasmFreeSpace(void* data); + +/*! + * \brief Create PackedFunc from a resource handle. + * \param resource_handle The handle to the resource. + * \param out The output PackedFunc. + * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer +3A * \return 0 if success. + */ +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out); + +// --- APIs to be implemented by the frontend. --- +/*! + * \brief Wasm frontend packed function caller. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param ret The return value handle. + * \param resource_handle The handle additional resouce handle from fron-end. + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. + */ +extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, + void* resource_handle); + +/*! + * \brief Wasm frontend resource finalizer. + * \param resource_handle The pointer to the external resource. + */ +extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +} // extern "C" + +void* TVMWasmAllocSpace(int size) { + int num_count = (size + 7) / 8; + return new int64_t[num_count]; +} + +void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } + +int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) { + return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer, + out); +} + +namespace tvm { +namespace runtime { + +// A special local session that can interact with async +// functions in the JS runtime. +class AsyncLocalSession : public LocalSession { + public: + AsyncLocalSession() {} + + PackedFuncHandle GetFunction(const std::string& name) final { + if (name == "runtime.RPCTimeEvaluator") { + return get_time_eval_placeholder_.get(); + } else if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { + auto* rptr = new PackedFunc(*fp); + async_func_set_.insert(rptr); + return rptr; + } else { + return nullptr; + } + } + + void FreeHandle(void* handle, int type_code) final { + if (type_code == kTVMPackedFuncHandle) { + auto it = async_func_set_.find(handle); + if (it != async_func_set_.end()) { + async_func_set_.erase(it); + } + } + if (handle != get_time_eval_placeholder_.get()) { + LocalSession::FreeHandle(handle, type_code); + } + } + + void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, FAsyncCallback callback) final { + auto it = async_func_set_.find(func); + if (it != async_func_set_.end()) { + PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { + int code = args[0]; + TVMRetValue rv; + rv = args[1]; + this->EncodeReturn(std::move(rv), + [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); }); + }); + + TVMRetValue temp; + std::vector values(arg_values, arg_values + num_args); + std::vector type_codes(arg_type_codes, arg_type_codes + num_args); + values.emplace_back(TVMValue()); + type_codes.emplace_back(0); + + TVMArgsSetter setter(&values[0], &type_codes[0]); + // pass the callback as the last argument. + setter(num_args, packed_callback); + + auto* pf = static_cast(func); + pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp); + } else if (func == get_time_eval_placeholder_.get()) { + // special handle time evaluator. + try { + TVMArgs args(arg_values, arg_type_codes, num_args); + PackedFunc retfunc = + this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + TVMRetValue rv; + rv = retfunc; + this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + // mark as async. + async_func_set_.insert(encoded_args.values[1].v_handle); + callback(RPCCode::kReturn, encoded_args); + }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } + } else { + LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback); + } + } + + void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_to) + ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + cpu_ctx, remote_ctx_to, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); + } + } + + void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, + DLDataType type_hint, FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_from) + ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, cpu_ctx, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); + } + } + + void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final { + if (ctx.device_type == kDLCPU) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } else { + CHECK(ctx.device_type == static_cast(kDLWebGPU)); + if (async_wait_ == nullptr) { + async_wait_ = tvm::runtime::Registry::Get("__async.wasm.WebGPUWaitForTasks"); + } + CHECK(async_wait_ != nullptr); + PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) { + int code = args[0]; + on_complete(static_cast(code), + TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1)); + }); + (*async_wait_)(packed_callback); + } + } + + bool IsAsync() const final { return true; } + + private: + std::unordered_set async_func_set_; + std::unique_ptr get_time_eval_placeholder_ = std::make_unique(); + const PackedFunc* async_wait_{nullptr}; + + // time evaluator + PackedFunc GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + int device_id, int number, int repeat, int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); + } + } + + // time evaluator + PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) { + // the function is a async function. + PackedFunc on_complete = args[args.size() - 1]; + // keep argument alive in finvoke so that they + // can be used throughout the async benchmark + std::vector values(args.values, args.values + args.size() - 1); + std::vector type_codes(args.type_codes, args.type_codes + args.size() - 1); + + auto finvoke = [pf, values, type_codes](int n) { + TVMRetValue temp; + TVMArgs invoke_args(values.data(), type_codes.data(), values.size()); + for (int i = 0; i < n; ++i) { + pf.CallPacked(invoke_args, &temp); + } + }; + auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution"); + CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function"; + (*time_exec)(TypedPackedFunc(finvoke), ctx, number, repeat, min_repeat_ms, + on_complete); + }; + return PackedFunc(ftimer); + } +}; + +TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc new file mode 100644 index 0000000000000..a67b4c3dcd146 --- /dev/null +++ b/web/emcc/wasm_runtime.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file wasm_runtime.cc + * \brief TVM wasm runtime library pack. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include + +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" +#include "src/runtime/library_module.cc" +#include "src/runtime/module.cc" +#include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/rpc/rpc_channel.cc" +#include "src/runtime/rpc/rpc_endpoint.cc" +#include "src/runtime/rpc/rpc_event_impl.cc" +#include "src/runtime/rpc/rpc_local_session.cc" +#include "src/runtime/rpc/rpc_module.cc" +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/system_library.cc" +#include "src/runtime/workspace_pool.cc" + +// --- Implementations of backend and wasm runtime API. --- + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { return 0; } + +// --- Environment PackedFuncs for testing --- +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0]; +}); + +TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); + +TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc new file mode 100644 index 0000000000000..7f0b0d9f72cb2 --- /dev/null +++ b/web/emcc/webgpu_runtime.cc @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file webgpu_runtime.cc + * \brief WebGPU runtime based on the TVM JS. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include +#include +#include +#include + +#include "../../src/runtime/meta_data.h" +#include "../../src/runtime/vulkan/vulkan_shader.h" +#include "../../src/runtime/workspace_pool.h" + +namespace tvm { +namespace runtime { + +/*! \brief Thread local workspace */ +class WebGPUThreadEntry { + public: + /*! \brief thread local pool*/ + WorkspacePool pool; + /*! \brief constructor */ + WebGPUThreadEntry(); + // get the threadlocal workspace + static WebGPUThreadEntry* ThreadLocal(); +}; + +// All the implementations are redirectly to the JS side. +class WebGPUDeviceAPI : public DeviceAPI { + public: + WebGPUDeviceAPI() { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUDeviceAPI"); + CHECK(fp != nullptr) << "Cannot find wasm.WebGPUContext in the env"; + auto getter = TypedPackedFunc(*fp); + alloc_space_ = getter("deviceAllocDataSpace"); + free_space_ = getter("deviceFreeDataSpace"); + copy_to_gpu_ = getter("deviceCopyToGPU"); + copy_from_gpu_ = getter("deviceCopyFromGPU"); + copy_within_gpu_ = getter("deviceCopyWithinGPU"); + } + + void SetDevice(TVMContext ctx) final {} + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (kind == kExist) { + *rv = 1; + } + } + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) final { + double ptr_number = alloc_space_(nbytes); + return reinterpret_cast(static_cast(ptr_number)); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { return free_space_(ptr); } + + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream) final { + if (static_cast(ctx_from.device_type) == kDLWebGPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + CHECK_EQ(ctx_from.device_id, ctx_to.device_id); + copy_within_gpu_(const_cast(from), from_offset, to, to_offset, size); + } else if (static_cast(ctx_from.device_type) == kDLWebGPU && + ctx_to.device_type == kDLCPU) { + void* to_ptr = static_cast(to) + to_offset; + copy_from_gpu_(const_cast(from), from_offset, to_ptr, size); + } else if (ctx_from.device_type == kDLCPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + void* from_ptr = static_cast(const_cast(from)) + from_offset; + copy_to_gpu_(from_ptr, to, to_offset, size); + } else { + LOG(FATAL) << "expect copy from/to WebGPU or between WebGPU"; + } + } + + TVMStreamHandle CreateStream(TVMContext ctx) final { + LOG(FATAL) << "Not implemented"; + return nullptr; + } + + void FreeStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { + LOG(FATAL) << "Not implemented"; + return; + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } + + void SetStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { + return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + } + + void FreeWorkspace(TVMContext ctx, void* data) final { + WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + } + + static const std::shared_ptr& Global() { + static std::shared_ptr inst = std::make_shared(); + return inst; + } + + private: + // NOTE: js return number as double. + TypedPackedFunc alloc_space_; + TypedPackedFunc free_space_; + TypedPackedFunc copy_to_gpu_; + TypedPackedFunc copy_from_gpu_; + TypedPackedFunc + copy_within_gpu_; +}; + +typedef dmlc::ThreadLocalStore WebGPUThreadStore; + +WebGPUThreadEntry::WebGPUThreadEntry() + : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} + +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } + +class WebGPUModuleNode final : public runtime::ModuleNode { + public: + explicit WebGPUModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string source) + : smap_(smap), fmap_(fmap), source_(source) { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); + CHECK(fp != nullptr); + create_shader_ = *fp; + } + + const char* type_key() const final { return "webgpu"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + auto it = smap_.find(name); + if (it != smap_.end()) { + FunctionInfo info = fmap_.at(name); + info.name = name; + std::ostringstream os; + dmlc::JSONWriter writer(&os); + info.Save(&writer); + TVMByteArray arr; + arr.data = reinterpret_cast(it->second.data.data()); + arr.size = it->second.data.size() * sizeof(it->second.data[0]); + return create_shader_(os.str(), arr); + } else { + return PackedFunc(nullptr); + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + LOG(FATAL) << "Not implemented"; + } + + void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } + + std::string GetSource(const std::string& format) final { + // can only return source code. + return source_; + } + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; + // The source + std::string source_; + // Callback to get the GPU function. + TypedPackedFunc create_shader_; +}; + +Module WebGPUModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map smap; + std::unordered_map fmap; + + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&smap); + return Module(make_object(smap, fmap, "")); +} + +// for now webgpu is hosted via a vulkan module. +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); + +TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/example_rpc.html b/web/example_rpc.html deleted file mode 100644 index ae2b1dd9c44bc..0000000000000 --- a/web/example_rpc.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - - -

TVM Test Page

- To use this page, the easiest way is to do -
    -
  • run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -
  • Click Connect to proxy. -
  • run "python tests/web/websock_rpc_test.py" to run the rpc client. -
-

Options

- Proxy URL
- RPC Server Key
- - -
- - - - diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000000000..f6b700d5ff6d8 --- /dev/null +++ b/web/package.json @@ -0,0 +1,28 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.7.0", + "scripts": { + "build": "tsc -b", + "lint": "eslint -c .eslintrc.json .", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "typescript": "^3.8.3", + "@types/node": "^12.12.37", + "@webgpu/types": "^0.0.24", + "eslint": "^6.8.0", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "typedoc": "^0.17.6", + "rollup": "^2.7.6", + "ws": "^7.2.5", + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "rollup-plugin-typescript2": "^0.27.0" + } +} diff --git a/web/rollup.config.js b/web/rollup.config.js new file mode 100644 index 0000000000000..9090c77868fe5 --- /dev/null +++ b/web/rollup.config.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import commonjs from '@rollup/plugin-commonjs'; +import resolve from '@rollup/plugin-node-resolve'; + +export default { + input: 'dist/index.js', + output: { + file: 'dist/tvmjs.bundle.js', + format: 'umd', + name: 'tvmjs', + exports: 'named', + globals: {'ws': 'ws', + 'perf_hooks': 'perf_hooks', + '@webgpu/types': 'webgputypes'} + }, + plugins: [commonjs(), resolve()], + external: ['ws', 'perf_hooks', '@webgpu/types'] +}; diff --git a/web/src/compact.ts b/web/src/compact.ts new file mode 100644 index 0000000000000..29569b5d005d3 --- /dev/null +++ b/web/src/compact.ts @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** NodeJS and Web compact layer */ + +/** + * Get performance masurement. + */ +export function getPeformance(): Performance { + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require("perf_hooks"); + return performanceNode.performance as Performance; + } else { + return performance as Performance; + } +} + +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export function createWebSocket(url: string): WebSocket { + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + return new WebSocket(url); + } else { + return new (WebSocket as any)(url); + } + +} \ No newline at end of file diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts new file mode 100644 index 0000000000000..f533b4e491a6f --- /dev/null +++ b/web/src/ctypes.ts @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLContext = I32 + I32, +} + +/** + * Type code in TVM FFI. + */ +export const enum TypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + TVMContext = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} \ No newline at end of file diff --git a/web/src/environment.ts b/web/src/environment.ts new file mode 100644 index 0000000000000..df0fe68c81e06 --- /dev/null +++ b/web/src/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts new file mode 100644 index 0000000000000..2d99fc9106ccc --- /dev/null +++ b/web/src/index.ts @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { + Scalar, DLContext, DLDataType, + PackedFunc, Module, NDArray, Instance, + instantiate +} from "./runtime"; +export { Disposable, LibraryProvider } from "./types"; +export { RPCServer } from "./rpc_server"; +export { wasmPath } from "./support"; +export { detectGPUDevice } from "./webgpu"; +export { assert } from "./support"; \ No newline at end of file diff --git a/web/src/memory.ts b/web/src/memory.ts new file mode 100644 index 0000000000000..ac737b7c297d7 --- /dev/null +++ b/web/src/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts new file mode 100644 index 0000000000000..50227dc792818 --- /dev/null +++ b/web/src/rpc_server.ts @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SizeOf, TypeCode } from "./ctypes"; +import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import { detectGPUDevice } from "./webgpu"; +import * as compact from "./compact"; +import * as runtime from "./runtime"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private pendingSend: Promise = Promise.resolve(); + private name: string; + private inst?: runtime.Instance = undefined; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + + this.checkLittleEndian(); + this.socket = compact.createWebSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + this.log("Automatic reconnecting.."); + new RPCServer(this.url, this.key, this.getImports, this.logger); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == TypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == TypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(this.pendingBytes == 0); + + const asyncInitServer = async (): Promise => { + assert(args[1] instanceof Uint8Array); + const inst = await runtime.instantiate( + args[1].buffer, + this.getImports(), + this.logger + ); + try { + const gpuDevice: GPUDevice | undefined = await detectGPUDevice(); + if (gpuDevice !== undefined) { + const label = gpuDevice.label?.toString() || "WebGPU"; + this.log("Initialize GPU device: " + label); + inst.initWebGPU(gpuDevice); + } + } catch (err) { + this.log("Cannnot initialize WebGPU, " + err.toString()); + } + + this.inst = inst; + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + // WebSocket will automatically close the socket + // if we burst send data that exceeds its internal buffer + // wait a bit before we send next one. + const sendDataWithCongestionControl = async (): Promise => { + const packetSize = 4 << 10; + const maxBufferAmount = 4 * packetSize; + const waitTimeMs = 20; + for ( + let offset = 0; + offset < cbytes.length; + offset += packetSize + ) { + const end = Math.min(offset + packetSize, cbytes.length); + while (this.socket.bufferedAmount >= maxBufferAmount) { + await new Promise((r) => setTimeout(r, waitTimeMs)); + } + this.socket.send(cbytes.slice(offset, end)); + } + }; + // Chain up the pending send so that the async send is always in-order. + this.pendingSend = this.pendingSend.then( + sendDataWithCongestionControl + ); + // Directly return since the data are "sent" from the caller's pov. + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + + fcreate.dispose(); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("wasm.LocalSession"); + const localSession = flocal(); + flocal.dispose(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + localSession.dispose(); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + }; + + this.state = RPCServerState.WaitForCallback; + asyncInitServer(); + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts new file mode 100644 index 0000000000000..bcf7be7d5544a --- /dev/null +++ b/web/src/runtime.ts @@ -0,0 +1,1363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset, SizeOf, TypeCode } from "./ctypes"; +import { Disposable } from "./types"; +import { Memory, CachedCallStack } from "./memory"; +import { assert, StringToUint8Array } from "./support"; +import { Environment } from "./environment"; +import { WebGPUContext } from "./webgpu"; + +import * as compact from "./compact"; +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + webGPUContext?: WebGPUContext; + private wasmInstance: WebAssembly.Instance; + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "gpu", + 4: "opencl", + 8: "metal", + 15: "webgpu" +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + gpu: 2, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, + webgpu: 15 +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLContext { + /** The device type code of the context. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + if (this.deviceType == undefined) { + throw new Error("Cannot recogonize deviceType " + deviceType); + } + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the context + */ + async sync(): Promise { + if (this.deviceType == DeviceStrToEnum.webgpu) { + assert(this.lib.webGPUContext !== undefined); + await this.lib.webGPUContext.sync(); + } + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 4: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Context of the array. */ + context: DLContext; + /** Whether it is a temporary view that can become invalid after the call. */ + private isView: boolean; + private byteOffset: number; + private dltensor: Pointer; + private dataPtr: Pointer; + private lib: FFILibrary; + private dlDataType: DLDataType; + + constructor(handle: Pointer, isView: boolean, lib: FFILibrary) { + this.handle = handle; + this.isView = isView; + this.lib = lib; + + if (this.isView) { + this.dltensor = handle; + } else { + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + } + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // dataPtr + this.dataPtr = lib.memory.loadPointer(this.dltensor); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // ctx + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.context = new DLContext(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + dispose(): void { + if (this.handle != 0 && !this.isView) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array | Float32Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.handle, + this.handle, + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + if (this.context.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); + } + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.handle, + stack.ptrFromOffset(nameOffset), + 1, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.handle, + mod.handle + ) + ); + } +} + +/** + * Graph runtime. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +class GraphRuntime implements Disposable { + module: Module; + private packedSetInput: PackedFunc; + private packedRun: PackedFunc; + private packedGetOutput: PackedFunc; + private packedLoadParams: PackedFunc; + + /** + * COnstructor + * @param module The underlying module. + */ + constructor(module: Module) { + this.module = module; + this.packedSetInput = module.getFunction("set_input"); + this.packedRun = module.getFunction("run"); + this.packedGetOutput = module.getFunction("get_output"); + this.packedLoadParams = module.getFunction("load_params"); + } + + dispose(): void { + this.packedSetInput.dispose(); + this.packedRun.dispose(); + this.packedGetOutput.dispose(); + } + + /** + * Set input to the executor. + * + * @param key The input key. + * @param value The value to get set. + */ + setInput(key: number | string, value: NDArray): void { + if (typeof key == "number") { + this.packedSetInput(new Scalar(key, "int32"), value); + } else { + this.packedSetInput(key, value); + + } + } + + /** + * Execute the underlying graph. + */ + run(): void { + this.packedRun(); + } + + /** + * Get index-th output. + * @param index The index number. + * @param out The optional output storage parameters. + * @returns The output array. + */ + getOutput(index: number, out: NDArray | undefined = undefined): NDArray { + if (out !== undefined) { + this.packedGetOutput(new Scalar(index, "int32"), out) + return out; + } else { + return this.packedGetOutput(new Scalar(index, "int32")); + } + } + + /** + * Load parameters from parameter binary. + * @param paramBinary The parameter binary. + */ + loadParams(paramBinary: Uint8Array): void { + this.packedLoadParams(paramBinary); + } + + /** + * Benchmark stable execution of the graph(without data copy). + * @params ctx The context to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + async benchmarkRuns(ctx: DLContext, number=10, repeat=4): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPeformance(); + const results = []; + this.run(); + await ctx.sync(); + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.run(); + } + await ctx.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); + } + return results; + } +} + +/** Code used as the first argument of the async callback. */ +const enum AyncCallbackCode { + kReturn = 4, + kException = 5, +} + +/** + * TVM runtime instance. + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + private lib: FFILibrary; + private env: Environment; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.registerEnvGlobalPackedFuncs(); + } + + dispose(): void { + this.lib.dispose(); + } + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + const getSysLib = this.getGlobalFunc("runtime.SystemLib"); + const mod = getSysLib() as Module; + getSysLib.dispose(); + return mod; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + const packedFunc = this.toPackedFunc(func); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.handle, + ioverride + ) + ); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = TypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = TypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = TypeCode.TVMOpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLContext} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created context. + */ + context(deviceType: number | string, deviceId = 0): DLContext { + return new DLContext(deviceType, deviceId, this.lib); + } + + /** + * Create a new cpu {@link DLContext} + * @param deviceId The device index. + */ + cpu(deviceId = 0): DLContext { + return this.context("cpu", deviceId); + } + + /** + * Create a new webgpu {@link DLContext} + * @param deviceId The device index. + */ + webgpu(deviceId = 0): DLContext { + return this.context("webgpu", deviceId); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param ctx The context of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + ctx: DLContext = this.context("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + ctx.deviceType, + ctx.deviceId, + outPtr + ) + ); + const ret = new NDArray(this.memory.loadPointer(outPtr), false, this.lib); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Create a new graph runtime. + * + * @param graphJson The graph runtime json file. + * @param lib The underlying library. + * @param ctx The execution context of the graph. + */ + createGraphRuntime( + graphJson: string, + lib: Module, + ctx: DLContext + ): GraphRuntime { + const fcreate = this.getGlobalFunc("tvm.graph_runtime.create"); + const module = fcreate( + graphJson, + lib, + this.scalar(ctx.deviceType, "int32"), + this.scalar(ctx.deviceId, "int32")) as Module; + return new GraphRuntime(module); + } + + + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc( + name: string, + func: Function, + override = false + ): void { + const asyncVariant = (...args: Array): void => { + const fargs = args.slice(0, args.length - 1); + const callback = args[args.length - 1] as PackedFunc; + const promise: Promise = func(...fargs); + promise.then((rv: any) => { + callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + }); + }; + this.registerFunc("__async." + name, asyncVariant, override); + } + + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void { + const webGPUContext = new WebGPUContext( + this.memory, device + ); + this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { + return webGPUContext.getDeviceAPI(name); + }); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { + return webGPUContext.createShader(info, data); + }); + this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + this.lib.webGPUContext = webGPUContext; + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + const perf = compact.getPeformance(); + + // Helper function to time the finvoke + const timeExecution = async ( + finvoke: PackedFunc, + ctx: DLContext, + nstep: number, + repeat: number, + minRepeatMs: number + ): Promise => { + finvoke(this.scalar(1, "int32")); + await ctx.sync(); + const result = []; + let setupNumber: number = nstep; + + for (let i = 0; i < repeat; ++i) { + let durationMs = 0.0; + do { + if (durationMs > 0.0) { + setupNumber = Math.floor( + Math.max(minRepeatMs / (durationMs / nstep) + 1, nstep * 1.618) + ); + } + const tstart: number = perf.now(); + finvoke(this.scalar(setupNumber, "int32")); + await ctx.sync(); + const tend: number = perf.now(); + + durationMs = tend - tstart; + } while (durationMs < minRepeatMs); + const speed = durationMs / setupNumber / 1000; + result.push(speed); + } + const ret = new Float64Array(result.length); + ret.set(result); + return new Uint8Array(ret.buffer); + }; + + const addOne = async (x: number): Promise => { + await new Promise(resolve => setTimeout(resolve, 100)); + return x + 1; + }; + + this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); + this.registerAsyncServerFunc("testing.asyncAddOne", addOne); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMNDArrayHandle); + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, TypeCode.TVMOpaqueHandle); + } + } else if (val instanceof DLContext) { + stack.storeI32(valueOffset, val.deviceType); + stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); + stack.storeI32(codeOffset, TypeCode.TVMContext); + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, TypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, TypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFunc(val); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, TypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, TypeCode.TVMModuleHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == TypeCode.TVMObjectHandle || + tcode == TypeCode.TVMObjectRValueRefArg || + tcode == TypeCode.TVMPackedFuncHandle || + tcode == TypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); + } + + const rv = func(...jsArgs); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + handle, + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { + switch (tcode) { + case TypeCode.Int: + case TypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case TypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case TypeCode.TVMOpaqueHandle: { + return this.memory.loadPointer(rvaluePtr); + } + case TypeCode.TVMNDArrayHandle: { + return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); + } + case TypeCode.TVMDLTensorHandle: { + assert(callbackArg); + return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); + } + case TypeCode.TVMPackedFuncHandle: { + return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + } + case TypeCode.TVMModuleHandle: { + return new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.makePackedFunc(ptr); + } + ); + } + case TypeCode.Null: return undefined; + case TypeCode.TVMContext: { + const deviceType = this.memory.loadI32(rvaluePtr); + const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); + return this.context(deviceType, deviceId); + } + case TypeCode.TVMStr: { + const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + return ret; + } + case TypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {}, + logger: (msg: string) => void = console.log +): Promise { + const env = new Environment(importObject, logger); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/web/src/support.ts b/web/src/support.ts new file mode 100644 index 0000000000000..7a2667a2299f0 --- /dev/null +++ b/web/src/support.ts @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of arr) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return __dirname + "/wasm"; +} \ No newline at end of file diff --git a/web/src/types.ts b/web/src/types.ts new file mode 100644 index 0000000000000..621375a23f5ff --- /dev/null +++ b/web/src/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts new file mode 100644 index 0000000000000..640f7b4a71637 --- /dev/null +++ b/web/src/webgpu.ts @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import "@webgpu/types"; +import { assert } from "./support"; +import { Pointer } from "./ctypes"; +import { Memory } from "./memory"; + +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; + +/** + * DetectGPU device in the environment. + */ +export async function detectGPUDevice(): Promise { + if (typeof navigator !== "undefined" && navigator.gpu !== undefined) { + const adapter = await navigator.gpu.requestAdapter(); + return await adapter.requestDevice(); + } else { + return undefined; + } +} + +interface FunctionInfo { + name: string; + arg_types: Array; + thread_axis_tags: Array; +} + +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export class WebGPUContext { + device: GPUDevice; + memory: Memory; + + //private readBuffer:; + private bufferTable: Array = [undefined]; + private bufferTableFreeId: Array = []; + private pendingRead: Promise = Promise.resolve(); + private numPendingReads = 0; + + constructor(memory: Memory, device: GPUDevice) { + this.memory = memory; + this.device = device; + } + + /** + * Wait for all pending GPU tasks to complete + */ + async sync(): Promise { + const fence = this.device.defaultQueue.createFence(); + this.device.defaultQueue.signal(fence, 1); + if (this.numPendingReads != 0) { + // eslint-disable-next-line @typescript-eslint/no-empty-function + await Promise.all([fence.onCompletion(1), this.pendingRead]); + } else { + await fence.onCompletion(1); + } + } + + /** + * Create a PackedFunc that runs the given shader + * + * @param info The function information in json. + * @param data The shader data(in SPIRV) + */ + createShader(info: string, data: Uint8Array): Function { + const finfo = JSON.parse(info); + const layoutEntries: Array = []; + for (let i = 0; i < finfo.arg_types.length; ++i) { + const dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ + binding: i, + visibility: GPUShaderStage.COMPUTE, + type: "storage-buffer" + }); + } else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); + } + } + const bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + + const pipeline = this.device.createComputePipeline({ + layout: this.device.createPipelineLayout({ + bindGroupLayouts: [ bindGroupLayout ] + }), + computeStage: { + module: this.device.createShaderModule({ + code: new Uint32Array(data.buffer) + }), + entryPoint: "main" + } + }); + + const dispatchToDim: Array = []; + + for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { + const tag: string = finfo.thread_axis_tags[i]; + if (tag.startsWith("blockIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target); + } else if (tag.startsWith("threadIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target + 3); + } else { + throw new Error("Cannot handle thread_axis " + tag); + } + } + + const submitShader = (...args: Array): void => { + const commandEncoder = this.device.createCommandEncoder(); + const compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + const bindGroupEntries: Array = []; + assert(args.length == layoutEntries.length + dispatchToDim.length); + + for (let i = 0; i < layoutEntries.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: this.gpuBufferFromPtr(args[i]) + } + }); + } + + compute.setBindGroup(0, this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + const wl: Array = [1, 1, 1, 1, 1, 1]; + for (let i = 0; i < dispatchToDim.length; ++i) { + wl[dispatchToDim[i]] = args[layoutEntries.length + i]; + } + compute.dispatch(wl[0], wl[1], wl[2]); + compute.endPass(); + const command = commandEncoder.finish(); + this.device.defaultQueue.submit([command]); + }; + + return submitShader; + } + + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function { + if (name == "deviceAllocDataSpace") { + return (nbytes: number): GPUPointer => { + return this.deviceAllocDataSpace(nbytes); + }; + } else if (name == "deviceFreeDataSpace") { + return (ptr: GPUPointer): void => { + return this.deviceFreeDataSpace(ptr); + }; + } else if (name == "deviceCopyToGPU") { + return ( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyToGPU(from, to, toOffset, nbytes); + }; + } else if (name == "deviceCopyFromGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void => { + this.deviceCopyFromGPU(from, fromOffset, to, nbytes); + }; + } else if (name == "deviceCopyWithinGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyWithinGPU(from, fromOffset, to, toOffset, nbytes); + }; + } else { + throw new Error("Unknown DeviceAPI function " + name); + } + + } + + // DeviceAPI + private deviceAllocDataSpace(nbytes: number): GPUPointer { + const buffer = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + return this.attachToBufferTable(buffer); + } + + private deviceFreeDataSpace(ptr: GPUPointer): void { + const idx = ptr; + const buffer = this.bufferTable[idx]; + this.bufferTable[idx] = undefined; + assert(buffer !== undefined); + this.bufferTableFreeId.push(idx); + buffer.destroy(); + } + + private deviceCopyToGPU( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + const [gpuTemp, cpuTemp] = this.device.createBufferMapped({ + size: nbytes, + usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC, + }); + + const viewU8 = new Uint8Array(cpuTemp); + viewU8.set(this.memory.loadRawBytes(from, nbytes)); + gpuTemp.unmap(); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + gpuTemp, + 0, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + gpuTemp.destroy(); + } + + private deviceCopyFromGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void { + // Perhaps it would be more useful to resuse a staging buffer? + const gpuTemp = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + gpuTemp, + 0, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + + this.numPendingReads += 1; + const readEvent = gpuTemp.mapReadAsync().then((data: ArrayBuffer) => { + this.memory.storeRawBytes(to, new Uint8Array(data)); + this.numPendingReads -= 1; + gpuTemp.destroy(); + }); + + if (this.numPendingReads == 1) { + this.pendingRead = readEvent; + } else { + this.pendingRead = Promise.all([ + this.pendingRead, + readEvent, + // eslint-disable-next-line @typescript-eslint/no-empty-function + ]).then(() => {}); + } + } + + private deviceCopyWithinGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void { + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + } + + private gpuBufferFromPtr(ptr: GPUPointer): GPUBuffer { + const buffer = this.bufferTable[ptr]; + assert(buffer !== undefined); + return buffer; + } + + private attachToBufferTable(buffer: GPUBuffer): GPUPointer { + if (this.bufferTableFreeId.length != 0) { + const idx = this.bufferTableFreeId.pop() as number; + this.bufferTable[idx] = buffer; + return idx; + } else { + const idx = this.bufferTable.length; + this.bufferTable.push(buffer); + return idx; + } + } +} diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js similarity index 64% rename from tests/web/test_module_load.js rename to web/tests/node/test_module_load.js index f4c809536bb53..45e84fd404a9a 100644 --- a/tests/web/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -19,14 +19,18 @@ // Load Emscripten Module, need to change path to root/lib const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm")); + +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Load system library -var sysLib = tvm.systemLib(); +const sysLib = tvm.systemLib(); function randomArray(length, max) { return Array.apply(null, Array(length)).map(function() { @@ -36,23 +40,22 @@ function randomArray(length, max) { function testAddOne() { // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); + const faddOne = sysLib.getFunction("add_one"); + assert(tvm.isPackedFunc(faddOne)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n); // call the function. faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array // verify for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } - faddOne.release(); + faddOne.dispose(); } testAddOne(); -sysLib.release(); +sysLib.dispose(); console.log("Finish verifying test_module_load"); diff --git a/tests/web/test_basic.js b/web/tests/node/test_ndarray.js similarity index 55% rename from tests/web/test_basic.js rename to web/tests/node/test_ndarray.js index 6852319dbc127..ba43621ecb05b 100644 --- a/tests/web/test_basic.js +++ b/web/tests/node/test_ndarray.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -16,31 +16,34 @@ * specific language governing permissions and limitations * under the License. */ - -// Load Emscripten Module, need to change path to root/build const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); // Basic fields. -tvm.assert(tvm.float32 == "float32"); -tvm.assert(tvm.listGlobalFuncNames() !== "undefined"); -var sysLib = tvm.systemLib(); -tvm.assert(typeof sysLib.getFunction !== "undefined"); -sysLib.release(); +assert(tvm.listGlobalFuncNames() !== undefined); // Test ndarray -function testArrayCopy(dtype, arr) { - var data = [1, 2, 3, 4, 5, 6]; - var a = tvm.empty([2, 3], dtype); - a.copyFrom(data); - var ret = a.asArray(); - tvm.assert(ret instanceof arr); - tvm.assert(ret.toString() == arr.from(data)); - a.release(); +function testArrayCopy(dtype, arrayType) { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], dtype).copyFrom(data); + + assert(a.context.toString() == "cpu(0)"); + assert(a.shape[0] == 2 && a.shape[1] == 3); + + let ret = a.toArray(); + assert(ret instanceof arrayType); + assert(ret.toString() == arrayType.from(data).toString()); + // test multiple dispose. + a.dispose(); + a.dispose(); } testArrayCopy("float32", Float32Array); @@ -48,8 +51,3 @@ testArrayCopy("int", Int32Array); testArrayCopy("int8", Int8Array); testArrayCopy("uint8", Uint8Array); testArrayCopy("float64", Float64Array); - -// Function registration -tvm.registerFunc("xyz", function(x, y) { - return x + y; -}); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js new file mode 100644 index 0000000000000..c961f9576e3fd --- /dev/null +++ b/web/tests/node/test_packed_func.js @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +const path = require("path"); +const fs = require("fs"); +const assert = require('assert'); +const tvmjs = require("../../dist") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +function testGetGlobal() { + let flist = tvm.listGlobalFuncNames(); + let faddOne = tvm.getGlobalFunc("testing.add_one"); + let fecho = tvm.getGlobalFunc("testing.echo"); + + assert(faddOne(tvm.scalar(1, "int")) == 2); + // check function argument with different types. + assert(fecho(1123) == 1123); + assert(fecho("xyz") == "xyz"); + + let bytes = new Uint8Array([1, 2, 3]); + let rbytes = fecho(bytes); + assert(rbytes.length == bytes.length); + + for (let i = 0; i < bytes.length; ++i) { + assert(rbytes[i] == bytes[i]); + } + + assert(fecho(undefined) == undefined); + + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); + let arr2 = fecho(arr); + assert(arr.handle == arr2.handle); + assert(arr2.toArray().toString() == arr.toArray().toString()); + + let mod = tvm.systemLib(); + let ret = fecho(mod); + assert(ret.handle == mod.handle); + assert(flist.length != 0); + + mod.dispose(); + ret.dispose(); + arr.dispose(); + arr2.dispose(); + fecho.dispose(); + faddOne.dispose(); +} + +function testReturnFunc() { + function addy(y) { + function add(x, z) { + return x + y + z; + } + return add; + } + + let fecho = tvm.getGlobalFunc("testing.echo"); + let myf = tvm.toPackedFunc(addy); + assert(tvm.isPackedFunc(myf)); + let myf2 = tvm.toPackedFunc(myf); + assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle); + let f = myf(10); + + assert(tvm.isPackedFunc(f)); + assert(f(11, 0) == 21); + assert(f("x", 1) == "x101"); + assert(f("x", "yz") == "x10yz"); + + fecho.dispose(); + myf.dispose(); + myf2.dispose(); + // test multiple dispose. + f.dispose(); + f.dispose(); +} + +function testRegisterGlobal() { + tvm.registerFunc("xyz", function (x, y) { + return x + y; + }); + + let f = tvm.getGlobalFunc("xyz"); + assert(f(1, 2) == 3); + f.dispose(); + + let syslib = tvm.systemLib(); + syslib.dispose(); +} + +function testTimer() { + const fecho = tvm.getGlobalFunc("testing.echo"); + const fgetTimer = tvm.getGlobalFunc("wasm.GetTimer"); + + let finvoke = (n) => { + let x = "xyz"; + for (let i = 0; i < n; ++i) { + x = fecho(x); + } + }; + const number = 10000; + const invokeTimer = fgetTimer(finvoke); + console.log("Time cost:", number / invokeTimer(number) * 1000, " ops/sec"); + fecho.dispose(); + invokeTimer.dispose(); + fgetTimer.dispose(); +} + +testGetGlobal(); +testRegisterGlobal(); +testReturnFunc(); +testTimer(); diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py similarity index 69% rename from tests/web/prepare_test_libs.py rename to web/tests/python/prepare_test_libs.py index a0e2c13eab826..ec4eb5be15368 100644 --- a/tests/web/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -14,27 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Prepare test library for js. +# Prepare test library for standalone wasm runtime test. + import tvm from tvm import te -from tvm.contrib import emscripten +from tvm.contrib import emcc import os + def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path, - options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s", - "USE_WEBGL2=1", "-lglfw"]) + fadd = tvm.build(s, [A, B], target, name="add_one") + + wasm_path = os.path.join(base_path, "test_addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) + prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py new file mode 100644 index 0000000000000..d16ba3f3304ec --- /dev/null +++ b/web/tests/python/webgpu_rpc_test.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Simple testcode to test Javascript RPC + +To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy". +Connect javascript end to the websocket port and connect to the RPC. +""" + +import tvm +from tvm import te +from tvm import rpc +from tvm.contrib import util, emcc +import numpy as np + +proxy_host = "localhost" +proxy_port = 9090 + + +def test_rpc(): + if not tvm.runtime.enabled("rpc"): + return + # generate the wasm library + target_device = "webgpu" + target_host = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target_host): + raise RuntimeError("Target %s is not enbaled" % target_host) + + n = 2048 + A = te.placeholder((n,), name='A') + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = te.create_schedule(B.op) + + num_thread = 2 + xo, xi = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(xi, te.thread_axis("threadIdx.x")) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + + + fadd = tvm.build(s, [A, B], target_device, target_host=target_host, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone_gpu.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + ctx = remote.webgpu(0) + adata = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(adata, ctx) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), ctx) + + np.testing.assert_equal(a.asnumpy(), adata) + f1 = remote.system_lib() + addone = f1.get_function("addone") + addone(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + print("Test pass..") + + check(remote) + +test_rpc() diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py similarity index 53% rename from tests/web/websock_rpc_test.py rename to web/tests/python/websock_rpc_test.py index 8be8ce04cb75d..f7c07924a210d 100644 --- a/tests/web/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -22,45 +22,63 @@ import tvm from tvm import te -import os from tvm import rpc -from tvm.contrib import util, emscripten +from tvm.contrib import util, emcc import numpy as np proxy_host = "localhost" proxy_port = 9090 -def test_rpc_array(): +def test_rpc(): if not tvm.runtime.enabled("rpc"): return - # graph - n = tvm.runtime.convert(1024) + # generate the wasm library + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - remote = rpc.connect(proxy_host, proxy_port, key="js") - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - def check_remote(): - if not tvm.runtime.enabled(target): - print("Skip because %s is not enabled" % target) - return - temp = util.tempdir() + + fadd = tvm.build(s, [A, B], target, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + faddone = remote.get_function("testing.asyncAddOne") + fecho = remote.get_function("testing.echo") + assert(faddone(100) == 101) + assert(fecho(1, 2, 3) == 1) + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + # run the generated library. + f1 = remote.system_lib() ctx = remote.cpu(0) - f = tvm.build(s, [A, B], target, name="myadd") - path_obj = temp.relpath("dev_lib.bc") - path_dso = temp.relpath("dev_lib.js") - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - # Upload to suffix as dso so it can be loaded remotely - remote.upload(path_dso, "dev_lib.dso") - data = remote.download("dev_lib.dso") - f1 = remote.load_module("dev_lib.dso") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + # invoke the function + addone = f1.get_function("addone") + addone(a, b) + + # time evaluator + time_f = f1.time_evaluator("addone", ctx, number=100, repeat=10) + time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() -test_rpc_array() + check(remote) + +test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 0000000000000..6aec44858a7ae --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js deleted file mode 100644 index b62b298d969e2..0000000000000 --- a/web/tvm_runtime.js +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/** - * TVM Javascript web runtime library. - * - * @projectname tvm - * @version 0.7.dev1 - */ -/* eslint no-unused-vars: "off" */ -/* eslint no-unexpected-multiline: "off" */ -/* eslint indent: "off" */ -/* eslint no-console: "off" */ -/** - * TVM Runtime namespace. - * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. - * - * @namespace tvm_runtime - */ -var tvm_runtime = tvm_runtime || {}; - -/** - * TVM root namespace. - * The classes inside this namespace need to be constructed by factory functions. - * Use {@link tvm_runtime}.create to get started. - * - * @namespace tvm - */ -(function() { - /** - * TVMRuntime object for interacting with TVM runtime. - * This object can be constructed using {@link tvm_runtime}.create - * - * @class - * @memberof tvm - */ - function TVMRuntime() { - "use strict"; - var runtime_ref = this; - // Utility function to throw error - function throwError(message) { - if (typeof runtime_ref.logger !== "undefined") { - runtime_ref.logger(message); - } - if (typeof Error !== "undefined") { - throw new Error(message); - } - throw message; - } - var Module = this.Module; - var Runtime = this.Runtime; - if (typeof Module === "undefined") { - throwError("Emscripten Module is not available"); - } - // constants - var SIZEOF_POINTER = 4; - var SIZEOF_SIZE_T = 4; - var SIZEOF_FLOAT = 4; - var SIZEOF_INT = 4; - var SIZEOF_INT8 = 1; - var SIZEOF_INT64 = 8; - var SIZEOF_DOUBLE = 8; - var SIZEOF_TYPE = 4; - var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT; - var SIZEOF_TVMVALUE = SIZEOF_DOUBLE; - var ARRAY_OFFSET_DATA = 0; - var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER; - var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX; - var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT; - var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX; - var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT; - var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE; - var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8; - var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8; - var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE; - var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - // Type codes - var kInt = 0; - var kUInt = 1; - var kFloat = 2; - var kTVMOpaqueHandle = 3; - var kNull = 4; - var kTVMDataType = 5; - var kTVMContext = 6; - var kTVMDLTensorHandle = 7; - var kTVMObjectHandle = 8; - var kTVMModuleHandle = 9; - var kTVMPackedFuncHandle = 10; - var kTVMStr = 11; - var kTVMBytes = 12; - var kTVMObjectRValueRefArg = 14; - //----------------------------------------- - // TVM CWrap library - // ---------------------------------------- - var TVMGetLastError = Module.cwrap( - "TVMGetLastError", - "string", // const char* - []); - - var TVMAPISetLastError = Module.cwrap - ("TVMAPISetLastError", - null, - ["string" // const char* - ]); - - var TVMModImport = Module.cwrap - ("TVMModImport", - "number", - ["number", // TVMModuleHandle mod - "number" // TVMModuleHandle dep - ]); - - var TVMModGetFunction = Module.cwrap - ("TVMModGetFunction", - "number", - ["number", // TVMModuleHandle mod - "string", // const char* func_name - "number", // int query_imports - "number" // TVMFunctionHandle *out - ]); - - var TVMModFree = Module.cwrap - ("TVMModFree", - "number", - ["number" // TVMModeHandle mod - ]); - - var TVMFuncFree = Module.cwrap - ("TVMFuncFree", - "number", - ["number" // TVMFunctionHandle func - ]); - - var TVMFuncCall = Module.cwrap - ("TVMFuncCall", - "number", - ["number", // TVMFunctionHandle func - "number", // TVMValue* arg_values - "number", // int* arg_tcodes - "number", // int num_args - "number", // int ret_val - "number" // int ret_type_code - ]); - - var TVMCFuncSetReturn = Module.cwrap - ("TVMCFuncSetReturn", - "number", - ["number", // TVMRetValueHandle ret - "number", // TVMValue* value - "number", // int* type_code - "number" // int num_ret - ]); - - var TVMCbArgToReturn = Module.cwrap - ("TVMCbArgToReturn", - "number", - ["number", // TVMValue* value - "number" // int* code - ]); - - var TVMFuncCreateFromCFunc = Module.cwrap - ("TVMFuncCreateFromCFunc", - "number", - ["number", // TVMPackedCFunc func, - "number", // void* resource_handle - "number", // TVMPackedCFuncFinalizer fin - "number" // TVMFunctionHandle *out - ]); - - var TVMFuncRegisterGlobal = Module.cwrap - ("TVMFuncRegisterGlobal", - "number", - ["string", // name - "number", // TVMFunctionHandle f - "number" // int override - ]); - - var TVMFuncGetGlobal = Module.cwrap - ("TVMFuncGetGlobal", - "number", - ["string", // const char* name - "number" // TVMFunctionHandle* out - ]); - - var TVMFuncListGlobalNames = Module.cwrap - ("TVMFuncListGlobalNames", - "number", - ["number", // int* out_size - "number" // const char*** out_array - ]); - - - var TVMArrayAlloc = Module.cwrap - ("TVMArrayAlloc", - "number", - ["number", // const tvm_index_t* shape - "number", // int ndim - "number", // int dtype_code - "number", // int dtype_bits - "number", // int dtype_lanes - "number", // int device_type - "number", // int device_id - "number" // int TVMArrayHandle* out - ]); - - var TVMArrayFree = Module.cwrap - ("TVMArrayFree", - "number", - ["number" // TVMArrayHandle handle - ]); - - var TVMArrayCopyFromTo = Module.cwrap - ("TVMArrayCopyFromTo", - "number", - ["number", // TVMArrayHandle from - "number" // TVMArrayHandle to - ]); - - var TVMArrayCopyFromBytes = Module.cwrap - ("TVMArrayCopyFromBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMArrayCopyToBytes = Module.cwrap - ("TVMArrayCopyToBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMModLoadFromFile = Module.cwrap - ("TVMModLoadFromFile", - "number", - ["string", // const char* file_name - "string", // const char* format - "number" // TVMModuleHandle* out - ]) - - //----------------------------------------- - // Static utility functions - // ---------------------------------------- - this.assert = function(condition, message) { - if (!condition) { - message = message || "assert failed"; - throwError(message); - } - }; - /** - * Logging function. - * Override this to change logger behavior. - * - * @param {string} message - */ - this.logger = function(message) { - console.log(message); - }; - - function logging(message) { - runtime_ref.logger(message); - } - // Override print error to logging - Module.printErr = logging; - var CHECK = this.assert; - - function TVM_CALL(ret) { - if (ret != 0) { - throwError(TVMGetLastError()); - } - } - - function CInt64ArrayToJS(ptr, size) { - var ret = []; - for (var i = 0; i < size; ++i) { - ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64")); - } - return ret; - } - - function CStringToJS(ptr) { - var ret = []; - var ch = 1; - while (ch != 0) { - ch = Module.getValue(ptr, "i8"); - if (ch != 0) { - ret.push(String.fromCharCode(ch)); - } - ++ptr; - } - return ret.join(""); - } - - function CBytesToJS(ptr) { - var data = Module.getValue(ptr, "*"); - var size = Module.getValue(ptr + SIZEOF_POINTER, "i32"); - var ret = new Uint8Array(new ArrayBuffer(size)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size)); - return ret; - } - - function StringToUint8Array(str) { - var arr = new Uint8Array(str.length + 1); - for(var i = 0; i < str.length; ++i) { - arr[i] = str.charCodeAt(i); - } - arr[str.length] = 0; - return arr; - } - //----------------------------------------- - // Class declarations - // ---------------------------------------- - function CBuffer(nbytes) { - this.data = Module._malloc(nbytes); - } - - function RefTVMValue() { - this.data = Module._malloc(SIZEOF_TVMVALUE); - } - - function TVMArgs(nargs) { - this.nargs = nargs; - this.value = Module._malloc(SIZEOF_TVMVALUE * nargs); - this.tcode = Module._malloc(SIZEOF_INT * nargs); - this.temp = []; - } - - function TVMType(code, bits, lanes) { - this.code = code; - this.bits = bits; - this.lanes = lanes; - } - /** - * TVM device context. - * @class - * @memberof tvm - */ - function TVMContext(device_type, device_id) { - this.device_type = device_type; - this.device_id = device_id; - } - /** - * TVM n-dimensional array. - * - * Use {@link tvm.TVMRuntime}.empty to create an instance. - * @class - * @memberof tvm - */ - function NDArray(handle) { - this.handle = handle; - this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32"); - // shape - var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*"); - this.shape = CInt64ArrayToJS(cshape, this.ndim); - // dtype - var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8"); - var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8"); - var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16"); - var dtype = new TVMType(code, bits, lanes); - this.dtype = dtype; - this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8); - // ctx - var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32"); - var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32"); - this.context = new TVMContext(device_type, device_id); - // byte_offset - this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64"); - } - - function TVMFunction(handle) { - this.handle = handle; - } - /** - * Module container of TVM generated functions. - * - * @class - * @memberof tvm - */ - function TVMModule(handle) { - this.handle = handle; - } - /** - * A typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * Use {@link tvm.TVMRuntime}.constant to create an instance. - * @class - * @memberof tvm - */ - function TVMConstant(value, dtype) { - this.value = value; - this.dtype = dtype; - } - //----------------------------------------- - // Private Functions - // ---------------------------------------- - function getTVMType(dtype) { - if (dtype instanceof TVMType) return dtype; - if (typeof dtype == "string") { - var pattern = dtype; - var code, bits = 32, lanes = 1; - if (pattern.substring(0, 5) == "float") { - pattern = pattern.substring(5, pattern.length); - code = kFloat; - } else if (pattern.substring(0, 3) == "int") { - pattern = pattern.substring(3, pattern.length); - code = kInt; - } else if (pattern.substring(0, 4) == "uint") { - pattern = pattern.substring(4, pattern.length); - code = kUInt; - } else if (pattern.substring(0, 6) == "handle") { - pattern = pattern.substring(5, pattern.length); - code = kTVMOpaqueHandle; - bits = 64; - } else { - throw throwError("Unknown dtype " + dtype); - } - var arr = pattern.split("x"); - if (arr.length >= 1) { - var parsed = parseInt(arr[0]); - if (parsed == arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new TVMType(code, bits, lanes); - } else { - throw throwError("Unknown dtype " + dtype); - } - } - - function TVMRetValueToJS(vptr, tcode) { - switch (tcode) { - case kInt: - case kUInt: return Module.getValue(vptr, "i64"); - case kFloat: return Module.getValue(vptr, "double"); - case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); - case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); - case kNull: return null; - case kTVMStr: return CStringToJS(Module.getValue(vptr, "*")); - case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*")); - default: throwError("Unsupported return type code=" + tcode); - } - } - - function makeTVMFunction(handle) { - var func = new TVMFunction(handle); - var ret = function () { - // alloc - var args = new TVMArgs(arguments.length); - var rvalue = new RefTVMValue(); - var rtcode = new RefTVMValue(); - args.setArguments(arguments); - TVM_CALL(TVMFuncCall(handle, args.value, args.tcode, - args.nargs, rvalue.data, rtcode.data)); - var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt()); - // release - args.release(); - rvalue.release(); - rtcode.release(); - return rv; - }; - var release = function() { - func.release(); - }; - ret._tvm_function = func; - ret.release = release; - return ret; - } - //----------------------------------------- - // Javascript PackedCallback System - // ---------------------------------------- - var funcTable = [0]; - var freeFuncId = []; - - function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) { - var args = []; - for (var i = 0; i < nargs; ++i) { - var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcodeptr = arg_tcode + i * SIZEOF_INT; - var tcode = Module.getValue(tcodeptr, "i32"); - if (tcode == kTVMObjectHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMPackedFuncHandle || - tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); - } - tcode = Module.getValue(tcodeptr, "i32"); - args.push(TVMRetValueToJS(vptr, tcode)); - } - var rv = funcTable[handle].apply(null, args); - if (typeof rv !== "undefined") { - // alloc - var rarg = new TVMArgs(1); - rarg.setArguments([rv]); - TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1)); - // release - rarg.release(); - } - return 0; - } - function freeCallback(handle) { - funcTable[handle] = 0; - freeFuncId.push(handle); - } - var fptrInvokeCallback = null; - var fptrFreeCallback = null; - if (typeof Runtime !== "undefined" && - typeof Runtime.addFunction !== "undefined") { - fptrInvokeCallback = Runtime.addFunction(invokeCallback); - fptrFreeCallback = Runtime.addFunction(freeCallback); - } - /** - * Check if a function is TVM PackedFunc - * @param {Function} f function to be checked. - * @return {boolean} Whether f is PackedFunc - */ - this.isPackedFunc = function(f) { - return (typeof f == "function") && f.hasOwnProperty("_tvm_function"); - }; - var isPackedFunc = this.isPackedFunc; - /** - * Convert a javascript function to TVM function. - * @param {Function} f javascript function. - * @return {Function} The created TVMFunction. - */ - this.convertFunc = function(f) { - if (isPackedFunc(f)) return f; - CHECK(fptrInvokeCallback !== null, - "Emscripten Runtime addFunction is not available"); - var fid; - if (freeFuncId.length != 0) { - fid = freeFuncId.pop(); - } else { - fid = funcTable.length; - funcTable.push(0); - } - funcTable[fid] = f; - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncCreateFromCFunc( - fptrInvokeCallback, fid, fptrFreeCallback, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - return makeTVMFunction(out_handle); - }; - var convertFunc = this.convertFunc; - //----------------------------------------- - // Private Class declarations - // ---------------------------------------- - CBuffer.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - }; - // RefTVMValue - RefTVMValue.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - asInt : function() { - return Module.getValue(this.data, "i32"); - }, - asInt64 : function() { - return Module.getValue(this.data, "i64"); - }, - asDouble : function() { - return Module.getValue(this.data, "double"); - }, - asHandle : function() { - return Module.getValue(this.data, "*"); - } - }; - // TVMArgs - TVMArgs.prototype = { - release : function() { - if (this.value != 0) { - Module._free(this.value); - Module._free(this.tcode); - this.value = 0; - for (var i = 0; i< this.temp.length; ++i) { - if (this.temp[i].release instanceof Function) { - this.temp[i].release(); - } - } - } - }, - setInt : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64"); - }, - setDouble : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double"); - }, - setHandle : function(index, value, tcode) { - Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*"); - }, - setString : function(index, value) { - var sdata = new CBuffer(value.length + 1); - Module.HEAPU8.set(StringToUint8Array(value), sdata.data); - this.temp.push(sdata); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); - }, - setBytes : function(index, value) { - CHECK(value instanceof Uint8Array); - var sdata = new CBuffer(value.length); - var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T); - Module.HEAPU8.set(new Uint8Array(value), sdata.data); - Module.setValue(sheader.data, sdata.data, "*"); - Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); - this.temp.push(sdata); - this.temp.push(sheader); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); - }, - setArguments : function(args) { - for (var i = 0; i < args.length; ++i) { - var v = args[i]; - var tp = typeof v; - if (v instanceof NDArray) { - this.setHandle(i, v.handle, kTVMDLTensorHandle); - } else if (v instanceof TVMConstant) { - var code = getTVMType(v.dtype).code; - if (code == kInt || code == kUInt) { - this.setInt(i, v.value); - } else if (code == kFloat) { - this.setDouble(i, v.value); - } else { - CHECK(code == kTVMOpaqueHandle); - this.setHandle(i, v.value, kTVMOpaqueHandle); - } - } else if (tp == "number") { - this.setDouble(i, v); - } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { - this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v === null) { - this.setHandle(i, 0, kNull); - } else if (tp == "string") { - this.setString(i, v); - } else if (v instanceof Uint8Array) { - this.setBytes(i, v); - } else if (v instanceof Function) { - v = convertFunc(v); - this.temp.push(v); - this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v instanceof TVMModule) { - this.setHandle(i, v.handle, kTVMModuleHandle); - } else { - throwError("Unsupported argument type " + tp); - } - } - } - }; - // TVMType - var TYPE_CODE2STR = { - 0 : "int", - 1 : "uint", - 2 : "float", - 4 : "handle" - }; - - TVMType.prototype = { - toString : function() { - var ret = TYPE_CODE2STR[this.code] + this.bits.toString(); - if (this.lanes != 1) { - return ret + "x" + this.lanes.toString(); - } else { - return ret; - } - } - }; - // TVMFunction - TVMFunction.prototype = { - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMFuncFree(this.handle)); - this.handle = 0; - } - } - }; - // TVMContext - var CTX_MASK2STR = { - 1 : "cpu", - 2 : "gpu", - 4 : "opencl", - 7 : "vulkan", - 8 : "metal", - 9 : "vpi", - 11 : "opengl", - }; - var CTX_STR2MASK = { - "cpu": 1, - "gpu": 2, - "cuda": 2, - "cl": 4, - "opencl": 4, - "vulkan": 7, - "metal": 8, - "vpi": 9, - "opengl": 11, - }; - TVMContext.prototype = { - toString : function() { - return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")"; - } - }; - //----------------------------------------- - // Public Functions - // ---------------------------------------- - /** - * Construct a TVMContext given device type and id. - * - * @param {number} device_type, string or int, The device type. - * @param {number} device_id, the device id. - * @return {tvm.TVMContext} The created TVMContext - */ - this.context = function(device_type, device_id) { - if (typeof device_type == "string") { - device_type = CTX_STR2MASK[device_type]; - } - return new TVMContext(device_type, device_id); - }; - var context = this.context; - /** - * Create empty ndarray with given shape. - * - * @param {Array.} shape The shape of the array. - * @param {string} dtype The data type of the array, optional, default="float32" - * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0). - * @return {tvm.NDArray} The created ndarray. - */ - this.empty = function(shape, dtype, ctx) { - dtype = (typeof dtype !== "undefined") ? dtype: "float32"; - ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0); - shape = (typeof shape == "number") ? [shape] : shape; - // alloc - var cshape = Module._malloc(SIZEOF_INT64 * shape.length); - var out = new RefTVMValue(); - for (var i = 0; i < shape.length; ++i) { - Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64"); - } - dtype = getTVMType(dtype); - TVM_CALL(TVMArrayAlloc(cshape, shape.length, - dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, - out.data)); - var out_handle = out.asHandle(); - // release - Module._free(cshape); - out.release(); - return new NDArray(out_handle); - }; - /** - * List all global function names in the TVM runtime. - * @return {Array.} List of global function names. - */ - this.listGlobalFuncNames = function() { - // alloc - var out_size = new RefTVMValue(); - var out_array = new RefTVMValue(); - TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data)); - var length = out_size.asInt(); - var base = out_array.asHandle(); - var names = []; - for (var i = 0 ; i < length; ++i) { - names.push( - CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*"))); - } - // release - out_size.release(); - out_array.release(); - return names; - }; - var listGlobalFuncNames = this.listGlobalFuncNames; - /** - * Get a global function from TVM runtime. - * - * @param {string} The name of the function. - * @return {Function} The corresponding function, null if function do not exist - */ - this.getGlobalFunc = function (name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncGetGlobal(name, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return makeTVMFunction(out_handle); - } else { - return null; - } - }; - var getGlobalFunc = this.getGlobalFunc; - /** - * Register function to be global function in tvm runtime. - * @param {string} name The name of the function. - * @param {Function} f function to be registered. - * @param {boolean} override Whether overwrite function in existing registry. - */ - this.registerFunc = function(name, f, override) { - f = convertFunc(f); - override = (typeof override !== "undefined") ? override: false; - var ioverride = override ? 1 : 0; - TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride)); - }; - /** - * Create a typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * - * @param {number} value The value of the data. - * @param {string} dtype The data type. - * @param {tvm.TVMConstant} The created typed scalar. - */ - this.constant = function(value, dtype) { - return new TVMConstant(value, dtype); - }; - //----------------------------------------- - // Wrap of TVM Functions. - // ---------------------------------------- - var systemFunc = {}; - /** - * Get system-wide library module singleton.5A - * System lib is a global module that contains self register functions in startup. - * @return {tvm.TVMModule} The system module singleton. - */ - this.systemLib = function() { - if (typeof systemFunc.fGetSystemLib === "undefined") { - systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib"); - } - return systemFunc.fGetSystemLib(); - }; - - this.startRPCServer = function(url, key, counter) { - if (typeof key === "undefined") { - key = ""; - } - if (typeof counter === "undefined") { - counter = 1; - } - // Node js, import websocket - var bkey = StringToUint8Array("server:" + key); - bkey = bkey.slice(0, bkey.length - 1); - var server_name = "WebSocketRPCServer[" + key + "]"; - var RPC_MAGIC = 0xff271; - function checkEndian() { - var a = new ArrayBuffer(4); - var b = new Uint8Array(a); - var c = new Uint32Array(a); - b[0] = 0x11; - b[1] = 0x22; - b[2] = 0x33; - b[3] = 0x44; - CHECK(c[0] === 0x44332211, "Need little endian to work"); - } - checkEndian(); - // start rpc - function RPCServer(counter) { - var socket; - if (typeof module !== "undefined" && module.exports) { - // WebSocket for nodejs - const WebSocket = require("ws"); - socket = new WebSocket(url); - } else { - socket = new WebSocket(url); - } - var self = this; - socket.binaryType = "arraybuffer"; - this.init = true; - this.counter = counter; - - if (typeof systemFunc.fcreateServer === "undefined") { - systemFunc.fcreateServer = - getGlobalFunc("rpc._CreateEventDrivenServer"); - } - if (systemFunc.fcreateServer == null) { - throwError("RPCServer is not included in runtime"); - } - - var message_handler = systemFunc.fcreateServer( - function(cbytes) { - if (socket.readyState == 1) { - socket.send(cbytes); - return new TVMConstant(cbytes.length, "int32"); - } else { - return new TVMConstant(0, "int32"); - } - } , server_name, "%toinit"); - - function on_open(event) { - var intbuf = new Int32Array(1); - intbuf[0] = RPC_MAGIC; - socket.send(intbuf); - intbuf[0] = bkey.length; - socket.send(intbuf); - socket.send(bkey); - logging(server_name + " connected..."); - } - - function on_message(event) { - if (self.init) { - var msg = new Uint8Array(event.data); - CHECK(msg.length >= 4, "Need message header to be bigger than 4"); - var magic = new Int32Array(event.data)[0]; - - if (magic == RPC_MAGIC + 1) { - throwError("key: " + key + " has already been used in proxy"); - } else if (magic == RPC_MAGIC + 2) { - logging(server_name + ": RPCProxy do not have matching client key " + key); - } else { - CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy"); - self.init = false; - } - logging(server_name + "init end..."); - if (msg.length > 4) { - if (message_handler( - new Uint8Array(event.data, 4, msg.length -4), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } else { - if (message_handler(new Uint8Array(event.data), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } - function on_close(event) { - message_handler.release(); - logging(server_name + ": closed finish..."); - if (!self.init && self.counter != 0) { - logging(server_name + ":reconnect to serve another request, session left=" + counter); - // start a new server. - new RPCServer(counter - 1); - } - } - socket.addEventListener("open", on_open); - socket.addEventListener("message", on_message); - socket.addEventListener("close", on_close); - } - return new RPCServer(counter); - }; - - /** - * Load a TVM module from a library file. - * The file must be present in the Emscripten virtual file system. - * For example, you can pass "--preload-file file" or "--preload-file dir/" - * to "emcc" when compiling the TVM library, in order to populate files into - * the file system. - * For more detail, see: - * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files - * @param {string} file_name Path of the file to be loaded. The path refers - * to the Emscripten virtual file system. - * @param {string} format The format of the file. - * @return {tvm.TVMModule} The loaded module. - */ - this.loadModuleFromFile = function (file_name, format) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModLoadFromFile(file_name, format, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return new TVMModule(out_handle); - } else { - return null; - } - }; - var loadModuleFromFile = this.loadModuleFromFile; - - /** - * Wrapper runtime module. - * Wraps around set_input, load_params, run, and get_output. - * - * @class - * @memberof tvm - */ - function GraphModule(tvm_graph_module, ctx) { - CHECK(tvm_graph_module instanceof TVMModule, - "tvm_graph_module must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - this.tvm_graph_module = tvm_graph_module; - this.ctx = ctx; - this._set_input = tvm_graph_module.getFunction("set_input"); - this._load_params = tvm_graph_module.getFunction("load_params"); - this._run = tvm_graph_module.getFunction("run"); - this._get_output = tvm_graph_module.getFunction("get_output"); - }; - - GraphModule.prototype = { - /** - * Set input to graph module. - * - * @param {string} key The name of the input. - * @param {NDArray} value The input value. - */ - "set_input" : function(key, value) { - CHECK(typeof key == "string", "key must be string"); - CHECK(value instanceof NDArray, "value must be NDArray"); - this._set_input(key, value); - }, - - /** - * Load parameters from serialized byte array of parameter dict. - * - * @param {Uint8Array} params The serialized parameter dict. - */ - "load_params" : function(params) { - CHECK(params instanceof Uint8Array, "params must be Uint8Array"); - this._load_params(params); - }, - - /** - * Load parameters from serialized base64 string of parameter dict. - * - * @param {string} base64_params The serialized parameter dict. - */ - "load_base64_params" : function(base64_params) { - CHECK(typeof base64_params == "string", "base64_params must be string"); - var decoded_string = atob(base64_params); - var decoded_u8 = new Uint8Array(decoded_string.length); - for (var i = 0; i < decoded_string.length; i++) { - decoded_u8[i] = decoded_string[i].charCodeAt(0); - } - this.load_params(decoded_u8); - }, - - /** - * Run forward execution of the graph. - */ - "run" : function() { - this._run(); - }, - - /** - * Get index-th output to out. - * - * @param {NDArray} out The output array container. - * @return {NDArray} The output array container. - */ - "get_output" : function(index, out) { - CHECK(typeof index == "number", "index must be number"); - CHECK(out instanceof NDArray, "out must be NDArray"); - this._get_output(new TVMConstant(index, "int32"), out); - return out; - } - }; - - /** - * Create a runtime executor module given a graph and a module. - * @param {string} graph_json_str The Json string of the graph. - * @param {TVMModule} libmod The TVM module. - * @param {TVMContext} ctx The context to deploy the module. - * @return {GraphModule} Runtime graph module for executing the graph. - */ - this.createGraphRuntime = function(graph_json_str, libmod, ctx) { - CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); - CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - var fcreate = getGlobalFunc("tvm.graph_runtime.create"); - CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); - - var tvm_graph_module = fcreate(graph_json_str, libmod, - new TVMConstant(ctx.device_type, "int32"), - new TVMConstant(ctx.device_id, "int32")); - - return new GraphModule(tvm_graph_module, ctx); - }; - - //----------------------------------------- - // Class defintions - // ---------------------------------------- - // NDArray. - NDArray.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMArrayFree(this.handle)); - this.handle = 0; - } - }, - /** - * Copy data from another NDArray or javascript array. - * The number of elements must match. - * - * @param {Array} data The source data array. - */ - copyFrom : function(data) { - if (data instanceof NDArray) { - TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle)); - } else { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - if (data.length != size) { - throwError("data size and shape mismatch data.length" + data.length + " vs " + size); - } - if (this.dtype == "float32") { - data = Float32Array.from(data); - } else if (this.dtype == "float64") { - data = Float64Array.from(data); - } else if (this.dtype == "int32") { - data = Int32Array.from(data); - } else if (this.dtype == "int8") { - data = Int8Array.from(data); - } else if (this.dtype == "uint8") { - data = Uint8Array.from(data); - } else { - throwError("Unsupported data type " + this.dtype); - } - return this.copyFromRawBytes(new Uint8Array(data.buffer)); - } - }, - /** - * Copy data from raw bytes. - * @param {Uint8Array} data Uint8Array of bytes. - */ - copyFromRawBytes : function(data) { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var dtype = getTVMType(this.dtype); - var nbytes = this.BYTES_PER_ELEMENT * size; - CHECK(data instanceof Uint8Array); - CHECK(data.length == nbytes, - "Data length and bytes do not match " + data.length + - " vs " + nbytes); - var temp = Module._malloc(nbytes); - Module.HEAPU8.set(data, temp); - TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes)); - Module._free(temp); - return this; - }, - /** - * Return a copied Uint8Array of the raw bytes in the NDArray. - * @return {Uint8Array} The created array. - */ - asRawBytes : function() { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var nbytes = this.BYTES_PER_ELEMENT * size; - var temp = Module._malloc(nbytes); - TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes)); - var ret = new Uint8Array(new ArrayBuffer(nbytes)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes)); - Module._free(temp); - return ret; - }, - /** - * Return Array data content as javascript typed array. - * @return {TypedArray} The created array. - */ - asArray : function() { - if (this.dtype == "float32") { - return new Float32Array(this.asRawBytes().buffer); - } else if (this.dtype == "float64") { - return new Float64Array(this.asRawBytes().buffer); - } else if (this.dtype == "int32") { - return new Int32Array(this.asRawBytes().buffer); - } else if (this.dtype == "int8") { - return new Int8Array(this.asRawBytes().buffer); - } else if (this.dtype == "uint8") { - return new Uint8Array(this.asRawBytes().buffer); - } else { - throwError("Unsupported data type " + this.dtype); - } - } - }; - - TVMModule.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMModFree(this.handle)); - this.handle = 0; - } - }, - /** - * Get function from the module. - * @param {string} name The name of the function. - * @return {Function} The correspondin function. - */ - getFunction : function(name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle == 0) { - throwError("Module has no function " + name); - } - return makeTVMFunction(out_handle); - }, - /** - * Add module to the import list of current one. - * @param {tvm.TVMModule} mod The other module to be imported. - */ - import_module : function(mod) { - CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule"); - TVM_CALL(TVMModImport(this.handle, mod.handle)); - } - }; - //----------------------------------------- - // Static variables. - // ---------------------------------------- - /** Float32 type */ - this.float32 = "float32"; - /** Int32 type */ - this.int32 = "int32"; - } - /** - * Create a TVM runtime given emscripten module. - * @property {string} create - * @memberof tvm_runtime - * @param Module The emscripten module. - * @return {tvm.TVMRuntime} The created TVM runtime. - */ - this.create = function(Module) { - var tvm = {}; - tvm.Module = Module; - if (typeof Module.addFunction !== "undefined") { - tvm.Runtime = Module; - } else { - tvm.Runtime = Module.Runtime; - } - TVMRuntime.apply(tvm); - return tvm; - }; -}).apply(tvm_runtime); - -// export things in node -if (typeof module !== "undefined" && module.exports) { - module.exports = tvm_runtime; -} diff --git a/web/web_runtime.cc b/web/web_runtime.cc deleted file mode 100644 index 701ded76288e0..0000000000000 --- a/web/web_runtime.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file web_runtime.cc - */ -#include -#include - -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" -#include "../src/runtime/module.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" -#include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" -#include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/opengl/opengl_device_api.cc" -#include "../src/runtime/opengl/opengl_module.cc" - -namespace tvm { -namespace contrib { - -struct RPCEnv { - public: - RPCEnv() { - base_ = "/rpc"; - mkdir(&base_[0], 0777); - } - // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + "/" + file_name; - } - - private: - std::string base_; -}; - -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { - static RPCEnv env; - return env.GetPath(path); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { - std::string file_name = "/rpc/" + path; - LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); - }); -} // namespace contrib -} // namespace tvm - -// dummy parallel runtime -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMAPISetLastError("Parallel is not supported in Web runtime"); - return -1; -} - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -}