1717#include " ../transform/loop_partition.h"
1818#include " ../transform/loop_vectorize.h"
1919#include " builtin.h"
20+ #include " region.h"
2021
2122namespace tvm {
2223namespace tl {
@@ -62,7 +63,30 @@ using namespace tir;
6263Fill::Fill (Array<PrimExpr> args, BufferMap vmap) {
6364 ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
6465
65- if (args[0 ]->IsInstance <BufferLoadNode>()) {
66+ // Case 1: Region descriptor call (tl.region)
67+ if (const auto *call = args[0 ].as <CallNode>()) {
68+ if (call->op .same_as (RegionOp::Get ())) {
69+ auto region = RegionOp (call->args , vmap);
70+ node->dst = region->GetBuffer ();
71+ node->region = region->GetRanges ();
72+ } else if (call->op .same_as (builtin::tvm_access_ptr ())) {
73+ node->dst = vmap[GetVarFromAccessPtr (args[0 ])];
74+ for (int i = 0 ; i < node->dst ->shape .size (); i++) {
75+ node->region .push_back (Range (0 , node->dst ->shape [i]));
76+ }
77+ } else {
78+ ICHECK (false ) << " Unsupported call op in tl.fill: "
79+ << Downcast<Op>(call->op )->name ;
80+ }
81+
82+ // Case 2: Explicit BufferRegion (legacy path)
83+ } else if (args[0 ]->IsInstance <BufferRegionNode>()) {
84+ auto region = Downcast<BufferRegion>(args[0 ]);
85+ node->dst = region->buffer ;
86+ node->region = region->region ;
87+
88+ // Case 3: Vector/scalar region expressed via BufferLoad indices
89+ } else if (args[0 ]->IsInstance <BufferLoadNode>()) {
6690 auto buffer_load = Downcast<BufferLoad>(args[0 ]);
6791 for (const auto &index : buffer_load->indices ) {
6892 if (const auto *ramp = index.as <RampNode>()) {
@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
77101 }
78102 }
79103 node->dst = buffer_load->buffer ;
104+ // Case 4: Access pointer, fill the full buffer
80105 } else {
81106 node->dst = vmap[GetVarFromAccessPtr (args[0 ])];
82107 for (int i = 0 ; i < node->dst ->shape .size (); i++) {
@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
95120 << " != " << node->dst ->shape .size ();
96121 for (int i = 0 ; i < node->region .size (); i++) {
97122 // bound check if region is static
98- if (node->region [i]->min .as <IntImm >()) {
99- int64_t min = Downcast<IntImm>(node-> region [i]-> min ) ->value ;
123+ if (const auto *min_imm = node->region [i]->min .as <IntImmNode >()) {
124+ int64_t min = min_imm ->value ;
100125 ICHECK_GE (min, 0 ) << " region[" << i << " ] = " << min << " < 0" ;
101126 }
102- if (node->region [i]->extent .as <IntImm>()) {
103- int64_t extent = Downcast<IntImm>(node->region [i]->extent )->value ;
104- ICHECK_LE (extent, Downcast<IntImm>(node->dst ->shape [i])->value )
105- << " region[" << i << " ] = " << extent << " > " << node->dst ->shape [i];
127+ if (const auto *extent_imm = node->region [i]->extent .as <IntImmNode>()) {
128+ // Only perform the upper-bound check when the destination shape
129+ // extent is also statically known. If the shape is symbolic (e.g., Var),
130+ // skip this static check to avoid invalid downcasts.
131+ if (const auto *shape_imm = node->dst ->shape [i].as <IntImmNode>()) {
132+ ICHECK_LE (extent_imm->value , shape_imm->value )
133+ << " region[" << i << " ] = " << extent_imm->value << " > "
134+ << node->dst ->shape [i];
135+ }
106136 }
107137 }
108138 data_ = std::move (node);
@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
140170 for (int i = 0 ; i < ndim; i++) {
141171 Var var = Var (std::string{char (' i' + i)}, region[i]->extent ->dtype );
142172 loop_vars.push_back ({region[i], var, IterVarType::kDataPar });
143- dst_indices.push_back (var);
173+ // Offset the loop induction variable by region min to honor sliced regions
174+ dst_indices.push_back (region[i]->min + var);
144175 }
145176 Stmt body = BufferStore (dst, value, dst_indices);
146177 for (int i = ndim - 1 ; i >= 0 ; i--) {
@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
202233 return vectorized_thread_loop;
203234 } else {
204235 LOG (FATAL) << " Unsupported scope " << dst.scope ();
236+ return Stmt ();
205237 }
206238}
207239
@@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
229261TVM_FFI_STATIC_INIT_BLOCK () { FillNode::RegisterReflection (); }
230262
231263} // namespace tl
232- } // namespace tvm
264+ } // namespace tvm
0 commit comments