Skip to content

Commit 35cf888

Browse files
[Enhancement] Remove constraint requiring last dimension stride to be 1 (#1040)
* remove last dimension stride must be 1 constraint * add vectorize test * minor fix * [Lint]: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fd1493b commit 35cf888

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import tilelang.testing
3+
import tilelang.language as T
4+
5+
6+
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
7+
def vectorize_test(N, M, stride_A, stride_B):
8+
assert N % 128 == 0 and M % 128 == 0
9+
10+
@T.prim_func
11+
def main(
12+
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
13+
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
14+
):
15+
with T.Kernel(M // 128, threads=128) as (bx):
16+
tx = T.get_thread_binding(0)
17+
col = bx * 128 + tx
18+
19+
for row in T.vectorized(N):
20+
B[row, col] = A[row, col]
21+
22+
return main
23+
24+
25+
def run_vectorize(N, M, stride_A, stride_B):
26+
assert stride_A >= N and stride_B >= N
27+
28+
jit_kernel = vectorize_test(N, M, stride_A, stride_B)
29+
30+
base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32)
31+
base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32)
32+
a = torch.as_strided(base_a, size=(N, M), stride=(1, stride_A))
33+
b = torch.as_strided(base_b, size=(N, M), stride=(1, stride_B))
34+
35+
jit_kernel(a, b)
36+
37+
torch.testing.assert_close(a, b, atol=1e-8, rtol=1e-8)
38+
39+
code = jit_kernel.get_kernel_source()
40+
41+
vectorize_size = 1
42+
while vectorize_size <= 2 and \
43+
stride_A % (vectorize_size * 2) == 0 and \
44+
stride_B % (vectorize_size * 2) == 0:
45+
vectorize_size *= 2
46+
47+
if vectorize_size == 4:
48+
assert "float4" in code
49+
elif vectorize_size == 2:
50+
assert "float2" in code
51+
52+
53+
def test_vectorize():
54+
N, M = 512, 256
55+
56+
run_vectorize(N, M, N, N)
57+
run_vectorize(N, M, N + 2, N + 4)
58+
run_vectorize(N, M, N + 4, N + 8)
59+
run_vectorize(N, M, N + 8, N + 16)
60+
61+
62+
if __name__ == "__main__":
63+
tilelang.testing.main()

tilelang/language/proxy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,6 @@ def __call__(self,
178178
scope=None) -> tir.Buffer:
179179
if len(shape) != len(strides):
180180
raise ValueError("Invalid shape/strides' dimensions")
181-
if not bool(strides[-1] == 1):
182-
# TODO(chenggang): shall we support non-contiguous even for the last dimension?
183-
raise ValueError("The stride of the last dimension must be 1 (contiguous)")
184181
return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
185182

186183

0 commit comments

Comments
 (0)