|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +/*! |
| 21 | + * Copyright (c) 2015 by Contributors |
| 22 | + * \file broadcast_reduce_op.h |
| 23 | + * \brief Function definition of broadcast and reduce operators |
| 24 | + */ |
| 25 | +#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ |
| 26 | +#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ |
| 27 | + |
| 28 | +#include <algorithm> |
| 29 | +#include <vector> |
| 30 | +#include "../tensor/broadcast_reduce_op.h" |
| 31 | + |
| 32 | +namespace mxnet { |
| 33 | +namespace op { |
| 34 | + |
| 35 | +struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> { |
| 36 | + dmlc::optional<mxnet::Tuple<int>> axis; |
| 37 | + dmlc::optional<int> dtype; |
| 38 | + bool keepdims; |
| 39 | + dmlc::optional<double> initial; |
| 40 | + DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) { |
| 41 | + DMLC_DECLARE_FIELD(axis) |
| 42 | + .set_default(dmlc::optional<mxnet::Tuple<int>>()) |
| 43 | + .describe("Axis or axes along which a sum is performed. The default, axis=None, will sum " |
| 44 | + "all of the elements of the input array. If axis is negative it counts from the " |
| 45 | + "last to the first axis."); |
| 46 | + DMLC_DECLARE_FIELD(dtype) |
| 47 | + .add_enum("float16", mshadow::kFloat16) |
| 48 | + .add_enum("float32", mshadow::kFloat32) |
| 49 | + .add_enum("float64", mshadow::kFloat64) |
| 50 | + .add_enum("int8", mshadow::kInt8) |
| 51 | + .add_enum("int32", mshadow::kInt32) |
| 52 | + .add_enum("int64", mshadow::kInt64) |
| 53 | + .set_default(dmlc::optional<int>()) |
| 54 | + .describe("The type of the returned array and of the accumulator in which the elements are " |
| 55 | + "summed. The dtype of a is used by default unless a has an integer dtype of less " |
| 56 | + "precision than the default platform integer. In that case, if a is signed then " |
| 57 | + "the platform integer is used while if a is unsigned then an unsigned integer of " |
| 58 | + "the same precision as the platform integer is used."); |
| 59 | + DMLC_DECLARE_FIELD(keepdims).set_default(false) |
| 60 | + .describe("If this is set to `True`, the reduced axes are left " |
| 61 | + "in the result as dimension with size one."); |
| 62 | + DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>()) |
| 63 | + .describe("Starting value for the sum."); |
| 64 | + } |
| 65 | +}; |
| 66 | + |
| 67 | +inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, |
| 68 | + const dmlc::optional<mxnet::Tuple<int>>& axis, |
| 69 | + bool keepdims) { |
| 70 | + // TODO(junwu): improve the logic |
| 71 | + // If input is a scalar, output should be a scalar too |
| 72 | + if (ishape.ndim() == 0) { |
| 73 | + if (axis.has_value()) { |
| 74 | + const mxnet::Tuple<int>& axes = axis.value(); |
| 75 | + if (axes.ndim() > 0) { |
| 76 | + CHECK_EQ(axes.ndim(), 1); |
| 77 | + CHECK(axes[0] == 0 || axes[0] == -1); |
| 78 | + } |
| 79 | + } |
| 80 | + return TShape(0, -1); |
| 81 | + } |
| 82 | + |
| 83 | + // axis=None, do global reduction |
| 84 | + if (!axis.has_value()) { |
| 85 | + if (keepdims) { |
| 86 | + return TShape(ishape.ndim(), 1); |
| 87 | + } else { |
| 88 | + return TShape(0, -1); |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + // axis = (), will return identity(input) |
| 93 | + if (axis.value().ndim() == 0) { |
| 94 | + return ishape; |
| 95 | + } |
| 96 | + |
| 97 | + // axis has value |
| 98 | + mxnet::Tuple<int> axes(axis.value()); |
| 99 | + for (index_t i = 0; i < axes.ndim(); i++) { |
| 100 | + if (axes[i] < 0) { |
| 101 | + axes[i] += ishape.ndim(); |
| 102 | + } |
| 103 | + } |
| 104 | + std::sort(axes.begin(), axes.end()); |
| 105 | + |
| 106 | + for (index_t i = 1; i < axes.ndim(); i++) { |
| 107 | + CHECK_LT(axes[i-1], axes[i]) |
| 108 | + << "Reduction axes have duplicates " |
| 109 | + << axes; |
| 110 | + } |
| 111 | + CHECK_LT(axes[axes.ndim()-1], ishape.ndim()) |
| 112 | + << "Reduction axis " << axes[axes.ndim()-1] |
| 113 | + << " Exceeds input dimensions " << ishape; |
| 114 | + CHECK_GE(axes[0], 0) |
| 115 | + << "Reduction axis " << axis.value() |
| 116 | + << " Exceeds input dimensions " << ishape; |
| 117 | + |
| 118 | + TShape oshape; |
| 119 | + if (keepdims) { |
| 120 | + oshape = TShape(ishape); |
| 121 | + } else { |
| 122 | + oshape = TShape(ishape.ndim() - axes.ndim(), -1); |
| 123 | + } |
| 124 | + |
| 125 | + if (keepdims) { |
| 126 | + for (index_t i = 0; i < axes.ndim(); ++i) { |
| 127 | + oshape[axes[i]] = 1; |
| 128 | + } |
| 129 | + } else { |
| 130 | + for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { |
| 131 | + if (j < axes.ndim() && i == axes[j]) { |
| 132 | + ++j; |
| 133 | + continue; |
| 134 | + } |
| 135 | + oshape[k++] = ishape[i]; |
| 136 | + } |
| 137 | + } |
| 138 | + return oshape; |
| 139 | +} |
| 140 | + |
| 141 | +inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, |
| 142 | + std::vector<TShape> *in_attrs, |
| 143 | + std::vector<TShape> *out_attrs) { |
| 144 | + CHECK_EQ(in_attrs->size(), 1U); |
| 145 | + CHECK_EQ(out_attrs->size(), 1U); |
| 146 | + if (!shape_is_known(in_attrs->at(0))) { |
| 147 | + return false; |
| 148 | + } |
| 149 | + const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed); |
| 150 | + SHAPE_ASSIGN_CHECK(*out_attrs, 0, |
| 151 | + NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); |
| 152 | + return shape_is_known(out_attrs->at(0)); |
| 153 | +} |
| 154 | + |
| 155 | +template<bool safe_acc_hint = false> |
| 156 | +inline bool NeedSafeAcc(int itype, int otype) { |
| 157 | + bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64); |
| 158 | + return safe_acc_hint && rule; |
| 159 | +} |
| 160 | + |
| 161 | +template<typename xpu, typename reducer, bool safe_acc_hint = false, bool normalize = false, |
| 162 | + typename OP = op::mshadow_op::identity> |
| 163 | +void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, |
| 164 | + const OpContext& ctx, |
| 165 | + const std::vector<TBlob>& inputs, |
| 166 | + const std::vector<OpReqType>& req, |
| 167 | + const std::vector<TBlob>& outputs) { |
| 168 | + const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed); |
| 169 | + if (param.initial.has_value()) { |
| 170 | + LOG(FATAL) << "initial is not supported yet"; |
| 171 | + } |
| 172 | + if (param.axis.has_value() && param.axis.value().ndim() == 0) { |
| 173 | + UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs); |
| 174 | + } |
| 175 | + TShape small; |
| 176 | + if (param.keepdims) { |
| 177 | + small = outputs[0].shape_; |
| 178 | + } else { |
| 179 | + small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); |
| 180 | + } |
| 181 | + |
| 182 | + if (NeedSafeAcc<safe_acc_hint>(inputs[0].type_flag_, outputs[0].type_flag_)) { |
| 183 | + ReduceAxesComputeImpl<xpu, reducer, true, normalize, OP>(ctx, inputs, req, outputs, small); |
| 184 | + } else { |
| 185 | + ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small); |
| 186 | + } |
| 187 | +} |
| 188 | + |
| 189 | +template<typename xpu, bool normalize = false> |
| 190 | +inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, |
| 191 | + const OpContext& ctx, |
| 192 | + const std::vector<TBlob>& inputs, |
| 193 | + const std::vector<OpReqType>& req, |
| 194 | + const std::vector<TBlob>& outputs) { |
| 195 | + using namespace mshadow; |
| 196 | + using namespace mshadow::expr; |
| 197 | + const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed); |
| 198 | + TShape small; |
| 199 | + if (param.keepdims) { |
| 200 | + small = inputs[0].shape_; |
| 201 | + } else { |
| 202 | + small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); |
| 203 | + } |
| 204 | + |
| 205 | + BroadcastComputeImpl<xpu>(attrs, ctx, inputs, req, outputs, small); |
| 206 | + if (normalize) { |
| 207 | + Stream<xpu> *s = ctx.get_stream<xpu>(); |
| 208 | + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, { |
| 209 | + Tensor<xpu, 1, IType> igrad = outputs[0].FlatTo1D<xpu, IType>(s); |
| 210 | + printf("output size: %lu input_size: %lu\n", outputs[0].Size(), inputs[0].Size()); |
| 211 | + igrad /= scalar<IType>(outputs[0].Size()/inputs[0].Size()); |
| 212 | + }); |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +} // namespace op |
| 217 | +} // namespace mxnet |
| 218 | +#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ |
0 commit comments