Skip to content

Commit b38bd69

Browse files
authored
[Refactor] Refactor Operator into TileOperator and with tvm reflection (#763)
* Refactor operator classes to inherit from TileOperator and update layout inference methods - Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations. - Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency. - Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization. - Added missing layout inference implementations for Fill and Conv2DIm2ColOp. - Removed deprecated op.cc and op.h files to streamline the codebase. * lint fix * Refactor operator classes to use Node pattern and improve memory management - Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation. - Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access. - Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design. - Refactored InferLayout and Lower methods to ensure consistency across operator implementations. - Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase. * Enhance Clone methods in AtomicAdd and Copy classes to support parallel operation cloning - Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects. - Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations. - Made minor adjustments in layout inference and other related methods for consistency and clarity. * Refactor FillNode::Lower method to remove unused global function call - Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity. - Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies.
1 parent 277ed53 commit b38bd69

22 files changed

+760
-518
lines changed

src/op/atomic_add.cc

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* Define elment-wise operators.
55
*/
66

7-
#include "atomic_add.h"
8-
7+
#include "./atomic_add.h"
8+
#include "./region.h"
99
#include <tvm/tir/builtin.h>
1010
#include <tvm/tir/op.h>
1111
#include <tvm/tir/op_attr_types.h>
@@ -34,25 +34,35 @@ static int GetArchInt(Target target) {
3434
return arch_int;
3535
}
3636

37-
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
37+
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
38+
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
3839
Array<Range> rgs[2];
3940
Buffer bf[2];
4041
for (int i = 0; i < 2; i++) {
4142
auto expr = args[i];
4243
auto call = expr.as<CallNode>();
4344
ICHECK(call);
4445
auto region = RegionOp(call->args, vmap);
45-
rgs[i] = region.GetRanges();
46-
bf[i] = region.GetBuffer();
46+
rgs[i] = region->GetRanges();
47+
bf[i] = region->GetBuffer();
4748
}
48-
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
49-
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
49+
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
50+
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
5051
if (args.size() >= 3) {
51-
coalesced_width = Downcast<IntImm>(args[2]);
52+
node->coalesced_width = Downcast<IntImm>(args[2]);
5253
}
54+
data_ = std::move(node);
5355
}
5456

55-
Array<IterVar> AtomicAdd::MakeIterVars() const {
57+
TileOperator AtomicAddNode::Clone() const {
58+
auto op = make_object<AtomicAddNode>(*this);
59+
if (par_op_.defined()) {
60+
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
61+
}
62+
return AtomicAdd(op);
63+
}
64+
65+
Array<IterVar> AtomicAddNode::MakeIterVars() const {
5666
Array<IterVar> loop_vars;
5767
size_t idx = 0;
5868
for (size_t i = 0; i < src_range.size(); i++) {
@@ -68,8 +78,8 @@ Array<IterVar> AtomicAdd::MakeIterVars() const {
6878

6979
// ivs: itervars returned by MakeIterVars()
7080
// src_dst: 0 for src_indices, 1 for dst_indices
71-
Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
72-
int src_dst) const {
81+
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
82+
int src_dst) const {
7383
Array<PrimExpr> indices;
7484
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
7585
size_t idx = 0;
@@ -87,9 +97,10 @@ Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
8797
return indices;
8898
}
8999

90-
PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
91-
const Array<IterVar> &ivs,
92-
Array<PrimExpr> extents, int src_dst) const {
100+
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
101+
const Array<IterVar> &ivs,
102+
Array<PrimExpr> extents,
103+
int src_dst) const {
93104
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
94105
Array<PrimExpr> cond_list;
95106
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
@@ -117,7 +128,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
117128
}
118129
}
119130

120-
For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
131+
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
121132
Array<IterVar> loop_vars = MakeIterVars();
122133
bool is_scalar = loop_vars.size() == 0;
123134
if (is_scalar) {
@@ -180,16 +191,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
180191
return Downcast<For>(body);
181192
}
182193

183-
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
194+
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
184195
Target target = T.target;
185196
auto simt_loop = MakeSIMTLoop(analyzer);
186197
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
187-
auto par_op = std::make_unique<ParallelOp>(fused_loop);
198+
auto par_op = ParallelOp(fused_loop);
188199

189200
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
190201
InferLevel::kFree};
191202
for (auto level : levels) {
192-
par_op->InferLayout(
203+
(par_op)->InferLayout(
193204
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
194205
}
195206
auto loop_layout = par_op->GetLoopLayout();
@@ -210,10 +221,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
210221
return vectorized_thread_loop;
211222
}
212223

213-
LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) {
214-
if (par_op_ == nullptr) {
224+
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
225+
InferLevel level) const {
226+
if (!par_op_.defined()) {
215227
arith::Analyzer analyzer;
216-
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
228+
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
217229
}
218230
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
219231
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
@@ -236,10 +248,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
236248
.set_attr<TCallEffectKind>("TCallEffectKind",
237249
Integer(CallEffectKind::kOpaque));
238250

239-
// TVM_REGISTER_OP("tl.atomicadd")
240-
// .set_num_inputs(2)
241-
// .add_argument("ref", "Buffer", "The destination buffer")
242-
// .add_argument("val", "Expr", "The value to be added atomically");
243-
244251
} // namespace tl
245252
} // namespace tvm

src/op/atomic_add.h

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,31 @@
77
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
88
#define TVM_TL_OP_ATOMIC_ADD_H_
99

10-
#include "op.h"
10+
#include "operator.h"
1111
#include "parallel.h"
1212

1313
namespace tvm {
1414
namespace tl {
1515

1616
using namespace tir;
1717

18-
class AtomicAdd : public Operator {
18+
class AtomicAddNode : public TileOperatorNode {
1919
public:
20-
AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
21-
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
22-
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
20+
Array<PrimExpr> args_;
2321

24-
static const Op &Get();
22+
Buffer src, dst;
23+
Array<Range> src_range, dst_range;
24+
IntImm coalesced_width;
2525

26-
AtomicAdd(const AtomicAdd &other)
27-
: args_(other.args_), src(other.src), dst(other.dst),
28-
src_range(other.src_range), dst_range(other.dst_range),
29-
coalesced_width(other.coalesced_width) {
30-
// No clone nullptr
31-
if (other.par_op_)
32-
par_op_ = std::unique_ptr<ParallelOp>(
33-
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
34-
}
35-
std::unique_ptr<Operator> Clone() const final {
36-
return std::make_unique<AtomicAdd>(*this);
37-
}
26+
mutable ParallelOp par_op_;
27+
static constexpr const char *_type_key = "tl.AtomicAdd";
28+
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
29+
30+
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
31+
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
32+
33+
static const Op &Get();
34+
TileOperator Clone() const;
3835

3936
protected:
4037
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
@@ -46,14 +43,13 @@ class AtomicAdd : public Operator {
4643

4744
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
4845
Array<PrimExpr> extents, int src_dst) const;
46+
};
4947

50-
Array<PrimExpr> args_;
51-
52-
Buffer src, dst;
53-
Array<Range> src_range, dst_range;
54-
IntImm coalesced_width;
55-
56-
std::unique_ptr<ParallelOp> par_op_;
48+
class AtomicAdd : public TileOperator {
49+
public:
50+
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
51+
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
52+
static const Op &Get();
5753
};
5854

5955
} // namespace tl

src/op/builtin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#ifndef TVM_TL_OP_BUILTIN_H_
88
#define TVM_TL_OP_BUILTIN_H_
99

10-
#include "op.h"
10+
#include "operator.h"
1111
#include <tvm/ir/transform.h>
1212

1313
namespace tvm {

0 commit comments

Comments
 (0)