Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ Fragment FragmentNode::DeReplicate() const {
int(*rep_size) / factor, NullOpt);
}

Fragment FragmentNode::SetThreadRange(Range thread_range) {
thread_range_ = thread_range;
return GetRef<Fragment>(this);
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}

Layout LayoutNode::Inverse() const {
Expand Down Expand Up @@ -418,11 +419,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// a[i, j] = b[j, i] in register level.

bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange());
if (!ret) {
// may be broadcast case
return true;
}
if (this->thread_range_.defined() && other->thread_range_.defined()) {
ret &= StructuralEqual()(this->thread_range_, other->thread_range_);
}
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
Expand Down
8 changes: 1 addition & 7 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class FragmentNode : public LayoutNode {

std::string DebugOutput() const final;

Fragment SetThreadRange(Range thread_range);
Fragment BindThreadRange(Range thread_range) const;

Range ThreadRange() const { return thread_range_; }

Expand Down Expand Up @@ -130,12 +130,6 @@ class Fragment : public Layout {
Optional<Var> replicate_var);

TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);

Fragment SetThreadRange(Range thread_range) {
auto node = make_object<FragmentNode>(*this->get());
node->SetThreadRange(thread_range);
return Fragment(node);
}
};

Var InputPlaceholder(size_t idx);
Expand Down
20 changes: 10 additions & 10 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
Expand All @@ -184,7 +184,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand All @@ -200,7 +200,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));

if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
Expand All @@ -213,7 +213,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand All @@ -228,7 +228,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this";
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand All @@ -242,7 +242,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
Expand All @@ -255,7 +255,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
Expand All @@ -275,7 +275,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {

auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));

if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
Expand All @@ -286,7 +286,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand All @@ -299,7 +299,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
Expand Down
8 changes: 4 additions & 4 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
.SetThreadRange(T.thread_bounds);
->BindThreadRange(T.thread_bounds);
}
};
if (source_buffer.defined()) {
Expand Down Expand Up @@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange(
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds));
}
// Though they may exist some conflicts, but it's fine.
Expand All @@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get();
Fragment dst_layout_fragment =
CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds);
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
<< "\nError body begin:\n"
<< " vs. " << source_buffer << "\nError body begin:\n"
<< GetRoot()->body << "\nError body end"
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
Expand Down
3 changes: 2 additions & 1 deletion src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}};
}
return {};
Expand Down
2 changes: 1 addition & 1 deletion src/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread = *as_const_int(thread_range->extent);
LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
return fragment.SetThreadRange(thread_range);
return fragment->BindThreadRange(thread_range);
}

For LoopPragmaUnroll(For stmt) {
Expand Down
35 changes: 3 additions & 32 deletions tilelang/language/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,9 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents), f"region_extents = {region_extents}, extents = {extents}"

# If region_extents already contains all elements
# of extents (in any order), pass directly
tmp_extents = list(extents)
variable_extent_count = 0
for i in range(len(region_extents)):
v = region_extents[i]
if not isinstance(v, tir.IntImm):
variable_extent_count += 1
continue

if v in tmp_extents:
tmp_extents.remove(v)
elif isinstance(v, tir.IntImm) and v != 1:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}"
)

tmp_len = len(tmp_extents) - variable_extent_count
if tmp_len > 0:
# Otherwise, align extents from the last dimension, region_extents
# can only replace 1 with extents value, otherwise raise error
for i in range(len(extents)):
idx = len(region_extents) - len(extents) + i
if region_extents[idx] != extents[i]:
if region_extents[idx] == 1:
region_extents[idx] = extents[i]
else:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{idx}] = {region_extents[idx]}, extents[{i}] = {extents[i]}"
)
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"

return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)


Expand Down