-
Notifications
You must be signed in to change notification settings - Fork 21
/
nn.cc
471 lines (421 loc) · 18.2 KB
/
nn.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/*!
* \file src/op/grad/nn.cc
* \brief Declaration of gradients
*/
#include <raf/value.h>
#include <raf/op_utils.h>
#include "./grad_utils.h"
namespace raf {
namespace op {
namespace grad {
using namespace raf::ir;
using namespace raf::value;
Array<Expr> BiasAddGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
using namespace raf::value;
static auto sum = Op::Get("raf.op.sum");
const CallNode* call = orig_call.as<CallNode>();
const Expr& x = call->args[0];
const Expr& bias = call->args[1];
const Expr& axis = call->args[2];
Expr keep_dims = MakeConstant(ScalarValue::make((int64_t)0));
Expr exclude = MakeConstant(ScalarValue::make(true));
return {GetReshapeLike(dy, x), Call(sum, {dy, axis, keep_dims, exclude})};
}
RAF_OP_GRAD("raf.op.bias_add", BiasAddGrad);
Array<Expr> ContribDropoutGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dout) {
const static auto dropout_dx = Op::Get("raf.op._contrib_dropout_dx");
const Expr& dy = AsTupleExpr(dout, 2)[0];
const Expr& mask = TupleGetItem(y, 1);
const Expr& reserve_space = TupleGetItem(y, 2);
const Expr& p = orig_args[1];
return {Call(dropout_dx, {dy, mask, reserve_space, p})};
}
RAF_OP_GRAD("raf.op._contrib_dropout", ContribDropoutGrad);
template <const char* GradOp>
Array<Expr> PoolGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get(GradOp);
const CallNode* call = orig_call.as<CallNode>();
const Expr& x = call->args[0];
const Expr& kernel = call->args[1];
const Expr& stride = call->args[2];
const Expr& padding = call->args[3];
const Expr& dilation = call->args[4];
const Expr& ceil_mode = call->args[5];
const Expr& include_pad = call->args[6];
const Expr& layout = orig_args[7];
const auto* layout_const = layout.as<ConstantNode>();
if (layout_const) {
const auto* layout_str = layout_const->value.as<value::StringValueObj>();
CHECK(layout_str && layout_str->value == "NCHW")
<< "PoolGrad support NCHW layout only. Layout = " << layout_str->value;
}
return {Call(op_dx, {x, y, dy, kernel, stride, padding, dilation, ceil_mode, include_pad})};
}
const char MAX_POOL2D_DX[] = "raf.op.max_pool2d_dx";
auto MaxPool2dGrad = PoolGrad<MAX_POOL2D_DX>;
RAF_OP_GRAD("raf.op.max_pool2d", MaxPool2dGrad);
const char AVG_POOL2D_DX[] = "raf.op.avg_pool2d_dx";
auto AvgPool2dGrad = PoolGrad<AVG_POOL2D_DX>;
RAF_OP_GRAD("raf.op.avg_pool2d", AvgPool2dGrad);
template <const char* GradOp>
Array<Expr> AdaptivePoolGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get(GradOp);
const CallNode* call = orig_call.as<CallNode>();
const Expr& x = call->args[0];
const Expr& shape = call->args[1];
const Expr& layout = call->args[2];
const auto* layout_const = layout.as<ConstantNode>();
if (layout_const) {
const auto* layout_str = layout_const->value.as<value::StringValueObj>();
CHECK(layout_str && layout_str->value == "NCHW")
<< "AdaptivePoolGrad support NCHW layout only. Layout = " << layout_str->value;
}
return {Call(op_dx, {x, y, dy, shape})};
}
const char ADAPTIVE_MAX_POOL2D_DX[] = "raf.op.adaptive_max_pool2d_dx";
auto AdaptiveMaxPool2dGrad = AdaptivePoolGrad<ADAPTIVE_MAX_POOL2D_DX>;
RAF_OP_GRAD("raf.op.adaptive_max_pool2d", AdaptiveMaxPool2dGrad);
const char ADAPTIVE_AVG_POOL2D_DX[] = "raf.op.adaptive_avg_pool2d_dx";
auto AdaptiveAvgPool2dGrad = AdaptivePoolGrad<ADAPTIVE_AVG_POOL2D_DX>;
RAF_OP_GRAD("raf.op.adaptive_avg_pool2d", AdaptiveAvgPool2dGrad);
Array<Expr> Conv2dGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
// schema for conv2d is:
// x, w, stride, padding, dilation, groups
// schema for conv2d_grad is:
// x_or_w, y, dy, shape, stride, padding, dilation, groups
static auto op_dx = Op::Get("raf.op.conv2d_dx");
static auto op_dw = Op::Get("raf.op.conv2d_dw");
const CallNode* call = orig_call.as<CallNode>();
// TODO(@junrushao1994): this piece of code is particularly suitable for auto-gen
CHECK_GE(call->args.size(), 6);
const Expr& x = call->args[0];
const Expr& w = call->args[1];
const Expr& stride = call->args[2];
const Expr& padding = call->args[3];
const Expr& dilation = call->args[4];
const Expr& groups = call->args[5];
const Expr& layout = orig_args[6];
const auto* layout_const = layout.as<ConstantNode>();
if (layout_const) {
const auto* layout_str = layout_const->value.as<value::StringValueObj>();
CHECK(layout_str && layout_str->value == "NCHW")
<< "PoolGrad support NCHW layout only. Layout = " << layout_str->value;
}
// dx: w, y, dy, shape(x), stride, padding, dilation, groups
// dw: x, y, dy, shape(w), stride, padding, dilation, groups
return {Call(op_dx, {w, y, dy, GetShape(x), stride, padding, dilation, groups}),
Call(op_dw, {x, y, dy, GetShape(w), stride, padding, dilation, groups})};
}
RAF_OP_GRAD("raf.op.conv2d", Conv2dGrad);
Array<Expr> Conv2dTransGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get("raf.op.conv2d_transpose_dx");
static auto op_dw = Op::Get("raf.op.conv2d_transpose_dw");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 7);
const Expr& x = call->args[0];
const Expr& w = call->args[1];
const Expr& stride = call->args[2];
const Expr& padding = call->args[3];
const Expr& output_padding = call->args[4];
const Expr& dilation = call->args[5];
const Expr& groups = call->args[6];
const Expr& layout = orig_args[7];
const auto* layout_const = layout.as<ConstantNode>();
if (layout_const) {
const auto* layout_str = layout_const->value.as<value::StringValueObj>();
CHECK(layout_str && layout_str->value == "NCHW")
<< "PoolGrad support NCHW layout only. Layout = " << layout_str->value;
}
// dx: w, y, dy, shape(x), stride, padding, dilation, groups
// dw: x, y, dy, shape(w), stride, padding, dilation, groups
return {Call(op_dx, {w, y, dy, GetShape(x), stride, padding, output_padding, dilation, groups}),
Call(op_dw, {x, y, dy, GetShape(w), stride, padding, output_padding, dilation, groups})};
}
RAF_OP_GRAD("raf.op.conv2d_transpose", Conv2dTransGrad);
template <const char* GradOp>
Array<Expr> UnaryGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
// schema for unary_op is:
// x
// schema for unary_dx_op is:
// x, y, dy
// Note that unary_dx schema allows y to be optional. If the unary_dx op supports
// optional y, we need to specify it in GradientInputSelection pass to remove the
// y dependency.
static auto op_dx = Op::Get(GradOp);
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
return {Call(op_dx, {x, y, dy})};
}
Array<Expr> RsqrtGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_sqrt = Op::Get("raf.op.sqrt");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_divide = Op::Get("raf.op.divide");
static auto op_negative = Op::Get("raf.op.negative");
static auto op_add = Op::Get("raf.op.add");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call one = Call(op_divide, {x, x});
Call two = Call(op_add, {one, one, MakeNull(), MakeNull()});
Call half = Call(op_divide, {one, two});
Call neg_half = Call(op_negative, {half});
Call x_pow_2 = Call(op_multiply, {x, x});
Call x_pow_3 = Call(op_multiply, {x_pow_2, x});
Call sqrt_x_pow_3 = Call(op_sqrt, {x_pow_3});
Call rsqrt_x_pow_3 = Call(op_divide, {one, sqrt_x_pow_3});
Call dx = Call(op_multiply, {neg_half, rsqrt_x_pow_3});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.rsqrt", RsqrtGrad);
Array<Expr> TruncGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call grads;
if (x->checked_type_.defined()) {
static auto op_zeros = Op::Get("raf.op.zeros");
auto ttype = x->checked_type().as<TensorTypeNode>();
auto shape = MakeConstant(ArrayToIntTuple(ttype->shape));
auto dtype = MakeConstant(StringValue::make(tvm::runtime::DLDataType2String(ttype->dtype)));
grads = Call(op_zeros, {shape, dtype});
} else {
static auto op_zeros_like = Op::Get("raf.op.zeros_like");
grads = Call(op_zeros_like, {x});
}
return {grads};
}
RAF_OP_GRAD("raf.op.trunc", TruncGrad);
Array<Expr> LogGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
// give zero gradient for any gradient
static auto op_div = Op::Get("raf.op.divide");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_ones = Op::Get("raf.op.ones_like");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call ones = Call(op_ones, {x});
Call one_by_x = Call(op_div, {ones, x});
Call result = Call(op_multiply, {dy, one_by_x});
return {result};
}
RAF_OP_GRAD("raf.op.log", LogGrad);
Array<Expr> Log2Grad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
// give zero gradient for any gradient
static auto op_div = Op::Get("raf.op.divide");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_ones = Op::Get("raf.op.ones_like");
static auto op_add = Op::Get("raf.op.add");
static auto op_log = Op::Get("raf.op.log");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call ones = Call(op_ones, {x});
Call two = Call(op_add, {ones, ones, MakeNull(), MakeNull()});
Call log_two = Call(op_log, {two});
Call log_two_x = Call(op_multiply, {log_two, x});
Call dx = Call(op_div, {ones, log_two_x});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.log2", Log2Grad);
Array<Expr> CosGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_sin = Op::Get("raf.op.sin");
static auto op_negative = Op::Get("raf.op.negative");
static auto op_multiply = Op::Get("raf.op.multiply");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call sin_x = Call(op_sin, {x});
Call dx = Call(op_negative, {sin_x});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.cos", CosGrad);
Array<Expr> SinGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_cos = Op::Get("raf.op.cos");
static auto op_multiply = Op::Get("raf.op.multiply");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call dx = Call(op_cos, {x});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.sin", SinGrad);
Array<Expr> ExpGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_exp = Op::Get("raf.op.exp");
static auto op_multiply = Op::Get("raf.op.multiply");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call dx = Call(op_exp, {x});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.exp", ExpGrad);
Array<Expr> AtanGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_divide = Op::Get("raf.op.divide");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_add = Op::Get("raf.op.add");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call one = Call(op_divide, {x, x});
Call x_square = Call(op_multiply, {x, x});
Call denominator = Call(op_add, {x_square, one, MakeNull(), MakeNull()});
Call dx = Call(op_divide, {one, denominator});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.atan", AtanGrad);
const char RELU_DX[] = "raf.op.relu_dx";
auto ReluGrad = UnaryGrad<RELU_DX>;
RAF_OP_GRAD("raf.op.relu", ReluGrad);
const char GELU_DX[] = "raf.op.gelu_dx";
auto GeluGrad = UnaryGrad<GELU_DX>;
RAF_OP_GRAD("raf.op.gelu", GeluGrad);
const char TANH_DX[] = "raf.op.tanh_dx";
auto TanhGrad = UnaryGrad<TANH_DX>;
RAF_OP_GRAD("raf.op.tanh", TanhGrad);
const char SIGMOID_DX[] = "raf.op.sigmoid_dx";
auto SigmoidGrad = UnaryGrad<SIGMOID_DX>;
RAF_OP_GRAD("raf.op.sigmoid", SigmoidGrad);
const char ERF_DX[] = "raf.op.erf_dx";
auto ErfGrad = UnaryGrad<ERF_DX>;
RAF_OP_GRAD("raf.op.erf", ErfGrad);
const char SQRT_DX[] = "raf.op.sqrt_dx";
auto SqrtGrad = UnaryGrad<SQRT_DX>;
RAF_OP_GRAD("raf.op.sqrt", SqrtGrad);
Array<Expr> BatchNormTrainGrad(const Expr& orig_call, const Var& y, const Expr& dymv,
const Array<Expr>& igrads) {
// schema for batch_norm_train is:
// x, running_mean,running_var, w, b, momentum, eps
// schema for batch_norm_train_dxwb is:
// dy, x, w, b, eps
static auto op_dxwb = Op::Get("raf.op.batch_norm_train_dxwb");
const Expr& dy = AsTupleExpr(dymv, 3)[0];
const CallNode* call = orig_call.as<CallNode>();
const Expr& x = call->args[0];
const Expr& w = call->args[3];
const Expr& b = call->args[4];
const Expr& eps = call->args[6];
const Expr& ret = Call(op_dxwb, {dy, x, w, b, eps});
return {
TupleGetItem(ret, 0), NullValue<Expr>(), NullValue<Expr>(),
TupleGetItem(ret, 1), TupleGetItem(ret, 2),
};
}
RAF_OP_FUSED_GRAD("raf.op.batch_norm_train", BatchNormTrainGrad);
template <const char* GradOp>
Array<Expr> SoftmaxGradImpl(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get(GradOp);
const CallNode* call = orig_call.as<CallNode>();
const Expr& axis = call->args[1];
return {Call(op_dx, {y, dy, axis})};
}
const char SOFTMAX_DX[] = "raf.op.softmax_dx";
auto SoftmaxGrad = SoftmaxGradImpl<SOFTMAX_DX>;
RAF_OP_GRAD("raf.op.softmax", SoftmaxGrad);
Array<Expr> LogSoftmaxGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
using namespace raf::value;
static auto op_exp = Op::Get("raf.op.exp");
static auto op_sum = Op::Get("raf.op.sum");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_subtract = Op::Get("raf.op.subtract");
const CallNode* call = orig_call.as<CallNode>();
const Expr& axis = call->args[1];
Expr keep_dims = MakeConstant(ScalarValue::make((int64_t)1));
Expr e_1 = Call(op_sum, {dy, axis, keep_dims, MakeConstant(BoolValue::make(false))});
Expr e_2 = Call(op_multiply, {Call(op_exp, {y}), e_1});
Expr e_3 = Call(op_subtract, {dy, e_2, MakeNull(), MakeNull()});
return {e_3};
}
RAF_OP_GRAD("raf.op.log_softmax", LogSoftmaxGrad);
Array<Expr> LayerNormGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get("raf.op.layer_norm_dx");
const CallNode* call = orig_call.as<CallNode>();
CHECK(call != nullptr);
const Expr& x = call->args[0];
const Expr& scale = call->args[1];
const Expr& axis = call->args[3];
const Expr& eps = call->args[4];
const Expr& ret = Call(op_dx, {x, scale, dy, axis, eps});
const auto* kscale = scale.as<tvm::relay::ConstantNode>();
if (kscale && !static_cast<const ConstantNode*>(kscale)->value.defined()) {
// scale and bias are not learnable parameters.
return {ret};
}
return {
TupleGetItem(ret, 0),
TupleGetItem(ret, 1),
TupleGetItem(ret, 2),
};
}
RAF_OP_GRAD("raf.op.layer_norm", LayerNormGrad);
Array<Expr> ThresholdGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_dx = Op::Get("raf.op.threshold_dx");
const CallNode* call = orig_call.as<CallNode>();
CHECK(call != nullptr);
const Expr& x = call->args[0];
const Expr& threshold = call->args[1];
return {Call(op_dx, {x, dy, threshold})};
}
RAF_OP_GRAD("raf.op.threshold", ThresholdGrad);
Array<Expr> LayerNormTrainGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dymv) {
static auto op_dx = Op::Get("raf.op.layer_norm_train_dx");
const CallNode* call = orig_call.as<CallNode>();
CHECK(call != nullptr);
const Expr& dy = AsTupleExpr(dymv, 3)[0];
const Expr& x = call->args[0];
const Expr& scale = call->args[1];
const Expr& axis = call->args[3];
const Expr& eps = call->args[4];
auto mean = TupleGetItem(y, 1);
auto invvar = TupleGetItem(y, 2);
const Expr& ret = Call(op_dx, {x, scale, dy, mean, invvar, axis, eps});
const auto* kscale = scale.as<tvm::relay::ConstantNode>();
if (kscale && !static_cast<const ConstantNode*>(kscale)->value.defined()) {
// scale and bias are not learnable parameters.
return {ret, NullValue<Expr>(), NullValue<Expr>()};
}
return {TupleGetItem(ret, 0), TupleGetItem(ret, 1), TupleGetItem(ret, 2), NullValue<Expr>(),
NullValue<Expr>()};
}
RAF_OP_GRAD("raf.op.layer_norm_train", LayerNormTrainGrad);
Array<Expr> ReciprocalGrad(const Expr& orig_call, const Array<Expr> orig_args, const Var& y,
const Expr& dy) {
static auto op_div = Op::Get("raf.op.divide");
static auto op_multiply = Op::Get("raf.op.multiply");
static auto op_negative = Op::Get("raf.op.negative");
static auto op_ones = Op::Get("raf.op.ones_like");
const CallNode* call = orig_call.as<CallNode>();
CHECK_GE(call->args.size(), 1);
const Expr& x = call->args[0];
Call y_multiply_y = Call(op_multiply, {y, y});
Call dx = Call(op_negative, {y_multiply_y});
return {Call(op_multiply, {dy, dx})};
}
RAF_OP_GRAD("raf.op.reciprocal", ReciprocalGrad);
} // namespace grad
} // namespace op
} // namespace raf