Skip to content

Commit

Permalink
recover testing
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Jan 20, 2024
1 parent 37984f0 commit 070dee4
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 40 deletions.
1 change: 0 additions & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
for (size_t i = 0; i < initial_ranges.size(); i++) {
input_iters.Set(self->initial_indices[i], initial_ranges[i]);
}
// LOG(INFO) << "input_iters = " << input_iters;
// Unpack the output indices into linear combinations of the initial
// indices.
auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /*predicate=*/1,
Expand Down
3 changes: 1 addition & 2 deletions tests/python/dlight/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,5 +314,4 @@ def test_extract_func_info_from_prim_func():


if __name__ == "__main__":
test_benchmark_prim_func_local()
test_benchmark_prim_func_full_local()
tvm.testing.main()
3 changes: 1 addition & 2 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,5 +1845,4 @@ def main(


if __name__ == "__main__":
# tvm.testing.main()
test_gather()
tvm.testing.main()
36 changes: 1 addition & 35 deletions tests/python/tir-schedule/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,38 +455,4 @@ def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")):


if __name__ == "__main__":
# tvm.testing.main()
matmul = create_prim_func(te_workload.matmul(128, 128, 128, in_dtype="float16", out_dtype="float32"))
workload = matmul
conv2d = create_prim_func(
te_workload.conv2d_nhwc(4, 16, 16, 64, 64, 3, 1, 1, in_dtype="float16", out_dtype="float32")
)
workload = conv2d
block_name = "conv2d_nhwc"
def elementwise_copy(M, N, dtype="float16"):
@tvm.script.ir_module
class ElementWiseCopy:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, N], dtype=dtype)
B = T.match_buffer(b, [M, N], dtype=dtype)

for i, j in T.grid(M, N):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj]
return ElementWiseCopy
# workload = elementwise_copy(128, 128, dtype="float16")
# block_name = 'B'
s = Schedule(workload)
block = s.get_block(block_name)
desc_func = TensorIntrin.get(WMMA_SYNC_16x16x16_f16f16f32_INTRIN).desc
info = get_auto_tensorize_mapping_info(s, block, desc_func)
print(info.mappings)
# print(info.mappings[0])
# print(info.mappings[1])
# print(info.mappings[2])

tvm.testing.main()

0 comments on commit 070dee4

Please sign in to comment.