Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,16 +741,11 @@ Stmt BufferCompactorCompact(
}

PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict);
auto storage_align = CollectStorageAlignAnnotation(f->body);
fptr->body = BufferCompactorCompact(f, region, storage_align);
return f;
} else {
return f;
}
PrimFuncNode* fptr = f.CopyOnWrite();
auto region = BufferAccessRegionCollector::Collect(f, /*collect_inbound=*/is_strict);
auto storage_align = CollectStorageAlignAnnotation(f->body);
fptr->body = BufferCompactorCompact(f, region, storage_align);
return f;
}

namespace transform {
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/convert_blocks_to_opaque.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,9 @@ class OpaqueBlockConverter : public StmtExprMutator {
};

PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockConverter::Substitute(f);
return f;
} else {
return f;
}
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockConverter::Substitute(f);
return f;
}

namespace transform {
Expand Down
9 changes: 1 addition & 8 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
Map<Var, Buffer> updated_extern_buffer_map_;
};

PrimFunc FlattenBuffer(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
return BufferFlattener::Flatten(f);
} else {
return f;
}
}
PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); }

namespace transform {

Expand Down
5 changes: 0 additions & 5 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,6 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region
return result;
}

Bool IsFromLegacyTESchedule(PrimFunc f) {
Optional<Bool> from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false));
return from_legacy_te_schedule.value();
}

Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
// extract equations and related vars from condition expression.
// currently only extract simple integral equations which could be solvable.
Expand Down
10 changes: 0 additions & 10 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,6 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region
*/
Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer);

/*!
* \brief Check if a given PrimFunc originated from a TE schedule.
*
* Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc.
*
* \param f PrimFunc to check
* \return Whether or not the PrimFunc was created from a te schedule
*/
Bool IsFromLegacyTESchedule(PrimFunc f);

/*!
* \brief Context helper to update domain map within conditional scope.
* Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/lift_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,9 @@ class ThreadBindingLifter : public StmtExprMutator {
};

PrimFunc LiftThreadBinding(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = ThreadBindingLifter()(std::move(fptr->body));
return f;
} else {
return f;
}
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = ThreadBindingLifter()(std::move(fptr->body));
return f;
}

namespace transform {
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -920,14 +920,9 @@ class CrossThreadReductionTransformer : public StmtMutator {
};

PrimFunc LowerCrossThreadReduction(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = CrossThreadReductionTransformer()(f->body);
return f;
} else {
return f;
}
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = CrossThreadReductionTransformer()(f->body);
return f;
}

namespace transform {
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/lower_init_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,9 @@ class InitBlockLower : public StmtMutator {
};

PrimFunc LowerInitBlock(PrimFunc func) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(func)) {
auto fptr = func.CopyOnWrite();
fptr->body = InitBlockLower()(std::move(fptr->body));
return func;
} else {
return func;
}
auto fptr = func.CopyOnWrite();
fptr->body = InitBlockLower()(std::move(fptr->body));
return func;
}

namespace transform {
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/lower_opaque_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,9 @@ class OpaqueBlockLower : public StmtExprMutator {
};

PrimFunc LowerOpaqueBlock(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
auto fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
} else {
return f;
}
auto fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
return f;
}

namespace transform {
Expand Down
13 changes: 4 additions & 9 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,10 @@ class BufferAllocationLocator : public StmtExprMutator {
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(func)) {
auto fptr = func.CopyOnWrite();
BufferAllocationLocator locator(func);
fptr->body = locator(fptr->body);
return func;
} else {
return func;
}
auto fptr = func.CopyOnWrite();
BufferAllocationLocator locator(func);
fptr->body = locator(fptr->body);
return func;
}

namespace transform {
Expand Down
11 changes: 3 additions & 8 deletions src/tir/transforms/unify_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,9 @@ class ThreadBindingUnifier : public StmtExprMutator {
};

PrimFunc UnifyThreadBinding(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
return f;
} else {
return f;
}
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
return f;
}

namespace transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,7 @@ def main(
),
) -> None:
# function attr dict
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(
A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1
)
Expand Down Expand Up @@ -112,13 +106,7 @@ def main(
),
) -> None:
# function attr dict
T.func_attr(
{
"from_legacy_te_schedule": True,
"global_symbol": "main",
"tir.noalias": True,
}
)
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(
A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, offset_factor=1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_cce_loop_3():
def partitioned_concat(
A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C: T.Buffer((32,), "float32")
) -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in T.serial(0, 16):
C[i] = A[i]
for i in T.serial(0, 16):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3644,7 +3644,7 @@ def func():
def string_stride():
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.int32()
A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto")
B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto")
Expand All @@ -3663,7 +3663,7 @@ def main(a: T.handle, b: T.handle):
def string_stride_int64():
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.int64()
A_s0 = T.int64()
B_s0 = T.int64()
Expand All @@ -3679,7 +3679,7 @@ def merge_shape_var_def():
# uninitialized vars
@T.prim_func(check_well_formed=False)
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), buffer_type="auto")
Expand Down
Loading