Skip to content

Conversation

@iloveai8086
Copy link
Contributor

@iloveai8086 iloveai8086 commented Oct 9, 2025

Modified the architecture number in sm100 to support Thor’s execution. The results are as follows.

~/Project/tilelang/examples/gemm_sm100$ python gemm_tcgen5mma.py 
after extents [128, 128]
after extents [256, 128]
after extents [128, 256]
2025-10-09 15:53:57  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[2]`
major=11, minor=0
2025-10-09 15:54:08  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `main`
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C);
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  __shared__ uint C_tmem[1];
  __shared__ uint64_t mbar_mem[1];
  auto mbar = reinterpret_cast<Barrier*>(mbar_mem);
  float C_local[128];
  if ((((int)threadIdx.x) >> 5) == 0) {
    tl::tmem_allocate((&(C_tmem[0])), 256);
  }
  __syncthreads();
  if (tl::tl_shuffle_elect<0>()) {
    mbar[0].init(1);
  }
  __syncthreads();
  #pragma unroll
  for (int i = 0; i < 8; ++i) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((((int)threadIdx.x) & 15) >> 3) * 16384) + (i * 2048)) + ((((int)threadIdx.x) >> 4) * 128)) + (((((((int)threadIdx.x) & 127) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 131072), A+((((((int)blockIdx.y) * 1048576) + (i * 131072)) + ((((int)threadIdx.x) >> 4) * 8192)) + ((((int)threadIdx.x) & 15) * 8)));
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 16; ++i_1) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((((((int)threadIdx.x) & 15) >> 3) * 32768) + (i_1 * 2048)) + ((((int)threadIdx.x) >> 4) * 128)) + (((((((int)threadIdx.x) & 127) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)), B+((((((int)blockIdx.x) * 2097152) + (i_1 * 131072)) + ((((int)threadIdx.x) >> 4) * 8192)) + ((((int)threadIdx.x) & 15) * 8)));
  }
  tl::cp_async_commit();
  for (int k = 0; k < 63; ++k) {
    __syncthreads();
    #pragma unroll
    for (int i_2 = 0; i_2 < 8; ++i_2) {
      tl::cp_async_gs<16>(buf_dyn_shmem+((((((((((k + 1) & 1) * 32768) + (((((int)threadIdx.x) & 15) >> 3) * 16384)) + (i_2 * 2048)) + ((((int)threadIdx.x) >> 4) * 128)) + (((((((int)threadIdx.x) & 127) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 131072), A+((((((((int)blockIdx.y) * 1048576) + (i_2 * 131072)) + ((((int)threadIdx.x) >> 4) * 8192)) + (k * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 128));
    }
    #pragma unroll
    for (int i_3 = 0; i_3 < 16; ++i_3) {
      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((((k + 1) & 1) * 65536) + (((((int)threadIdx.x) & 15) >> 3) * 32768)) + (i_3 * 2048)) + ((((int)threadIdx.x) >> 4) * 128)) + (((((((int)threadIdx.x) & 127) >> 6) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + (((int)threadIdx.x) & 1)) & 1) * 16)), B+((((((((int)blockIdx.x) * 2097152) + (i_3 * 131072)) + ((((int)threadIdx.x) >> 4) * 8192)) + (k * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 128));
    }
    tl::cp_async_commit();
    tl::fence_proxy_async();
    tl::cp_async_wait<1>();
    __syncthreads();
    if ((((int)threadIdx.x) >> 5) == 0) {
      tl::tcgen5mma_gemm_ss<128, 256, 128, 128, 256, 16, 0, 1, float>((&(((bfloat16_t*)buf_dyn_shmem)[(((k & 1) * 16384) + 65536)])), (&(((bfloat16_t*)buf_dyn_shmem)[((k & 1) * 32768)])), C_tmem[0], (&(mbar[0])), (k == 0));
    }
    mbar[0].wait((k & 1));
  }
  tl::cp_async_wait<0>();
  __syncthreads();
  if ((((int)threadIdx.x) >> 5) == 0) {
    tl::tcgen5mma_gemm_ss<128, 256, 128, 128, 256, 16, 0, 1, float>((&(((bfloat16_t*)buf_dyn_shmem)[81920])), (&(((bfloat16_t*)buf_dyn_shmem)[32768])), C_tmem[0], (&(mbar[0])), (bool)0);
  }
  mbar[0].wait(1);
  tl::tcgen05_ld_32dp32bNx<128>(C_tmem[0], ((((int)threadIdx.x) >> 7) * 128), (&(C_local[0])));
  __syncthreads();
  #pragma unroll
  for (int i_4 = 0; i_4 < 32; ++i_4) {
    uint2 __1;
    float4 v_ = *(float4*)(C_local + (i_4 * 4));
    ((nv_bfloat162*)(&(__1.x)))->x = (bfloat16_t)(v_.x);
    ((nv_bfloat162*)(&(__1.x)))->y = (bfloat16_t)(v_.y);
    ((nv_bfloat162*)(&(__1.y)))->x = (bfloat16_t)(v_.z);
    ((nv_bfloat162*)(&(__1.y)))->y = (bfloat16_t)(v_.w);
    *(uint2*)(((bfloat16_t*)buf_dyn_shmem) + (((((((int)threadIdx.x) & 127) * 256) + ((((int)threadIdx.x) >> 7) * 128)) + (i_4 * 4)) + 65536)) = __1;
  }
  __syncthreads();
  #pragma unroll
  for (int i_5 = 0; i_5 < 8; ++i_5) {
    tl::st_global_256(&(*(ulonglong4*)(C + (((((((int)blockIdx.y) * 524288) + (i_5 * 65536)) + ((((int)threadIdx.x) >> 4) * 4096)) + (((int)blockIdx.x) * 256)) + ((((int)threadIdx.x) & 15) * 16)))), *(ulonglong4*)(((bfloat16_t*)buf_dyn_shmem) + (((i_5 * 4096) + (((int)threadIdx.x) * 16)) + 65536)));
  }
  tl::fence_proxy_async();
  if ((((int)threadIdx.x) >> 5) == 0) {
    tl::tmem_deallocate((&(C_tmem[0])), 256);
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_main_kernel = cudaFuncSetAttribute(main_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 196608);
    if (result_main_kernel != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 196608, cudaGetErrorString(result_main_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {
        main_kernel<<<dim3(16, 32, 1), dim3(256, 1, 1), 196608, stream>>>(A, B, C);
        TILELANG_CHECK_LAST_ERROR("main_kernel");

        return 0;
}

Latency: 9.004223823547363 ms
Flops: 30.52766260931387 TFLOPS

Summary by CodeRabbit

  • New Features
    • Expanded CUDA GPU architecture compatibility: the app now recognizes and runs on newer GPUs up to SM 110 while retaining support for SM 100+. This enhances out-of-the-box operation on the latest hardware with no user changes required. Existing behavior for older GPUs is preserved, with no impact on current performance settings or flags.

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 9, 2025

Walkthrough

Updated TargetIsSm100 in src/target/utils.cc to expand the accepted CUDA SM architecture range from 100–103 to 100–110. The function still checks CUDA targeting first and retrieves the architecture via GetArchInt; other control flow and checks remain unchanged.

Changes

Cohort / File(s) Summary
Target arch check update
src/target/utils.cc
Expanded SM range check in TargetIsSm100 from <=103 to <=110; retained existing CUDA target check, arch retrieval via GetArchInt, and overall control flow.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Poem

I twitch my whiskers, hop with glee,
SMs now span a broader sea—
From hundred’s dawn to ten-and-ten,
The checks align, we run again.
A tiny tweak, a cleaner run,
More arches welcomed—bun done! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and concisely summarizes the primary change—updating the SM architecture number to support sm110—without extraneous details or noise.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f6d4bd3 and c0e6229.

📒 Files selected for processing (1)
  • src/target/utils.cc (1 hunks)
🔇 Additional comments (1)
src/target/utils.cc (1)

60-60: Verify SM111–SM119 support and bound inclusivity

  • RG search shows no references to SM111–SM119; confirm whether architectures 111–119 exist or are planned.
  • Align upper-bound logic with other ranges (e.g., use < 111 instead of <= 110 if exclusive).

return false;
int arch = GetArchInt(target);
return arch >= 100 & arch <= 103;
return arch >= 100 & arch <= 110;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Use logical AND (&&) instead of bitwise AND (&).

Line 60 incorrectly uses the bitwise AND operator & instead of the logical AND operator &&. While this may accidentally produce correct results due to implicit bool-to-int conversion, it is incorrect and inconsistent with all other architecture check functions in this file (lines 32, 39, 46, 53, 67).

Apply this diff to fix the operator:

-  return arch >= 100 & arch <= 110;
+  return arch >= 100 && arch <= 110;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return arch >= 100 & arch <= 110;
return arch >= 100 && arch <= 110;
🤖 Prompt for AI Agents
In src/target/utils.cc around line 60, the range check uses a bitwise AND (`&`)
instead of a logical AND; replace the operator so the return becomes a proper
boolean range test (use `&&` between the two comparisons, e.g. return arch >=
100 && arch <= 110;) to match the other architecture check functions and ensure
correct logical behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants