@@ -21,6 +21,18 @@ namespace tl {
2121
2222using namespace tir ;
2323
24+ /* *
25+ * @brief Extracts a numeric architecture identifier from a Target's "arch"
26+ * attribute.
27+ *
28+ * Reads the Target's "arch" string (must be defined) and, if it has the form
29+ * "sm_<N>", parses and returns N as an integer. For any other arch string,
30+ * returns 0.
31+ *
32+ * @param target Target whose "arch" attribute will be inspected (ICHECKs that
33+ * the attribute is defined).
34+ * @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
35+ */
2436static int GetArchInt (Target target) {
2537 int arch_int = 0 ;
2638 auto s = target->GetAttr <String>(" arch" );
@@ -34,6 +46,25 @@ static int GetArchInt(Target target) {
3446 return arch_int;
3547}
3648
49+ /* *
50+ * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
51+ *
52+ * Builds the internal AtomicAddNode, extracts the source and destination
53+ * regions and their backing Buffers from the first two call-style expressions
54+ * in `args` (via RegionOp), and stores them along with their ranges. If a third
55+ * argument is provided, it is interpreted as an integer immediate and stored as
56+ * the node's coalesced width.
57+ *
58+ * @param args Call-style PrimExprs where:
59+ * - args[0] is the source region call,
60+ * - args[1] is the destination region call,
61+ * - args[2] (optional) is an IntImm specifying coalesced width.
62+ * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
63+ *
64+ * Notes:
65+ * - The constructor checks that args[0] and args[1] are CallNodes.
66+ * - The constructed node is stored in this->data_.
67+ */
3768AtomicAdd::AtomicAdd (Array<PrimExpr> args, BufferMap vmap) {
3869 ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
3970 Array<Range> rgs[2 ];
@@ -54,6 +85,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5485 data_ = std::move (node);
5586}
5687
88+ /* *
89+ * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
90+ *
91+ * Produces a new AtomicAddNode object copied from this node. If this node has
92+ * an associated ParallelOp (par_op_), the parallel op is cloned and attached to
93+ * the new node so the cloned operator preserves parallelization state.
94+ *
95+ * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
96+ */
5797TileOperator AtomicAddNode::Clone () const {
5898 auto op = make_object<AtomicAddNode>(*this );
5999 if (par_op_.defined ()) {
@@ -62,6 +102,19 @@ TileOperator AtomicAddNode::Clone() const {
62102 return AtomicAdd (op);
63103}
64104
105+ /* *
106+ * @brief Create data-parallel iteration variables for non-singleton dimensions
107+ * of the source.
108+ *
109+ * Constructs an Array of IterVar corresponding to each dimension in `src_range`
110+ * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
111+ * Var named sequentially ("i", "j", "k", ...) with the same dtype as the
112+ * extent, and type IterVarType::kDataPar. The ordering of returned itervars
113+ * matches the order of dimensions in `src_range`.
114+ *
115+ * @return Array<IterVar> Iteration variables for all non-singleton extents in
116+ * `src_range`.
117+ */
65118Array<IterVar> AtomicAddNode::MakeIterVars () const {
66119 Array<IterVar> loop_vars;
67120 size_t idx = 0 ;
@@ -77,7 +130,26 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
77130}
78131
79132// ivs: itervars returned by MakeIterVars()
80- // src_dst: 0 for src_indices, 1 for dst_indices
133+ /* *
134+ * @brief Build index expressions for either source or destination from loop
135+ * iter vars.
136+ *
137+ * Given a list of iteration variables that correspond to the non-singleton
138+ * extents of the selected region (source when src_dst == 0, destination when
139+ * src_dst == 1), return an array of index expressions matching the full rank of
140+ * that region. For dimensions with extent == 1, the corresponding index is the
141+ * range's minimum; otherwise the index is `min + ivar`.
142+ *
143+ * @param ivs Iteration variables in order for all non-singleton dimensions of
144+ * the chosen region.
145+ * @param src_dst Selects which region to index: 0 for source (src_range), 1 for
146+ * destination (dst_range).
147+ * @return Array<PrimExpr> Index expressions for every dimension of the selected
148+ * region, in original dimension order.
149+ *
150+ * @note The function checks that the number of provided iter vars equals the
151+ * number of non-singleton extents; it will abort (ICHECK) if they differ.
152+ */
81153Array<PrimExpr> AtomicAddNode::MakeIndices (const Array<IterVar> &ivs,
82154 int src_dst) const {
83155 Array<PrimExpr> indices;
@@ -97,6 +169,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
97169 return indices;
98170}
99171
172+ /* *
173+ * @brief Build a combined bound-check predicate for indexed access.
174+ *
175+ * Constructs an AND'd predicate ensuring each non-singleton index (derived from
176+ * `ivs`) stays within [0, extent) for the selected operand (source when
177+ * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
178+ * range list this produces two conditions:
179+ * - range.min + iv >= 0
180+ * - range.min + iv < extent
181+ *
182+ * Conditions that the analyzer can prove (with symbolic bounds) are omitted.
183+ * If no uncertain conditions remain, an empty PrimExpr is returned.
184+ *
185+ * Note: the function ICHECKs that `extents.size()` equals the number of ranges
186+ * for the selected operand.
187+ *
188+ * @param ivs Iteration variables corresponding to non-singleton extents (order
189+ * matches the non-unit ranges of the chosen operand).
190+ * @param extents Per-dimension upper bounds to check against; must have the
191+ * same size as the selected range list.
192+ * @param src_dst Selects which ranges to validate: 0 => `src_range`, else
193+ * `dst_range`.
194+ * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
195+ * an empty PrimExpr when no checks are required.
196+ */
100197PrimExpr AtomicAddNode::MakePredicate (arith::Analyzer *analyzer,
101198 const Array<IterVar> &ivs,
102199 Array<PrimExpr> extents,
@@ -128,6 +225,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
128225 }
129226}
130227
228+ /* *
229+ * @brief Build a SIMT-style loop nest that performs element-wise atomic
230+ * additions from src to dst.
231+ *
232+ * Constructs a nested loop (parallelized per iter var) that loads a value from
233+ * the source buffer, optionally casts it to the destination dtype, and performs
234+ * an extern atomic add into the destination buffer address. For scalar
235+ * (zero-dimensional) operations a trivial serial For with a single BufferStore
236+ * is returned.
237+ *
238+ * The method:
239+ * - Creates iter vars for all non-singleton extents and binds them into the
240+ * provided analyzer.
241+ * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
242+ * - Computes indexed accesses and emits optional bound predicates;
243+ * out-of-bounds accesses are masked to zero when predicates are uncertain.
244+ * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
245+ * src_value)` call wrapped in an Evaluate statement.
246+ * - Wraps the body with a parallel For at each loop level. If `coalesced_width`
247+ * is defined it is attached as the "coalesced_width" annotation on each loop.
248+ *
249+ * Note: This function mutates the analyzer binding state by binding loop
250+ * variables and may fail via ICHECK if internal assumptions about shapes are
251+ * violated.
252+ *
253+ * @return A nested For loop (parallel loops) implementing the atomic-add
254+ * kernel. For scalar cases a serial For of extent 1 is returned.
255+ */
131256For AtomicAddNode::MakeSIMTLoop (arith::Analyzer *analyzer) const {
132257 Array<IterVar> loop_vars = MakeIterVars ();
133258 bool is_scalar = loop_vars.size () == 0 ;
@@ -191,6 +316,41 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
191316 return Downcast<For>(body);
192317}
193318
319+ /* *
320+ * @brief Lower the atomic-add top-level operator into a parallel, vectorized
321+ * TIR loop.
322+ *
323+ * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
324+ * layout inference at multiple levels, partitions the root loop by the provided
325+ * thread variable, vectorizes the thread loop, and returns the final
326+ * (optionally predicate-guarded) statement.
327+ *
328+ * The lowering pipeline:
329+ * - Build the SIMT loop via MakeSIMTLoop.
330+ * - Fuse parallel loops into a single For and wrap as a ParallelOp.
331+ * - Run layout inference at kCommon, kStrict, and kFree levels using fields
332+ * from `T`.
333+ * - Obtain the loop layout, partition the root loop with PartitionLoop by
334+ * `T.thread_var`.
335+ * - Vectorize the partitioned thread loop via VectorizeLoop.
336+ * - If the ParallelOp produced a predicate for `T.thread_var`, return an
337+ * IfThenElse that guards the vectorized loop with that predicate; otherwise
338+ * return the vectorized loop.
339+ *
340+ * @param T Lowering context whose fields are used:
341+ * - T.target: target architecture for layout inference and lowering
342+ * decisions.
343+ * - T.thread_var: the Var used to partition the outer loop for thread-level
344+ * parallelism.
345+ * - T.thread_bounds: bounds associated with the thread dimension (used during
346+ * partitioning).
347+ * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
348+ * during InferLayout.
349+ * @param analyzer Analyzer used for symbolic reasoning during partitioning and
350+ * folding (omitted from detailed param docs as a common analysis utility).
351+ * @return Stmt A lowered TIR statement representing the parallelized and
352+ * vectorized atomic-add.
353+ */
194354Stmt AtomicAddNode::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
195355 Target target = T.target ;
196356 auto simt_loop = MakeSIMTLoop (analyzer);
@@ -221,6 +381,25 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
221381 return vectorized_thread_loop;
222382}
223383
384+ /* *
385+ * @brief Infer and return the layout map for the atomic add operator.
386+ *
387+ * Constructs a cached ParallelOp (by building the SIMT loop) if not already
388+ * present, validates that local.fragment layouts for src and dst match when
389+ * both are provided, and then delegates layout inference to the underlying
390+ * ParallelOp.
391+ *
392+ * @param T Layout inference inputs, including an optional mapping of buffers to
393+ * layouts.
394+ * @param level Inference strictness level.
395+ * @return LayoutMap The inferred layout mapping for buffers used by this
396+ * operator.
397+ *
398+ * @note This method mutates the AtomicAddNode by creating and storing a
399+ * ParallelOp on first invocation.
400+ * @throws If both src and dst have layouts in `local.fragment` and their
401+ * fragment layouts differ, an ICHECK failure is raised with diagnostic output.
402+ */
224403LayoutMap AtomicAddNode::InferLayout (const LayoutInferArgs &T,
225404 InferLevel level) const {
226405 if (!par_op_.defined ()) {
0 commit comments