Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
./python/amd/test_tilelang_test_amd.py
./python/amd

# Apple Metal tests
- name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu118
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu120
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu121
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu123
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu124
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu125
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu126
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.cu128
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z
RUN pip install cython

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && cmake -S . -B build -DUSE_CUDA=ON && cmake --build build -j
&& cd TileLang && USE_CUDA=1 pip install -e . -v

CMD bash
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh"
conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"

RUN conda init bash

Expand Down
12 changes: 3 additions & 9 deletions testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

Expand Down Expand Up @@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
Expand Down
24 changes: 21 additions & 3 deletions testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
Expand Down Expand Up @@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)

assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
Expand All @@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack=2,
b_preshuffle=True)

assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)


if __name__ == "__main__":
tilelang.testing.main()
45 changes: 18 additions & 27 deletions tilelang/intrinsics/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,6 @@ def mfma(self,
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0

print(a_local_stride, b_local_stride)

@T.macro
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
Expand Down Expand Up @@ -678,34 +676,27 @@ def __init__(
is_m_first: bool | None = False,
a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False,
thread_var: Var | None = None,
):

self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
super().__init__(
a_dtype=a_dtype,
b_dtype=b_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
num_elems_per_byte=num_elems_per_byte,
k_pack=k_pack,
is_m_first=is_m_first,
thread_var=thread_var,
)
Comment on lines +681 to +697
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure preshuffle emitter honors custom thread binding

The new thread_var argument is exposed here, but ldmatrix_a/ldmatrix_b still fetch the binding from T.KernelLaunchFrame.Current(). If a caller now supplies thread_var, those methods ignore it and still assert on an active kernel frame, so the preshuffle path breaks for the newly supported use case. Please route both loaders through self.get_thread_binding() like the base emitter.

Apply this diff:

@@ def ldmatrix_a(...):
-        current_frame = T.KernelLaunchFrame.Current()
-        thread_binding = current_frame.get_thread_binding()
+        thread_binding = self.get_thread_binding()
@@ def ldmatrix_b(...):
-        current_frame = T.KernelLaunchFrame.Current()
-        thread_binding = current_frame.get_thread_binding()
+        thread_binding = self.get_thread_binding()

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 681 to 697, the
constructor now accepts thread_var but the preshuffle loaders
ldmatrix_a/ldmatrix_b still read thread binding from
T.KernelLaunchFrame.Current(); update those loader calls so they use
self.get_thread_binding() (which respects the supplied thread_var) instead of
directly querying KernelLaunchFrame.Current(), and remove or replace the
assertion that an active kernel frame is required so the preshuffle path works
with an externally provided thread_var.

self._initialize_preshuffle(a_preshuffle, b_preshuffle)

self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte

def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle
Expand Down
Loading