Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,95 @@ TileOperator CopyNode::Clone() const {
* copy operation.
*/
Array<IterVar> CopyNode::MakeIterVars() const {
// Choose the range set from the lowest-level memory scope between src and
// dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment
// (fragment)
auto scope_level = [](const Buffer &b) -> int {
String s = b.scope();
if (s == "local.fragment" || s == "local")
return 2;
if (s == "shared" || s == "shared.dyn" || s == "shared.tmem")
return 1;
// default to global level for unknown scopes
return 0;
};

int src_level = scope_level(src);
int dst_level = scope_level(dst);
bool base_is_src = (src_level >= dst_level);
const Array<Range> &base_ranges = base_is_src ? src_range : dst_range;

// Sanity check: when switching away from the original (src_range),
// ensure the chosen base ranges are not provably smaller than the original
// per dimension. This guards against generating undersized loop domains.
// Improved logic: use two pointers to traverse both base_ranges and
// src_range, skipping dimensions with extent == 1. The number of non-1
// extents must match.
arith::Analyzer analyzer;

size_t base_dim = 0, src_dim = 0;
while (base_dim < base_ranges.size() && src_dim < src_range.size()) {
// Skip base extents that are 1
while (base_dim < base_ranges.size() &&
is_one(base_ranges[base_dim]->extent)) {
++base_dim;
}
// Skip src extents that are 1
while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) {
++src_dim;
}
// Both indices now at non-1, or at end
if (base_dim < base_ranges.size() && src_dim < src_range.size()) {
PrimExpr base_ext = base_ranges[base_dim]->extent;
PrimExpr src_ext = src_range[src_dim]->extent;
// Only fail if base extent is provably smaller than src extent
if (analyzer.CanProve(base_ext < src_ext)) {
std::ostringstream oss;
oss << "Selected loop range is smaller than original src range at "
"matched non-1 dimension: "
<< "base(extent=" << base_ext
<< ", scope=" << (base_is_src ? src.scope() : dst.scope())
<< ", min=" << base_ranges[base_dim]->min
<< ", base_dim=" << base_dim << ") < src(extent=" << src_ext
<< ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim
<< ", scope=" << src.scope() << ") for src=" << src->name
<< ", dst=" << dst->name << "\n";
oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n";
oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n";
oss << "base_ranges[" << base_dim
<< "]: min=" << base_ranges[base_dim]->min
<< ", extent=" << base_ext << "\n";
oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min
<< ", extent=" << src_ext << "\n";
LOG(FATAL) << oss.str();
}
++base_dim;
++src_dim;
}
}

// Any remaining unmatched dimensions in either range must all have extent ==
// 1
while (base_dim < base_ranges.size()) {
ICHECK(is_one(base_ranges[base_dim]->extent))
<< "base_ranges has extra non-1 extent at dim " << base_dim;
++base_dim;
}
while (src_dim < src_range.size()) {
ICHECK(is_one(src_range[src_dim]->extent))
<< "src_range has extra non-1 extent at dim " << src_dim;
++src_dim;
}

Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
for (size_t i = 0; i < base_ranges.size(); i++) {
if (is_one(base_ranges[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
{Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
Expand Down
25 changes: 18 additions & 7 deletions tilelang/language/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,

Returns:
tir.Call: A handle to the copy operation

Range handling notes:
- Accepts `Buffer`/`BufferRegion`/`BufferLoad` on either side. Extents are
derived as follows: `Buffer -> shape`, `BufferRegion -> [r.extent]`,
`BufferLoad -> extents from its inferred/encoded region`.
- If both `src` and `dst` are scalar `BufferLoad` without region extents,
lowers to a direct store: `dst[...] = src`.
- If one side is missing extents, it is treated as all-ones with the other
side's rank to enable broadcasting.
- Extents are right-aligned and legalized via `legalize_pairwise_extents`:
per tail-dimension, equal keeps as-is, a `1` broadcasts to the other,
otherwise a conservative `tir.max` is used to remain safe for dynamic
shapes.
- The finalized extents are encoded with `tl.region` via `to_buffer_region`
and passed through to the backend; low-level loop construction and any
scope-specific decisions happen during lowering.
"""
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
ir.assert_structural_equal(src.shape, dst.shape)
Expand Down Expand Up @@ -57,16 +73,11 @@ def get_extent(data):
return tir.BufferStore(dst.buffer, src, dst.indices)

assert src_extent or dst_extent, "Can't deduce copy extents from args"
# Treat missing extent as length-matched ones to enable broadcasting logic.
# Treat missing extent as length-matched ones to enable broadcasting.
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)

# Align and broadcast extents from the right (tail) side independently
# for src and dst, so we can pass them unchanged into _to_region.
# Rules per-dim from the right:
# - equal -> keep both
# - one is 1 -> set that side to the other side's dim
# - otherwise -> error
# Align and broadcast extents from the right (tail) side.
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)

# Use legalized extents for src and dst respectively.
Expand Down
5 changes: 3 additions & 2 deletions tilelang/language/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
assert prim_expr_equal(bits_product(shape, src.dtype),
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
assert prim_expr_equal(
bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)
), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
return T.Tensor(shape, src.dtype, src.data)


Expand Down
Loading