diff --git a/src/op/gemm.cc b/src/op/gemm.cc index d7d7d81bc..14c73bc72 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -197,10 +197,6 @@ std::pair Gemm::ComputeWarpPartition(int num_warps, Target target, // Try all possible combinations that satisfy the constraints for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { int n = num_warps / m; - if (n > max_n_warps) - continue; - if (m * n != num_warps) - continue; // Calculate how balanced this partition is float m_per_warp = static_cast(this->M) / (m * kMPerWarp); diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 103a15172..6e593f27c 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -181,6 +181,9 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: else: raise ValueError(f"Invalid function type: {type(program_result_source)}") + if self.verbose: + logger.info(f"Verbose: Compiling for program \n {program_result.script()}") + kernel_result = compile( program_result, out_idx=self.out_idx,