diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index ed43ef68474e..a8e8baa63fa3 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -41,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { void AddFunction(LoweredFunc f); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void VisitStmt_(const ir::For* op) final;