-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
rewrite_cooperative_fetch.cc
234 lines (215 loc) · 8.82 KB
/
rewrite_cooperative_fetch.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"
namespace tvm {
namespace tir {
/*!
* \brief Parse instruction: sch.bind(..., axis)
* \param sch The schedule
* \param inst The instruction to be parsed
* \param axis The axis name expected
* \return NullOpt if parsing fails; Otherwise, the extent of thread axis
*/
Optional<Integer> ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) {
static InstructionKind inst_kind_bind = InstructionKind::Get("Bind");
if (!inst->kind.same_as(inst_kind_bind)) {
return NullOpt;
}
ICHECK_EQ(inst->inputs.size(), 1);
ICHECK_EQ(inst->attrs.size(), 1);
String thread_axis = Downcast<String>(inst->attrs[0]);
if (thread_axis != axis) {
return NullOpt;
}
return Downcast<Integer>(sch->Get(Downcast<LoopRV>(inst->inputs[0]))->extent);
}
/*!
* \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch)
* \param sch The schedule
* \param inst The instruction to be parsed
* \param vector_lane The number of vector lane in vectorized cooperative fetching
* \return NullOpt if parsing fails; Otherwise, the annotated block
*/
Optional<BlockRV> ParseAnnotate(const Schedule& sch, const Instruction& inst,
int64_t* vector_lane) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return NullOpt;
}
ICHECK_EQ(inst->inputs.size(), 2);
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
if (ann_key != attr::meta_schedule_cooperative_fetch) {
return NullOpt;
}
*vector_lane = Downcast<Integer>(sch->Get(Downcast<ExprRV>(inst->inputs[1])))->value;
return Downcast<BlockRV>(inst->inputs[0]);
}
/*!
* \brief Parse instruction: sch.annotate(..., attr::warp_execution)
* \param sch The schedule
* \param inst The instruction to be parsed
* \return Whether ths parsing is successful
*/
bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_kind_annotate)) {
return false;
}
ICHECK_EQ(inst->inputs.size(), 2);
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::warp_execution;
}
size_t GetMaxUsedDtypeBytes(Block block) {
size_t max_bytes = 1;
static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis");
static auto q_multiply_shift = Op::Get("tir.q_multiply_shift");
tir::PostOrderVisit(block->body, [&](const ObjectRef& obj) {
if (const auto* store = obj.as<tir::BufferStoreNode>()) {
max_bytes = std::max(max_bytes, static_cast<size_t>(store->value->dtype.bytes()));
} else if (const auto* load = obj.as<tir::BufferLoadNode>()) {
max_bytes = std::max(max_bytes, static_cast<size_t>(load->dtype.bytes()));
} else if (const auto* call = obj.as<tir::CallNode>()) {
if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) {
// q_multiply_shift uses 64 bit multiply
max_bytes = std::max<size_t>(max_bytes, 8);
}
} else if (const auto* cast = obj.as<tir::CastNode>()) {
max_bytes = std::max<size_t>(max_bytes, cast->dtype.bytes());
}
});
return max_bytes;
}
} // namespace tir
namespace meta_schedule {
/*!
* \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching
* in loop bindings.
*/
class RewriteCooperativeFetchNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("thread_warp_size")) {
this->thread_warp_size_ = v.value()->value;
} else {
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
}
}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;
Postproc Clone() const {
ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>(*this);
return Postproc(n);
}
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode);
private:
int thread_warp_size_ = -1;
};
bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
tir::Trace trace = sch->trace().value();
int64_t thread_extent_x = -1;
int64_t thread_extent_y = -1;
int64_t vector_lane = 1;
std::vector<std::function<void()>> tasks;
for (const tir::Instruction& inst : trace->insts) {
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) {
thread_extent_x = new_thread_extent.value()->value;
continue;
}
if (Optional<Integer> new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) {
thread_extent_y = new_thread_extent.value()->value;
continue;
}
if (tir::ParseWarpExecutionAnn(sch, inst)) {
thread_extent_x = thread_warp_size_;
continue;
}
Optional<tir::BlockRV> opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane);
if (!opt_block_rv.defined()) {
continue;
}
auto task = [thread_extent_x, thread_extent_y, vector_lane, sch,
block = opt_block_rv.value()]() mutable -> void {
sch->Unannotate(block, tir::attr::meta_schedule_cooperative_fetch);
tir::LoopRV fused = sch->GetLoops(block).back();
int64_t fused_extent = -1;
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(fused).get())) {
fused_extent = *extent;
} else {
return;
}
if (fused_extent % vector_lane != 0) {
vector_lane = 1;
}
// If the block involves 64 bit values, disable vectorization for now since
// vectorization of 64 bit values does not work well on CUDA.
// TODO(masahi, vinx13): Decouple epilogue fusion computation and shared to global store, so
// that we can always vectorize the latter.
if (tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) {
vector_lane = 1;
}
if (thread_extent_y != -1) {
if (vector_lane > 1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[3]);
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_y), //
Integer(thread_extent_x)});
sch->Bind(split[2], "threadIdx.x");
sch->Bind(split[1], "threadIdx.y");
}
} else {
if (vector_lane > 1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Integer(thread_extent_x), //
Integer(vector_lane)});
sch->Vectorize(split[2]);
sch->Bind(split[1], "threadIdx.x");
} else {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)});
sch->Bind(split[1], "threadIdx.x");
}
}
};
tasks.push_back(task);
}
for (auto&& task : tasks) {
task();
}
return true;
}
Postproc Postproc::RewriteCooperativeFetch() {
ObjectPtr<RewriteCooperativeFetchNode> n = make_object<RewriteCooperativeFetchNode>();
return Postproc(n);
}
TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch")
.set_body_typed(Postproc::RewriteCooperativeFetch);
} // namespace meta_schedule
} // namespace tvm