Skip to content

Commit 2af3f22

Browse files
📝 Add docstrings to pytile_0826 (#770)
* 📝 Add docstrings to `pytile_0826` Docstrings generation was requested by @LeiWang1999. * #763 (comment) The following files were modified: * `src/op/atomic_add.cc` * `src/op/atomic_add.h` * `src/op/copy.cc` * `src/op/copy.h` * `src/op/elem.cc` * `src/op/elem.h` * `src/op/gemm.cc` * `src/op/gemm.h` * `src/op/gemm_sp.cc` * `src/op/gemm_sp.h` * `src/op/operator.cc` * `src/op/operator.h` * `src/op/parallel.cc` * `src/op/parallel.h` * `src/op/reduce.cc` * `src/op/reduce.h` * `src/op/region.cc` * `src/op/region.h` * `src/transform/layout_inference.cc` * `src/transform/lower_tile_op.cc` * lint fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 8eab775 commit 2af3f22

File tree

20 files changed

+2251
-152
lines changed

20 files changed

+2251
-152
lines changed

src/op/atomic_add.cc

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ namespace tl {
2121

2222
using namespace tir;
2323

24+
/**
25+
* @brief Extracts a numeric architecture identifier from a Target's "arch"
26+
* attribute.
27+
*
28+
* Reads the Target's "arch" string (must be defined) and, if it has the form
29+
* "sm_<N>", parses and returns N as an integer. For any other arch string,
30+
* returns 0.
31+
*
32+
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
33+
* the attribute is defined).
34+
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
35+
*/
2436
static int GetArchInt(Target target) {
2537
int arch_int = 0;
2638
auto s = target->GetAttr<String>("arch");
@@ -34,6 +46,25 @@ static int GetArchInt(Target target) {
3446
return arch_int;
3547
}
3648

49+
/**
50+
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
51+
*
52+
* Builds the internal AtomicAddNode, extracts the source and destination
53+
* regions and their backing Buffers from the first two call-style expressions
54+
* in `args` (via RegionOp), and stores them along with their ranges. If a third
55+
* argument is provided, it is interpreted as an integer immediate and stored as
56+
* the node's coalesced width.
57+
*
58+
* @param args Call-style PrimExprs where:
59+
* - args[0] is the source region call,
60+
* - args[1] is the destination region call,
61+
* - args[2] (optional) is an IntImm specifying coalesced width.
62+
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
63+
*
64+
* Notes:
65+
* - The constructor checks that args[0] and args[1] are CallNodes.
66+
* - The constructed node is stored in this->data_.
67+
*/
3768
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
3869
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
3970
Array<Range> rgs[2];
@@ -54,6 +85,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5485
data_ = std::move(node);
5586
}
5687

88+
/**
89+
* @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
90+
*
91+
* Produces a new AtomicAddNode object copied from this node. If this node has
92+
* an associated ParallelOp (par_op_), the parallel op is cloned and attached to
93+
* the new node so the cloned operator preserves parallelization state.
94+
*
95+
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
96+
*/
5797
TileOperator AtomicAddNode::Clone() const {
5898
auto op = make_object<AtomicAddNode>(*this);
5999
if (par_op_.defined()) {
@@ -62,6 +102,19 @@ TileOperator AtomicAddNode::Clone() const {
62102
return AtomicAdd(op);
63103
}
64104

105+
/**
106+
* @brief Create data-parallel iteration variables for non-singleton dimensions
107+
* of the source.
108+
*
109+
* Constructs an Array of IterVar corresponding to each dimension in `src_range`
110+
* whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
111+
* Var named sequentially ("i", "j", "k", ...) with the same dtype as the
112+
* extent, and type IterVarType::kDataPar. The ordering of returned itervars
113+
* matches the order of dimensions in `src_range`.
114+
*
115+
* @return Array<IterVar> Iteration variables for all non-singleton extents in
116+
* `src_range`.
117+
*/
65118
Array<IterVar> AtomicAddNode::MakeIterVars() const {
66119
Array<IterVar> loop_vars;
67120
size_t idx = 0;
@@ -77,7 +130,26 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
77130
}
78131

79132
// ivs: itervars returned by MakeIterVars()
80-
// src_dst: 0 for src_indices, 1 for dst_indices
133+
/**
134+
* @brief Build index expressions for either source or destination from loop
135+
* iter vars.
136+
*
137+
* Given a list of iteration variables that correspond to the non-singleton
138+
* extents of the selected region (source when src_dst == 0, destination when
139+
* src_dst == 1), return an array of index expressions matching the full rank of
140+
* that region. For dimensions with extent == 1, the corresponding index is the
141+
* range's minimum; otherwise the index is `min + ivar`.
142+
*
143+
* @param ivs Iteration variables in order for all non-singleton dimensions of
144+
* the chosen region.
145+
* @param src_dst Selects which region to index: 0 for source (src_range), 1 for
146+
* destination (dst_range).
147+
* @return Array<PrimExpr> Index expressions for every dimension of the selected
148+
* region, in original dimension order.
149+
*
150+
* @note The function checks that the number of provided iter vars equals the
151+
* number of non-singleton extents; it will abort (ICHECK) if they differ.
152+
*/
81153
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
82154
int src_dst) const {
83155
Array<PrimExpr> indices;
@@ -97,6 +169,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
97169
return indices;
98170
}
99171

172+
/**
173+
* @brief Build a combined bound-check predicate for indexed access.
174+
*
175+
* Constructs an AND'd predicate ensuring each non-singleton index (derived from
176+
* `ivs`) stays within [0, extent) for the selected operand (source when
177+
* `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
178+
* range list this produces two conditions:
179+
* - range.min + iv >= 0
180+
* - range.min + iv < extent
181+
*
182+
* Conditions that the analyzer can prove (with symbolic bounds) are omitted.
183+
* If no uncertain conditions remain, an empty PrimExpr is returned.
184+
*
185+
* Note: the function ICHECKs that `extents.size()` equals the number of ranges
186+
* for the selected operand.
187+
*
188+
* @param ivs Iteration variables corresponding to non-singleton extents (order
189+
* matches the non-unit ranges of the chosen operand).
190+
* @param extents Per-dimension upper bounds to check against; must have the
191+
* same size as the selected range list.
192+
* @param src_dst Selects which ranges to validate: 0 => `src_range`, else
193+
* `dst_range`.
194+
* @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
195+
* an empty PrimExpr when no checks are required.
196+
*/
100197
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
101198
const Array<IterVar> &ivs,
102199
Array<PrimExpr> extents,
@@ -128,6 +225,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
128225
}
129226
}
130227

228+
/**
229+
* @brief Build a SIMT-style loop nest that performs element-wise atomic
230+
* additions from src to dst.
231+
*
232+
* Constructs a nested loop (parallelized per iter var) that loads a value from
233+
* the source buffer, optionally casts it to the destination dtype, and performs
234+
* an extern atomic add into the destination buffer address. For scalar
235+
* (zero-dimensional) operations a trivial serial For with a single BufferStore
236+
* is returned.
237+
*
238+
* The method:
239+
* - Creates iter vars for all non-singleton extents and binds them into the
240+
* provided analyzer.
241+
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
242+
* - Computes indexed accesses and emits optional bound predicates;
243+
* out-of-bounds accesses are masked to zero when predicates are uncertain.
244+
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
245+
* src_value)` call wrapped in an Evaluate statement.
246+
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
247+
* is defined it is attached as the "coalesced_width" annotation on each loop.
248+
*
249+
* Note: This function mutates the analyzer binding state by binding loop
250+
* variables and may fail via ICHECK if internal assumptions about shapes are
251+
* violated.
252+
*
253+
* @return A nested For loop (parallel loops) implementing the atomic-add
254+
* kernel. For scalar cases a serial For of extent 1 is returned.
255+
*/
131256
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
132257
Array<IterVar> loop_vars = MakeIterVars();
133258
bool is_scalar = loop_vars.size() == 0;
@@ -191,6 +316,41 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
191316
return Downcast<For>(body);
192317
}
193318

319+
/**
320+
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
321+
* TIR loop.
322+
*
323+
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
324+
* layout inference at multiple levels, partitions the root loop by the provided
325+
* thread variable, vectorizes the thread loop, and returns the final
326+
* (optionally predicate-guarded) statement.
327+
*
328+
* The lowering pipeline:
329+
* - Build the SIMT loop via MakeSIMTLoop.
330+
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
331+
* - Run layout inference at kCommon, kStrict, and kFree levels using fields
332+
* from `T`.
333+
* - Obtain the loop layout, partition the root loop with PartitionLoop by
334+
* `T.thread_var`.
335+
* - Vectorize the partitioned thread loop via VectorizeLoop.
336+
* - If the ParallelOp produced a predicate for `T.thread_var`, return an
337+
* IfThenElse that guards the vectorized loop with that predicate; otherwise
338+
* return the vectorized loop.
339+
*
340+
* @param T Lowering context whose fields are used:
341+
* - T.target: target architecture for layout inference and lowering
342+
* decisions.
343+
* - T.thread_var: the Var used to partition the outer loop for thread-level
344+
* parallelism.
345+
* - T.thread_bounds: bounds associated with the thread dimension (used during
346+
* partitioning).
347+
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
348+
* during InferLayout.
349+
* @param analyzer Analyzer used for symbolic reasoning during partitioning and
350+
* folding (omitted from detailed param docs as a common analysis utility).
351+
* @return Stmt A lowered TIR statement representing the parallelized and
352+
* vectorized atomic-add.
353+
*/
194354
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
195355
Target target = T.target;
196356
auto simt_loop = MakeSIMTLoop(analyzer);
@@ -221,6 +381,25 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
221381
return vectorized_thread_loop;
222382
}
223383

384+
/**
385+
* @brief Infer and return the layout map for the atomic add operator.
386+
*
387+
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
388+
* present, validates that local.fragment layouts for src and dst match when
389+
* both are provided, and then delegates layout inference to the underlying
390+
* ParallelOp.
391+
*
392+
* @param T Layout inference inputs, including an optional mapping of buffers to
393+
* layouts.
394+
* @param level Inference strictness level.
395+
* @return LayoutMap The inferred layout mapping for buffers used by this
396+
* operator.
397+
*
398+
* @note This method mutates the AtomicAddNode by creating and storing a
399+
* ParallelOp on first invocation.
400+
* @throws If both src and dst have layouts in `local.fragment` and their
401+
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
402+
*/
224403
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
225404
InferLevel level) const {
226405
if (!par_op_.defined()) {

src/op/atomic_add.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,79 @@
1010
#include "operator.h"
1111
#include "parallel.h"
1212

13+
/**
14+
* Lower this tile operator into a TIR statement for the given lowering context.
15+
*
16+
* @param T Lowering context containing mapped buffers and iteration
17+
* information.
18+
* @param analyzer Arithmetic analyzer used to simplify and reason about
19+
* expressions.
20+
* @return A TIR Stmt that implements the atomic-add tile operation for the
21+
* provided context.
22+
*/
23+
/**
24+
* Infer memory/layout mapping for tensors and buffers used by this operator.
25+
*
26+
* @param T Layout inference context providing buffer and shape information.
27+
* @param level Inference aggressiveness level; higher levels may perform more
28+
* speculative decisions.
29+
* @return A LayoutMap describing inferred layouts for the operator's inputs and
30+
* outputs.
31+
*/
32+
/**
33+
* Get the Op registration that identifies this tile operator.
34+
*
35+
* @return A reference to the registered Op representing this operator.
36+
*/
37+
/**
38+
* Create a deep copy of this tile operator node wrapped as a TileOperator.
39+
*
40+
* @return A TileOperator handle owning a cloned AtomicAddNode.
41+
*/
42+
/**
43+
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for
44+
* the operator.
45+
*
46+
* @param analyzer Arithmetic analyzer used to simplify loop bounds and
47+
* predicates.
48+
* @return A For loop node representing the SIMT-parallel loop structure.
49+
*/
50+
/**
51+
* Create iteration variables used by this operator's loop nest.
52+
*
53+
* @return An array of IterVar objects describing the loop iteration axes.
54+
*/
55+
/**
56+
* Produce index expressions for either source or destination buffer access
57+
* based on iteration vars.
58+
*
59+
* @param ivs IterVars created by MakeIterVars().
60+
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for
61+
* destination indices.
62+
* @return An array of PrimExpr index expressions suitable for indexing the
63+
* selected buffer.
64+
*/
65+
/**
66+
* Build a predicate expression that guards out-of-bounds or conditional
67+
* accesses for src or dst.
68+
*
69+
* @param analyzer Arithmetic analyzer used to simplify the predicate.
70+
* @param ivs IterVars created by MakeIterVars().
71+
* @param extents The loop extents corresponding to the itervars.
72+
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for
73+
* destination.
74+
* @return A PrimExpr boolean predicate that evaluates to true for valid
75+
* iterations.
76+
*/
77+
/**
78+
* Construct an AtomicAdd tile operator from operation arguments and a buffer
79+
* mapping.
80+
*
81+
* @param args Operation arguments (e.g., values or indices) specific to the
82+
* atomic-add semantics.
83+
* @param vmap Mapping from buffer names to Buffer objects used by this
84+
* operator.
85+
*/
1386
namespace tvm {
1487
namespace tl {
1588

0 commit comments

Comments
 (0)