diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a475cd513..f9fe32861 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -379,7 +379,7 @@ jobs: pytest --verbose --color=yes --durations=0 --showlocals --cache-clear ) "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ - ./python/amd/test_tilelang_test_amd.py + ./python/amd # Apple Metal tests - name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) diff --git a/docker/Dockerfile.cu118 b/docker/Dockerfile.cu118 index 9256fc09b..be8274461 100644 --- a/docker/Dockerfile.cu118 +++ b/docker/Dockerfile.cu118 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu120 b/docker/Dockerfile.cu120 index c89ce82ef..7ca1d931f 100644 --- a/docker/Dockerfile.cu120 +++ b/docker/Dockerfile.cu120 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu121 b/docker/Dockerfile.cu121 index 5b092773d..f91029d75 100644 --- a/docker/Dockerfile.cu121 +++ b/docker/Dockerfile.cu121 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu123 b/docker/Dockerfile.cu123 index 2715536a8..b3d1217fd 100644 --- a/docker/Dockerfile.cu123 +++ b/docker/Dockerfile.cu123 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu124 b/docker/Dockerfile.cu124 index fb9654f48..335f52565 100644 --- a/docker/Dockerfile.cu124 +++ b/docker/Dockerfile.cu124 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu125 b/docker/Dockerfile.cu125 index c409667cb..148e44b41 100644 --- a/docker/Dockerfile.cu125 +++ b/docker/Dockerfile.cu125 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu126 b/docker/Dockerfile.cu126 index 93593b5df..c031c2bc9 100644 --- a/docker/Dockerfile.cu126 +++ b/docker/Dockerfile.cu126 @@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && ./install_cuda.sh + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.cu128 b/docker/Dockerfile.cu128 index db5e1cb57..2b895ecd8 100644 --- a/docker/Dockerfile.cu128 +++ b/docker/Dockerfile.cu128 @@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z RUN pip install cython RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \ - && cd TileLang && cmake -S . -B build -DUSE_CUDA=ON && cmake --build build -j + && cd TileLang && USE_CUDA=1 pip install -e . -v CMD bash diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 1fb23a9f3..f519bb0aa 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \ - conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh" + conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v" RUN conda init bash diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index bf4d49e41..a01bd4596 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -22,15 +22,6 @@ def tl_matmul( b_transposed=True, k_pack=1, ): - assert in_dtype in [ - "float16", - "int8", - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - "float16", - "float32", - "int32", - ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 @@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M, if in_dtype == "int8": A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + elif in_dtype == "float8_e4m3fnuz": + A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) else: A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index 73cdc280b..b215f0d45 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M, if in_dtype == "int8": A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) + elif in_dtype == "float8_e4m3fnuz": + A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) + B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) else: A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) @@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M, @tilelang.testing.requires_rocm def test_assert_tl_matmul(): assert_tl_matmul_correctness( - 256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( - 256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) + 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( - 256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) + 256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness( 256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) @@ -283,6 +286,21 @@ def test_assert_tl_matmul(): k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) + assert_tl_matmul_correctness( + 256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) + assert_tl_matmul_correctness( + 256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) + assert_tl_matmul_correctness( + 256, + 256, + 512, + "float8_e4m3fnuz", + "float32", + k_pack=2, + b_transposed=False, + b_preshuffle=True) + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 8829fae25..84e4c21b9 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -374,8 +374,6 @@ def mfma(self, a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 - print(a_local_stride, b_local_stride) - @T.macro def _warp_mfma(A_local_buf, B_local_buf, C_local_buf): for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): @@ -678,34 +676,27 @@ def __init__( is_m_first: bool | None = False, a_preshuffle: bool | None = False, b_preshuffle: bool | None = False, + thread_var: Var | None = None, ): - - self.a_dtype = a_dtype - self.b_dtype = b_dtype - self.accum_dtype = accum_dtype - self.a_transposed = a_transposed - self.b_transposed = b_transposed - # Hint Information - self.block_row_warps = block_row_warps - self.block_col_warps = block_col_warps - self.warp_row_tiles = warp_row_tiles - self.warp_col_tiles = warp_col_tiles - self.chunk = chunk - self._initialize_k_dim(a_dtype) - self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) - self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) - self._initialize_mfma_prefix(self.k_dim) - self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) - self._initialize_k_pack(k_pack) - self._initialize_is_m_first(is_m_first) + super().__init__( + a_dtype=a_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + k_pack=k_pack, + is_m_first=is_m_first, + thread_var=thread_var, + ) self._initialize_preshuffle(a_preshuffle, b_preshuffle) - self.warp_rows = warp_row_tiles // self.micro_size_x - self.warp_cols = warp_col_tiles // self.micro_size_y - self.reduce_k = reduce_k - self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) - self.num_elems_per_byte = num_elems_per_byte - def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): if a_preshuffle is not None: self.a_preshuffle = a_preshuffle