Skip to content

Commit 089e63a

Browse files
authored
[Refactor] Improve layout equality checks and error messaging (#471)
* [Refactor] Simplify buffer_region_to_tile_region function in copy.py * Removed redundant logic for handling region extents in the buffer_region_to_tile_region function, streamlining the code for better readability and maintainability. * Enhanced error handling by focusing on essential checks while eliminating unnecessary complexity related to variable extents. * [Refactor] Improve layout equality checks and error messaging * Updated the `IsEqual` method in `FragmentNode` to ensure consistent evaluation of thread ranges. * Enhanced error messaging in `ParallelOp::InferLayout` to include source buffer information for better debugging. * Adjusted `ReduceOp::InferLayout` to set thread range during layout condensation, improving layout inference accuracy. * lintfix * [Refactor] Rename SetThreadRange to BindThreadRange for clarity * Updated the `SetThreadRange` method in `FragmentNode` and related classes to `BindThreadRange`, improving method naming consistency and clarity. * Adjusted all references to the renamed method across the codebase, ensuring proper functionality and maintaining existing behavior. * Enhanced layout equality checks to handle thread ranges more robustly in `IsEqual` method. * Updated layout inference methods in `Gemm`, `ParallelOp`, and `ReduceOp` to utilize the new method name, ensuring seamless integration with the updated API. * [Refactor] Update BindThreadRange usage across layout inference methods * Modified the implementation of `BindThreadRange` in `FragmentNode` to create a new object instance, enhancing thread range binding functionality. * Updated all references to `BindThreadRange` in layout inference methods across `Gemm`, `ParallelOp`, and `ReduceOp` to ensure consistency with the new implementation. * Adjusted the return statements in various layout inference functions to utilize the updated method, maintaining existing behavior while improving clarity. * lint fix
1 parent 9b3d135 commit 089e63a

File tree

6 files changed

+25
-27
lines changed

6 files changed

+25
-27
lines changed

src/layout/layout.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,10 @@ Fragment FragmentNode::DeReplicate() const {
204204
int(*rep_size) / factor, NullOpt);
205205
}
206206

207-
Fragment FragmentNode::SetThreadRange(Range thread_range) {
208-
thread_range_ = thread_range;
209-
return GetRef<Fragment>(this);
207+
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
208+
auto n = make_object<FragmentNode>(*this);
209+
n->thread_range_ = thread_range;
210+
return Fragment(n);
210211
}
211212

212213
Layout LayoutNode::Inverse() const {
@@ -418,11 +419,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
418419
// a[i, j] = b[j, i] in register level.
419420

420421
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
421-
ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange());
422422
if (!ret) {
423423
// may be broadcast case
424424
return true;
425425
}
426+
if (this->thread_range_.defined() && other->thread_range_.defined()) {
427+
ret &= StructuralEqual()(this->thread_range_, other->thread_range_);
428+
}
426429
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
427430
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
428431
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());

src/layout/layout.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class FragmentNode : public LayoutNode {
9898

9999
std::string DebugOutput() const final;
100100

101-
Fragment SetThreadRange(Range thread_range);
101+
Fragment BindThreadRange(Range thread_range) const;
102102

103103
Range ThreadRange() const { return thread_range_; }
104104

@@ -130,12 +130,6 @@ class Fragment : public Layout {
130130
Optional<Var> replicate_var);
131131

132132
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
133-
134-
Fragment SetThreadRange(Range thread_range) {
135-
auto node = make_object<FragmentNode>(*this->get());
136-
node->SetThreadRange(thread_range);
137-
return Fragment(node);
138-
}
139133
};
140134

141135
Var InputPlaceholder(size_t idx);

src/op/gemm.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
175175
ComputeWarpPartition(block_size / warp_size, T.target);
176176
auto fragment =
177177
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
178-
results.Set(C, fragment.SetThreadRange(thread_range));
178+
results.Set(C, fragment->BindThreadRange(thread_range));
179179
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
180180
int dim_A = A->shape.size();
181181
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
@@ -184,7 +184,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
184184
} else if (A.scope() == "local.fragment") {
185185
ICHECK(trans_A == false);
186186
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
187-
results.Set(A, fragment.SetThreadRange(thread_range));
187+
results.Set(A, fragment->BindThreadRange(thread_range));
188188
} else {
189189
ICHECK(0);
190190
}
@@ -200,7 +200,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
200200
ComputeWarpPartition(block_size / warp_size, T.target);
201201
auto fragment =
202202
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
203-
results.Set(C, fragment.SetThreadRange(thread_range));
203+
results.Set(C, fragment->BindThreadRange(thread_range));
204204

205205
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
206206
int dim_A = A->shape.size();
@@ -213,7 +213,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
213213
ICHECK(trans_A == false);
214214
auto fragment =
215215
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
216-
results.Set(A, fragment.SetThreadRange(thread_range));
216+
results.Set(A, fragment->BindThreadRange(thread_range));
217217
} else {
218218
ICHECK(0);
219219
}
@@ -228,7 +228,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
228228
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
229229
"please raise an issue if you see this";
230230
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
231-
results.Set(B, fragment.SetThreadRange(thread_range));
231+
results.Set(B, fragment->BindThreadRange(thread_range));
232232
} else {
233233
ICHECK(0);
234234
}
@@ -242,7 +242,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
242242
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
243243
C->dtype.bits())
244244
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
245-
results.Set(C, fragment.SetThreadRange(thread_range));
245+
results.Set(C, fragment->BindThreadRange(thread_range));
246246
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
247247
int dim_A = A->shape.size();
248248
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
@@ -255,7 +255,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
255255
ICHECK(trans_A == false);
256256
auto fragment =
257257
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
258-
results.Set(A, fragment.SetThreadRange(thread_range));
258+
results.Set(A, fragment->BindThreadRange(thread_range));
259259
}
260260
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
261261
int dim_B = B->shape.size();
@@ -275,7 +275,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
275275

276276
auto fragment =
277277
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
278-
results.Set(C, fragment.SetThreadRange(thread_range));
278+
results.Set(C, fragment->BindThreadRange(thread_range));
279279

280280
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
281281
int dim_A = A->shape.size();
@@ -286,7 +286,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
286286
} else if (A.scope() == "local.fragment") {
287287
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
288288
A->dtype.bits(), trans_A);
289-
results.Set(A, fragment.SetThreadRange(thread_range));
289+
results.Set(A, fragment->BindThreadRange(thread_range));
290290
} else {
291291
ICHECK(0);
292292
}
@@ -299,7 +299,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
299299
results.Set(B, shared_layout);
300300
} else if (B.scope() == "local.fragment") {
301301
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
302-
results.Set(B, fragment.SetThreadRange(thread_range));
302+
results.Set(B, fragment->BindThreadRange(thread_range));
303303
} else {
304304
ICHECK(0);
305305
}

src/op/parallel.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
181181
PrimExpr loop_var_to_thread =
182182
src_layout->ForwardThread(indice_map_[buffer], rep);
183183
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
184-
.SetThreadRange(T.thread_bounds);
184+
->BindThreadRange(T.thread_bounds);
185185
}
186186
};
187187
if (source_buffer.defined()) {
@@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
272272
LayoutMap results;
273273
for (const auto &[buffer, _] : indice_map_) {
274274
if (!T.layout_map.count(buffer)) {
275-
results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange(
275+
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
276276
T.thread_bounds));
277277
}
278278
// Though they may exist some conflicts, but it's fine.
@@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
285285
const FragmentNode *src_layout =
286286
T.layout_map[buffer].as<Fragment>().get();
287287
Fragment dst_layout_fragment =
288-
CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds);
288+
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
289289
const FragmentNode *dst_layout =
290290
dst_layout_fragment.as<Fragment>().get();
291291
if (src_layout && dst_layout) {
292292
ICHECK(src_layout->IsEqual(dst_layout, true))
293293
<< "Layout may conflict with ParallelOp for buffer " << buffer
294-
<< "\nError body begin:\n"
294+
<< " vs. " << source_buffer << "\nError body begin:\n"
295295
<< GetRoot()->body << "\nError body end"
296296
<< "\nLHS = " << src_layout->DebugOutput()
297297
<< "\nRHS = " << dst_layout->DebugOutput()

src/op/reduce.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
275275
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
276276
Fragment dst_layout =
277277
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
278-
->CondenseReplicateVar();
278+
->CondenseReplicateVar()
279+
->BindThreadRange(T.thread_bounds);
279280
return {{dst, dst_layout}};
280281
}
281282
return {};

src/transform/loop_partition.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
191191
size_t num_thread = *as_const_int(thread_range->extent);
192192
LoopPartitioner partitioner;
193193
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
194-
return fragment.SetThreadRange(thread_range);
194+
return fragment->BindThreadRange(thread_range);
195195
}
196196

197197
For LoopPragmaUnroll(For stmt) {

0 commit comments

Comments
 (0)