Skip to content

Commit f40164a

Browse files
authored
[Bugfix] Fix a bug when simplifying warp combination for T.gemm (#540)
1 parent db86ec4 commit f40164a

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

src/op/gemm.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
197197
// Try all possible combinations that satisfy the constraints
198198
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
199199
int n = num_warps / m;
200-
if (n > max_n_warps)
201-
continue;
202-
if (m * n != num_warps)
203-
continue;
204200

205201
// Calculate how balanced this partition is
206202
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);

tilelang/jit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
181181
else:
182182
raise ValueError(f"Invalid function type: {type(program_result_source)}")
183183

184+
if self.verbose:
185+
logger.info(f"Verbose: Compiling for program \n {program_result.script()}")
186+
184187
kernel_result = compile(
185188
program_result,
186189
out_idx=self.out_idx,

0 commit comments

Comments
 (0)