diff --git a/src/op/copy.cc b/src/op/copy.cc index 9b93fea1d..37348386e 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -179,15 +179,95 @@ TileOperator CopyNode::Clone() const { * copy operation. */ Array CopyNode::MakeIterVars() const { + // Choose the range set from the lowest-level memory scope between src and + // dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment + // (fragment) + auto scope_level = [](const Buffer &b) -> int { + String s = b.scope(); + if (s == "local.fragment" || s == "local") + return 2; + if (s == "shared" || s == "shared.dyn" || s == "shared.tmem") + return 1; + // default to global level for unknown scopes + return 0; + }; + + int src_level = scope_level(src); + int dst_level = scope_level(dst); + bool base_is_src = (src_level >= dst_level); + const Array &base_ranges = base_is_src ? src_range : dst_range; + + // Sanity check: when switching away from the original (src_range), + // ensure the chosen base ranges are not provably smaller than the original + // per dimension. This guards against generating undersized loop domains. + // Improved logic: use two pointers to traverse both base_ranges and + // src_range, skipping dimensions with extent == 1. The number of non-1 + // extents must match. + arith::Analyzer analyzer; + + size_t base_dim = 0, src_dim = 0; + while (base_dim < base_ranges.size() && src_dim < src_range.size()) { + // Skip base extents that are 1 + while (base_dim < base_ranges.size() && + is_one(base_ranges[base_dim]->extent)) { + ++base_dim; + } + // Skip src extents that are 1 + while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) { + ++src_dim; + } + // Both indices now at non-1, or at end + if (base_dim < base_ranges.size() && src_dim < src_range.size()) { + PrimExpr base_ext = base_ranges[base_dim]->extent; + PrimExpr src_ext = src_range[src_dim]->extent; + // Only fail if base extent is provably smaller than src extent + if (analyzer.CanProve(base_ext < src_ext)) { + std::ostringstream oss; + oss << "Selected loop range is smaller than original src range at " + "matched non-1 dimension: " + << "base(extent=" << base_ext + << ", scope=" << (base_is_src ? src.scope() : dst.scope()) + << ", min=" << base_ranges[base_dim]->min + << ", base_dim=" << base_dim << ") < src(extent=" << src_ext + << ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim + << ", scope=" << src.scope() << ") for src=" << src->name + << ", dst=" << dst->name << "\n"; + oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n"; + oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n"; + oss << "base_ranges[" << base_dim + << "]: min=" << base_ranges[base_dim]->min + << ", extent=" << base_ext << "\n"; + oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min + << ", extent=" << src_ext << "\n"; + LOG(FATAL) << oss.str(); + } + ++base_dim; + ++src_dim; + } + } + + // Any remaining unmatched dimensions in either range must all have extent == + // 1 + while (base_dim < base_ranges.size()) { + ICHECK(is_one(base_ranges[base_dim]->extent)) + << "base_ranges has extra non-1 extent at dim " << base_dim; + ++base_dim; + } + while (src_dim < src_range.size()) { + ICHECK(is_one(src_range[src_dim]->extent)) + << "src_range has extra non-1 extent at dim " << src_dim; + ++src_dim; + } + Array loop_vars; size_t idx = 0; - for (size_t i = 0; i < src_range.size(); i++) { - if (is_one(src_range[i]->extent)) + for (size_t i = 0; i < base_ranges.size(); i++) { + if (is_one(base_ranges[i]->extent)) continue; - Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype); idx++; loop_vars.push_back( - {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + {Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar}); } return loop_vars; } diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index d59d73e87..965919fd4 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -27,6 +27,22 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, Returns: tir.Call: A handle to the copy operation + + Range handling notes: + - Accepts `Buffer`/`BufferRegion`/`BufferLoad` on either side. Extents are + derived as follows: `Buffer -> shape`, `BufferRegion -> [r.extent]`, + `BufferLoad -> extents from its inferred/encoded region`. + - If both `src` and `dst` are scalar `BufferLoad` without region extents, + lowers to a direct store: `dst[...] = src`. + - If one side is missing extents, it is treated as all-ones with the other + side's rank to enable broadcasting. + - Extents are right-aligned and legalized via `legalize_pairwise_extents`: + per tail-dimension, equal keeps as-is, a `1` broadcasts to the other, + otherwise a conservative `tir.max` is used to remain safe for dynamic + shapes. + - The finalized extents are encoded with `tl.region` via `to_buffer_region` + and passed through to the backend; low-level loop construction and any + scope-specific decisions happen during lowering. """ if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): ir.assert_structural_equal(src.shape, dst.shape) @@ -57,16 +73,11 @@ def get_extent(data): return tir.BufferStore(dst.buffer, src, dst.indices) assert src_extent or dst_extent, "Can't deduce copy extents from args" - # Treat missing extent as length-matched ones to enable broadcasting logic. + # Treat missing extent as length-matched ones to enable broadcasting. src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) - # Align and broadcast extents from the right (tail) side independently - # for src and dst, so we can pass them unchanged into _to_region. - # Rules per-dim from the right: - # - equal -> keep both - # - one is 1 -> set that side to the other side's dim - # - otherwise -> error + # Align and broadcast extents from the right (tail) side. src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) # Use legalized extents for src and dst respectively. diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 3d40ce473..720c9e991 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -46,8 +46,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: Returns: Buffer: A new buffer view with the specified shape """ - assert prim_expr_equal(bits_product(shape, src.dtype), - bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + assert prim_expr_equal( + bits_product(shape, src.dtype), bits_product(src.shape, src.dtype) + ), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" return T.Tensor(shape, src.dtype, src.data)