diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 63520edfe..bea87a220 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -204,9 +204,10 @@ Fragment FragmentNode::DeReplicate() const { int(*rep_size) / factor, NullOpt); } -Fragment FragmentNode::SetThreadRange(Range thread_range) { - thread_range_ = thread_range; - return GetRef(this); +Fragment FragmentNode::BindThreadRange(Range thread_range) const { + auto n = make_object(*this); + n->thread_range_ = thread_range; + return Fragment(n); } Layout LayoutNode::Inverse() const { @@ -418,11 +419,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const { // a[i, j] = b[j, i] in register level. bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); - ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange()); if (!ret) { // may be broadcast case return true; } + if (this->thread_range_.defined() && other->thread_range_.defined()) { + ret &= StructuralEqual()(this->thread_range_, other->thread_range_); + } ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent()); ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent()); diff --git a/src/layout/layout.h b/src/layout/layout.h index 54c9ad19a..cc6b56ec0 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -98,7 +98,7 @@ class FragmentNode : public LayoutNode { std::string DebugOutput() const final; - Fragment SetThreadRange(Range thread_range); + Fragment BindThreadRange(Range thread_range) const; Range ThreadRange() const { return thread_range_; } @@ -130,12 +130,6 @@ class Fragment : public Layout { Optional replicate_var); TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); - - Fragment SetThreadRange(Range thread_range) { - auto node = make_object(*this->get()); - node->SetThreadRange(thread_range); - return Fragment(node); - } }; Var InputPlaceholder(size_t idx); diff --git a/src/op/gemm.cc b/src/op/gemm.cc index c9562c5c3..400aa2fa9 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -175,7 +175,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ComputeWarpPartition(block_size / warp_size, T.target); auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); - results.Set(C, fragment.SetThreadRange(thread_range)); + results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), @@ -184,7 +184,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { } else if (A.scope() == "local.fragment") { ICHECK(trans_A == false); auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); - results.Set(A, fragment.SetThreadRange(thread_range)); + results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } @@ -200,7 +200,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ComputeWarpPartition(block_size / warp_size, T.target); auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); - results.Set(C, fragment.SetThreadRange(thread_range)); + results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); @@ -213,7 +213,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ICHECK(trans_A == false); auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()); - results.Set(A, fragment.SetThreadRange(thread_range)); + results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } @@ -228,7 +228,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, " "please raise an issue if you see this"; auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n); - results.Set(B, fragment.SetThreadRange(thread_range)); + results.Set(B, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } @@ -242,7 +242,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits()) : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); - results.Set(C, fragment.SetThreadRange(thread_range)); + results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); @@ -255,7 +255,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ICHECK(trans_A == false); auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()); - results.Set(A, fragment.SetThreadRange(thread_range)); + results.Set(A, fragment->BindThreadRange(thread_range)); } if (B.scope() == "shared" || B.scope() == "shared.dyn") { int dim_B = B->shape.size(); @@ -275,7 +275,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); - results.Set(C, fragment.SetThreadRange(thread_range)); + results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); @@ -286,7 +286,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { } else if (A.scope() == "local.fragment") { auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits(), trans_A); - results.Set(A, fragment.SetThreadRange(thread_range)); + results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } @@ -299,7 +299,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { results.Set(B, shared_layout); } else if (B.scope() == "local.fragment") { auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n); - results.Set(B, fragment.SetThreadRange(thread_range)); + results.Set(B, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } diff --git a/src/op/parallel.cc b/src/op/parallel.cc index ed44dbb85..83dfc8a93 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) - .SetThreadRange(T.thread_bounds); + ->BindThreadRange(T.thread_bounds); } }; if (source_buffer.defined()) { @@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap results; for (const auto &[buffer, _] : indice_map_) { if (!T.layout_map.count(buffer)) { - results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange( + results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange( T.thread_bounds)); } // Though they may exist some conflicts, but it's fine. @@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { const FragmentNode *src_layout = T.layout_map[buffer].as().get(); Fragment dst_layout_fragment = - CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds); + CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds); const FragmentNode *dst_layout = dst_layout_fragment.as().get(); if (src_layout && dst_layout) { ICHECK(src_layout->IsEqual(dst_layout, true)) << "Layout may conflict with ParallelOp for buffer " << buffer - << "\nError body begin:\n" + << " vs. " << source_buffer << "\nError body begin:\n" << GetRoot()->body << "\nError body end" << "\nLHS = " << src_layout->DebugOutput() << "\nRHS = " << dst_layout->DebugOutput() diff --git a/src/op/reduce.cc b/src/op/reduce.cc index f08bf2a61..1ba03c65f 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -275,7 +275,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); Fragment dst_layout = Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt) - ->CondenseReplicateVar(); + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); return {{dst, dst_layout}}; } return {}; diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index a80605438..eee2b6c53 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { size_t num_thread = *as_const_int(thread_range->extent); LoopPartitioner partitioner; Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); - return fragment.SetThreadRange(thread_range); + return fragment->BindThreadRange(thread_range); } For LoopPragmaUnroll(For stmt) { diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index ef7fd9f9b..0b72c91ae 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -76,38 +76,9 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s mins = [x.min for x in buffer_region.region] region_extents = [x.extent for x in buffer_region.region] assert len(region_extents) >= len( - extents), f"region_extents = {region_extents}, extents = {extents}" - - # If region_extents already contains all elements - # of extents (in any order), pass directly - tmp_extents = list(extents) - variable_extent_count = 0 - for i in range(len(region_extents)): - v = region_extents[i] - if not isinstance(v, tir.IntImm): - variable_extent_count += 1 - continue - - if v in tmp_extents: - tmp_extents.remove(v) - elif isinstance(v, tir.IntImm) and v != 1: - raise ValueError( - f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}" - ) - - tmp_len = len(tmp_extents) - variable_extent_count - if tmp_len > 0: - # Otherwise, align extents from the last dimension, region_extents - # can only replace 1 with extents value, otherwise raise error - for i in range(len(extents)): - idx = len(region_extents) - len(extents) + i - if region_extents[idx] != extents[i]: - if region_extents[idx] == 1: - region_extents[idx] = extents[i] - else: - raise ValueError( - f"buffer {buffer_region.buffer} region_extents[{idx}] = {region_extents[idx]}, extents[{i}] = {extents[i]}" - ) + extents + ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" + return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)