-
Notifications
You must be signed in to change notification settings - Fork 335
Closed
Description
Overview
- A suite of pytest tests for tilelang is exhibiting multiple failures when run on an NVIDIA H100 GPU. The failures span various categories including TVM Error InternalError, issues related to Tensor Memory Accelerator (TMA) usage, numerical discrepancies, and test hangs.
Checked Directories
- tilelang/testing/python/language
- tilelang/testing/python/pass_config
- tilelang/testing/python/primitives
- tilelang/testing/python/tilelibrary
- tilelang/testing/python/transform
- tilelang/testing/python/webgpu
Environment & Reproduction
- GPU: NVIDIA H100
- Testing Method: Tests are run individually using pytest <path_to_test_file.py> to avoid potential interference between tests.
Observed Failures
- TVM Error InternalError
tilelang/testing/python/language/test_tilelang_language_reshape.py::test_reshape_smem_shared__________________________________________________________test_reshape_smem_shared___________________________________________________________ def test_reshape_smem_shared(): > run_reshape_smem(1024, 32, "float32") test_tilelang_language_reshape.py:74: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ test_tilelang_language_reshape.py:64: in run_reshape_smem jit_kernel = tl.compile(program, out_idx=-1) tilelang/tilelang/jit/__init__.py:71: in compile return cached( tilelang/tilelang/cache/__init__.py:30: in cached return _kernel_cache_instance.cached( tilelang/tilelang/cache/kernel_cache.py:163: in cached kernel = JITKernel( tilelang/tilelang/jit/kernel.py:111: in __init__ adapter = self._compile_and_create_adapter(func, out_idx) tilelang/tilelang/jit/kernel.py:199: in _compile_and_create_adapter artifact = tilelang.lower( tilelang/tilelang/engine/lower.py:232: in lower mod = OptimizeForTarget(mod, target) tilelang/tilelang/engine/phase.py:71: in OptimizeForTarget mod = tilelang.transform.WarpSpecialized()(mod) tilelang/3rdparty/tvm/python/tvm/ir/transform.py:238: in __call__ return _ffi_transform_api.RunPass(self, mod) tilelang/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py:239: in __call__ raise_last_ffi_error() _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ def raise_last_ffi_error(): ... tb = _filter_traceback_frames(tb, filter_funcs) py_err = py_err.with_traceback(tb) # The exception PyObject may contain a large amount of state, # including all stack frames that may be inspected in a later # PDB post-mortem. Therefore, we must make sure to remove the # underlying PyObject* from the C++ side after we retrieve it. _LIB.TVMDropLastPythonError() > raise py_err E tvm.error.InternalError: Traceback (most recent call last): E 28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue) E 27: tvm::transform::Pass::operator()(tvm::IRModule) const E 26: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const E 25: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const E 24: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tl::WarpSpecialized()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tl::WarpSpecialized()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) E 23: tvm::tl::WarpSpecializedRewriter::Substitute(tvm::tir::PrimFunc, bool) E 22: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 21: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9 E 20: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::BlockRealizeNode const*) E 19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9 E 17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 16: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runti E 15: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::AttrStmtNode const*) E 14: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 13: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runti E 12: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::AttrStmtNode const*) E 11: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 10: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runti E 9: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::AttrStmtNode const*) E 8: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runti E 6: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::AttrStmtNode const*) E 5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9 E 3: tvm::tl::WarpSpecializedRewriter::VisitStmt_(tvm::tir::BlockRealizeNode const*) E 2: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&) E 1: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9 E 0: tvm::tl::WSCodeEmitter::VisitStmt_(tvm::tir::SeqStmtNode const*) E File "tilelang/src/transform/warp_specialized_rewriter.cc", line 542 E InternalError: Check failed: (map.release[i].size() > 0) is false: tilelang/3rdparty/tvm/python/tvm/_ffi/base.py:465: InternalError
- TMA (Tensor Memory Accelerator) Related Failures
-
tilelang/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py::test_gemmFAILED test_tilelang_tilelibrary_gemm.py::test_gemm - RuntimeError: Kernel call failed: Error: Failed to initialize the TMA descriptor A_desc
- Numerical Mismatch / Assertion Error
tilelang/testing/python/primitives/test_tilelang_primitives_mma.py::test_gemm_f16f16f16_nt_ssrFAILED test_tilelang_primitives_mma.py::test_gemm_f16f16f16_nt_ssr - AssertionError: Too many mismatched elements: 96601 > 52428 (5.00% allowed, but get 9.21%).
- Stuck / Hung Tests:
tilelang/testing/python/language/test_tilelang_language_all_of.py: The testtest_block_sparse_matmul_localappears to hang indefinitely.tilelang/testing/python/language/test_tilelang_language_any_of.py: The teststest_block_sparse_matmul_sharedandtest_block_sparse_matmul_sharedappear to hang indefinitely.tilelang/testing/python/language/test_tilelang_language_reduce_sum.py: The testtest_reduce_sumappears to hang indefinitely.
Metadata
Metadata
Assignees
Labels
No labels