This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
softmax_output.cc
289 lines (255 loc) · 12.1 KB
/
softmax_output.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file softmax_output.cc
* \brief
* \author Bing Xu, Zhang Rong A
*/
#include "./softmax_output-inl.h"
#if MXNET_USE_ONEDNN == 1
#include "operator/nn/dnnl/dnnl_base-inl.h"
#include "operator/nn/dnnl/dnnl_softmax_output-inl.h"
#endif
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SoftmaxOutputParam);
struct SoftmaxOutputGrad {
const char* op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
std::vector<nnvm::NodeEntry> out_data(n->num_outputs());
for (uint32_t i = 0; i < out_data.size(); ++i) {
out_data[i] = nnvm::NodeEntry{n, i, 0};
}
std::vector<nnvm::NodeEntry> heads;
heads.push_back(out_data[softmaxout_enum::kOut]);
heads.push_back(n->inputs[softmaxout_enum::kLabel]);
nnvm::ObjectPtr gnode = nnvm::Node::Create();
gnode->inputs = std::move(heads);
gnode->control_deps.emplace_back(n);
gnode->attrs = n->attrs;
gnode->attrs.op = nnvm::Op::Get("_backward_SoftmaxOutput");
gnode->attrs.name = n->attrs.name + "_backward";
std::vector<nnvm::NodeEntry> in_grad(2);
in_grad[0] = nnvm::NodeEntry{gnode, 0, 0};
in_grad[1] = nnvm::NodeEntry{gnode, 1, 0};
return in_grad;
}
};
static inline std::vector<std::string> ListArguments() {
return {"data", "label"};
}
static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
}
} else {
// Input type is defined but output type is not: forward inference
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
return true;
}
static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
using namespace mshadow;
const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]";
const mxnet::TShape& dshape = in_shape->at(0);
if (!mxnet::ndim_is_known(dshape))
return false;
// label.shape == data.shape: use probability as label
if (dshape != (*in_shape)[softmaxout_enum::kLabel]) {
if (param.multi_output) {
mxnet::TShape lshape1 = Shape2(dshape[0], dshape.Size() / dshape[0] / dshape[1]);
mxnet::TShape lshape2(dshape.ndim() - 1, -1);
lshape2[0] = dshape[0];
for (int i = 2; i < dshape.ndim(); ++i)
lshape2[i - 1] = dshape[i];
mxnet::TShape lshape3 = dshape;
lshape3[1] = 1;
if (!mxnet::ndim_is_known(in_shape->at(softmaxout_enum::kLabel))) {
in_shape->at(softmaxout_enum::kLabel) = lshape1;
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) {
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) {
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) {
} else {
std::ostringstream os;
os << "Expecting " << lshape1 << " or " << lshape2 << ". But got "
<< in_shape->at(softmaxout_enum::kLabel);
throw InferShapeError(os.str(), softmaxout_enum::kLabel);
}
} else {
mxnet::TShape label_shape(dshape.ndim() - 1, -1);
for (int i = 0; i + 1 < dshape.ndim(); ++i)
label_shape[i] = dshape[i];
SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape);
}
}
out_shape->clear();
out_shape->push_back(dshape);
return true;
}
#if MXNET_USE_ONEDNN == 1
inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}
void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
if (SupportDNNLSoftmaxOutput(param, inputs[0]) && !ctx.is_train) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLSoftmaxOutputForward, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif
NNVM_REGISTER_OP(SoftmaxOutput)
.describe(R"code(Computes the gradient of cross entropy loss with respect to softmax output.
- This operator computes the gradient in two steps.
The cross entropy loss does not actually need to be computed.
- Applies softmax function on the input array.
- Computes and returns the gradient of cross entropy loss w.r.t. the softmax output.
- The softmax function, cross entropy loss and gradient is given by:
- Softmax Function:
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
- Cross Entropy Function:
.. math:: \text{CE(label, output)} = - \sum_i \text{label}_i \log(\text{output}_i)
- The gradient of cross entropy loss w.r.t softmax output:
.. math:: \text{gradient} = \text{output} - \text{label}
- During forward propagation, the softmax function is computed for each instance in the input array.
For general *N*-D input arrays with shape :math:`(d_1, d_2, ..., d_n)`. The size is
:math:`s=d_1 \cdot d_2 \cdot \cdot \cdot d_n`. We can use the parameters `preserve_shape`
and `multi_output` to specify the way to compute softmax:
- By default, `preserve_shape` is ``false``. This operator will reshape the input array
into a 2-D array with shape :math:`(d_1, \frac{s}{d_1})` and then compute the softmax function for
each row in the reshaped array, and afterwards reshape it back to the original shape
:math:`(d_1, d_2, ..., d_n)`.
- If `preserve_shape` is ``true``, the softmax function will be computed along
the last axis (`axis` = ``-1``).
- If `multi_output` is ``true``, the softmax function will be computed along
the second axis (`axis` = ``1``).
- During backward propagation, the gradient of cross-entropy loss w.r.t softmax output array is computed.
The provided label can be a one-hot label array or a probability label array.
- If the parameter `use_ignore` is ``true``, `ignore_label` can specify input instances
with a particular label to be ignored during backward propagation. **This has no effect when
softmax `output` has same shape as `label`**.
Example::
data = [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
label = [1,0,2,3]
ignore_label = 1
SoftmaxOutput(data=data, label = label,\
multi_output=true, use_ignore=true,\
ignore_label=ignore_label)
## forward softmax output
[[ 0.0320586 0.08714432 0.23688284 0.64391428]
[ 0.25 0.25 0.25 0.25 ]
[ 0.25 0.25 0.25 0.25 ]
[ 0.25 0.25 0.25 0.25 ]]
## backward gradient output
[[ 0. 0. 0. 0. ]
[-0.75 0.25 0.25 0.25]
[ 0.25 0.25 -0.75 0.25]
[ 0.25 0.25 0.25 -0.75]]
## notice that the first row is all 0 because label[0] is 1, which is equal to ignore_label.
- The parameter `grad_scale` can be used to rescale the gradient, which is often used to
give each loss function different weights.
- This operator also supports various ways to normalize the gradient by `normalization`,
The `normalization` is applied if softmax output has different shape than the labels.
The `normalization` mode can be set to the followings:
- ``'null'``: do nothing.
- ``'batch'``: divide the gradient by the batch size.
- ``'valid'``: divide the gradient by the number of instances which are not ignored.
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxOutputParam>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxOutputStorageType)
.set_attr<bool>("TIsDNNL", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxOutputComputeExCPU)
#endif
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "label"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxOutputShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOutputType)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxOutputCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", SoftmaxOutputGrad{"_backward_SoftmaxOutput"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "Input array.")
.add_argument("label", "NDArray-or-Symbol", "Ground truth label.")
.add_arguments(SoftmaxOutputParam::__FIELDS__());
// Softmax symbol is renamed to SoftmaxOutput and deprecated since Dec, 2015
NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax");
NNVM_REGISTER_OP(_backward_SoftmaxOutput)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr_parser(ParamParser<SoftmaxOutputParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxOutputGradCompute<cpu>);
} // namespace op
} // namespace mxnet