Skip to content

Commit 5f50e3a

Browse files
authored
[CI][Test] Add test cases for tilelang transform AnnotateDeviceRegions and MakePackedAPI (#26)
* installation script fix * readme typo fix * doc fix for dequantize gemm * [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md * [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md
1 parent a575f72 commit 5f50e3a

File tree

4 files changed

+413
-29
lines changed

4 files changed

+413
-29
lines changed

SUPPORT.md

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tilelang
19+
import tilelang.testing
20+
from tilelang import language as T
21+
22+
23+
class BaseCompare(tilelang.testing.CompareBeforeAfter):
24+
transform = tilelang.transform.AnnotateDeviceRegions()
25+
26+
27+
class TestAnnotateThreadExtent(BaseCompare):
28+
"""Annotation inserted at the "thread_extent" attribute"""
29+
30+
def before(A: T.Buffer(16, "float32")):
31+
T.func_attr({"target": T.target("cuda", host="llvm")})
32+
i = T.launch_thread("threadIdx.x", 16)
33+
A[i] = 0.0
34+
35+
def expected(A: T.Buffer(16, "float32")):
36+
T.func_attr({"target": T.target("cuda", host="llvm")})
37+
T.attr(T.target("cuda"), "target", 0)
38+
i = T.launch_thread("threadIdx.x", 16)
39+
A[i] = 0.0
40+
41+
42+
class TestAnnotateDeviceScope(BaseCompare):
43+
"""Annotation inserted at the "device_scope" attribute"""
44+
45+
def before(A: T.Buffer(1, "float32")):
46+
T.func_attr({"target": T.target("cuda", host="llvm")})
47+
T.attr(0, "device_scope", 0)
48+
A[0] = 0.0
49+
50+
def expected(A: T.Buffer(1, "float32")):
51+
T.func_attr({"target": T.target("cuda", host="llvm")})
52+
T.attr(T.target("cuda"), "target", 0)
53+
T.attr(0, "device_scope", 0)
54+
A[0] = 0.0
55+
56+
57+
if __name__ == "__main__":
58+
tilelang.testing.main()

0 commit comments

Comments
 (0)