From 95fed34ccc8f38e7b3e6e1f878bac5cf96138bc2 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 3 Mar 2025 18:29:35 +0800 Subject: [PATCH] [Refactor] Remove legacy TE schedule tag This commit removes the TE schedule tag from TIR transforms as all legacy TE schedules have been removed. --- src/tir/transforms/compact_buffer_region.cc | 15 +++++---------- src/tir/transforms/convert_blocks_to_opaque.cc | 11 +++-------- src/tir/transforms/flatten_buffer.cc | 9 +-------- src/tir/transforms/ir_utils.cc | 5 ----- src/tir/transforms/ir_utils.h | 10 ---------- src/tir/transforms/lift_thread_binding.cc | 11 +++-------- .../transforms/lower_cross_thread_reduction.cc | 11 +++-------- src/tir/transforms/lower_init_block.cc | 11 +++-------- src/tir/transforms/lower_opaque_block.cc | 11 +++-------- .../plan_update_buffer_allocation_location.cc | 13 ++++--------- src/tir/transforms/unify_thread_binding.cc | 11 +++-------- .../test_tir_transform_inject_rolling_buffer.py | 16 ++-------------- .../test_tir_transform_loop_partition.py | 2 +- .../python/tvmscript/test_tvmscript_roundtrip.py | 6 +++--- 14 files changed, 34 insertions(+), 108 deletions(-) diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 7385af49528b..1907c7ca5038 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -741,16 +741,11 @@ Stmt BufferCompactorCompact( } PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict); - auto storage_align = CollectStorageAlignAnnotation(f->body); - fptr->body = BufferCompactorCompact(f, region, storage_align); - return f; - } else { - return f; - } + PrimFuncNode* fptr = f.CopyOnWrite(); + auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict); + auto storage_align = CollectStorageAlignAnnotation(f->body); + fptr->body = BufferCompactorCompact(f, region, storage_align); + return f; } namespace transform { diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index 95648713494c..ab8d98a00e0e 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -108,14 +108,9 @@ class OpaqueBlockConverter : public StmtExprMutator { }; PrimFunc ConvertBlocksToOpaque(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockConverter::Substitute(f); - return f; - } else { - return f; - } + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockConverter::Substitute(f); + return f; } namespace transform { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c04e12b8395e..a6da7f7fc407 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -268,14 +268,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Map updated_extern_buffer_map_; }; -PrimFunc FlattenBuffer(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - return BufferFlattener::Flatten(f); - } else { - return f; - } -} +PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); } namespace transform { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 7026215a015b..b63192f50a28 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -594,11 +594,6 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region return result; } -Bool IsFromLegacyTESchedule(PrimFunc f) { - Optional from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false)); - return from_legacy_te_schedule.value(); -} - Optional ConditionalBoundsContext::TrySolveCondition() { // extract equations and related vars from condition expression. // currently only extract simple integral equations which could be solvable. diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 05345aab8628..94054f5d2cfe 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -234,16 +234,6 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region */ Array GetBufferAllocationShape(const Buffer& buffer); -/*! - * \brief Check if a given PrimFunc originated from a TE schedule. - * - * Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc. - * - * \param f PrimFunc to check - * \return Whether or not the PrimFunc was created from a te schedule - */ -Bool IsFromLegacyTESchedule(PrimFunc f); - /*! * \brief Context helper to update domain map within conditional scope. * Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 9d7d455dbaed..8cb88fa653c0 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -169,14 +169,9 @@ class ThreadBindingLifter : public StmtExprMutator { }; PrimFunc LiftThreadBinding(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = ThreadBindingLifter()(std::move(fptr->body)); - return f; - } else { - return f; - } + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingLifter()(std::move(fptr->body)); + return f; } namespace transform { diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 0146e2aebf46..325d8e5bb578 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -920,14 +920,9 @@ class CrossThreadReductionTransformer : public StmtMutator { }; PrimFunc LowerCrossThreadReduction(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = CrossThreadReductionTransformer()(f->body); - return f; - } else { - return f; - } + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = CrossThreadReductionTransformer()(f->body); + return f; } namespace transform { diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index 17b4e3fb22e6..3e8fc204314d 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -65,14 +65,9 @@ class InitBlockLower : public StmtMutator { }; PrimFunc LowerInitBlock(PrimFunc func) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(func)) { - auto fptr = func.CopyOnWrite(); - fptr->body = InitBlockLower()(std::move(fptr->body)); - return func; - } else { - return func; - } + auto fptr = func.CopyOnWrite(); + fptr->body = InitBlockLower()(std::move(fptr->body)); + return func; } namespace transform { diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 08642a598b74..96c6d3759cae 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -200,14 +200,9 @@ class OpaqueBlockLower : public StmtExprMutator { }; PrimFunc LowerOpaqueBlock(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - auto fptr = f.CopyOnWrite(); - fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body)); - return f; - } else { - return f; - } + auto fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body)); + return f; } namespace transform { diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index f9ce708c78b7..5ce8ade2085c 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -242,15 +242,10 @@ class BufferAllocationLocator : public StmtExprMutator { }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(func)) { - auto fptr = func.CopyOnWrite(); - BufferAllocationLocator locator(func); - fptr->body = locator(fptr->body); - return func; - } else { - return func; - } + auto fptr = func.CopyOnWrite(); + BufferAllocationLocator locator(func); + fptr->body = locator(fptr->body); + return func; } namespace transform { diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 02fa333dbe14..67c7f05ff413 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -185,14 +185,9 @@ class ThreadBindingUnifier : public StmtExprMutator { }; PrimFunc UnifyThreadBinding(PrimFunc f) { - // Only apply this pass to TIR that is not from TE schedules - if (!IsFromLegacyTESchedule(f)) { - PrimFuncNode* fptr = f.CopyOnWrite(); - fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); - return f; - } else { - return f; - } + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); + return f; } namespace transform { diff --git a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py index 3d8f85bf79dd..4dd1380c8f58 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py @@ -37,13 +37,7 @@ def main( ), ) -> None: # function attr dict - T.func_attr( - { - "from_legacy_te_schedule": True, - "global_symbol": "main", - "tir.noalias": True, - } - ) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) A_1 = T.match_buffer( A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 ) @@ -112,13 +106,7 @@ def main( ), ) -> None: # function attr dict - T.func_attr( - { - "from_legacy_te_schedule": True, - "global_symbol": "main", - "tir.noalias": True, - } - ) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) A_1 = T.match_buffer( A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1 ) diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index 25660880e13f..1e079ada5556 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -238,7 +238,7 @@ def test_cce_loop_3(): def partitioned_concat( A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C: T.Buffer((32,), "float32") ) -> None: - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.serial(0, 16): C[i] = A[i] for i in T.serial(0, 16): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index b44ff5ad7241..f29c03c640ab 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3644,7 +3644,7 @@ def func(): def string_stride(): @T.prim_func def main(a: T.handle, b: T.handle): - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) n = T.int32() A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto") B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto") @@ -3663,7 +3663,7 @@ def main(a: T.handle, b: T.handle): def string_stride_int64(): @T.prim_func def main(a: T.handle, b: T.handle): - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) n = T.int64() A_s0 = T.int64() B_s0 = T.int64() @@ -3679,7 +3679,7 @@ def merge_shape_var_def(): # uninitialized vars @T.prim_func(check_well_formed=False) def main(A: T.handle, B: T.handle): - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) m, n = T.int32(), T.int32() A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto") B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), buffer_type="auto")