Skip to content

Commit 562f148

Browse files
committed
[Do not review] [Do not merge] New numpy-compatible sum (apache#14739)
* Add numpy namespace and initial impl of np.sum (not complete) * Clean up * Fix import error * numpy sum * add test and backward data type support * add license to test_numpy_op.py * improve test to reduce flakiness * fix sanity build * extra numeric test and imperative test * add error message for initial argument
1 parent 51c07e5 commit 562f148

13 files changed

+563
-18
lines changed

python/mxnet/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
2828
from . import base
2929
from . import library
30+
from . import numpy
3031
from . import contrib
3132
from . import ndarray
3233
from . import ndarray as nd

python/mxnet/base.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _as_list(obj):
561561
return [obj]
562562

563563

564-
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
564+
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
565565

566566

567567
def _get_op_name_prefix(op_name):
@@ -607,6 +607,15 @@ def _init_op_module(root_namespace, module_name, make_op_func):
607607
# use mx.nd.contrib or mx.sym.contrib from now on
608608
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
609609
contrib_module_old = sys.modules[contrib_module_name_old]
610+
# special handling of registering numpy ops
611+
# only expose mxnet.numpy.op_name to users for imperative mode.
612+
# Symbolic mode should be used in Gluon.
613+
if module_name == 'ndarray':
614+
numpy_module_name = "%s.numpy" % root_namespace
615+
numpy_module = sys.modules[numpy_module_name]
616+
else:
617+
numpy_module_name = None
618+
numpy_module = None
610619
submodule_dict = {}
611620
for op_name_prefix in _OP_NAME_PREFIX_LIST:
612621
submodule_dict[op_name_prefix] =\
@@ -645,6 +654,16 @@ def _init_op_module(root_namespace, module_name, make_op_func):
645654
function.__module__ = contrib_module_name_old
646655
setattr(contrib_module_old, function.__name__, function)
647656
contrib_module_old.__all__.append(function.__name__)
657+
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
658+
# only register numpy ops under mxnet.numpy in imperative mode
659+
hdl = OpHandle()
660+
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
661+
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
662+
func_name = name[len(op_name_prefix):]
663+
function = make_op_func(hdl, name, func_name)
664+
function.__module__ = numpy_module_name
665+
setattr(numpy_module, function.__name__, function)
666+
numpy_module.__all__.append(function.__name__)
648667

649668

650669
def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):

python/mxnet/ndarray/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
"""NDArray API of MXNet."""
1919

20-
from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray
20+
from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy
2121
# pylint: disable=wildcard-import, redefined-builtin
2222
try:
2323
from .gen_op import * # pylint: disable=unused-wildcard-import

python/mxnet/ndarray/numpy.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
__all__ = []

python/mxnet/numpy/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env python
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
__all__ = []

python/mxnet/symbol/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
"""Symbol API of MXNet."""
1919

20-
from . import _internal, contrib, linalg, op, random, sparse, image, symbol
20+
from . import _internal, contrib, linalg, op, random, sparse, image, symbol, numpy
2121
# pylint: disable=wildcard-import, redefined-builtin
2222
try:
2323
from .gen_op import * # pylint: disable=unused-wildcard-import

python/mxnet/symbol/numpy.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
__all__ = []
+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2015 by Contributors
22+
* \file broadcast_reduce_op.h
23+
* \brief Function definition of broadcast and reduce operators
24+
*/
25+
#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
26+
#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
27+
28+
#include <algorithm>
29+
#include <vector>
30+
#include "../tensor/broadcast_reduce_op.h"
31+
32+
namespace mxnet {
33+
namespace op {
34+
35+
struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
36+
dmlc::optional<mxnet::Tuple<int>> axis;
37+
dmlc::optional<int> dtype;
38+
bool keepdims;
39+
dmlc::optional<double> initial;
40+
DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) {
41+
DMLC_DECLARE_FIELD(axis)
42+
.set_default(dmlc::optional<mxnet::Tuple<int>>())
43+
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
44+
"all of the elements of the input array. If axis is negative it counts from the "
45+
"last to the first axis.");
46+
DMLC_DECLARE_FIELD(dtype)
47+
.add_enum("float16", mshadow::kFloat16)
48+
.add_enum("float32", mshadow::kFloat32)
49+
.add_enum("float64", mshadow::kFloat64)
50+
.add_enum("int8", mshadow::kInt8)
51+
.add_enum("int32", mshadow::kInt32)
52+
.add_enum("int64", mshadow::kInt64)
53+
.set_default(dmlc::optional<int>())
54+
.describe("The type of the returned array and of the accumulator in which the elements are "
55+
"summed. The dtype of a is used by default unless a has an integer dtype of less "
56+
"precision than the default platform integer. In that case, if a is signed then "
57+
"the platform integer is used while if a is unsigned then an unsigned integer of "
58+
"the same precision as the platform integer is used.");
59+
DMLC_DECLARE_FIELD(keepdims).set_default(false)
60+
.describe("If this is set to `True`, the reduced axes are left "
61+
"in the result as dimension with size one.");
62+
DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
63+
.describe("Starting value for the sum.");
64+
}
65+
};
66+
67+
inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
68+
const dmlc::optional<mxnet::Tuple<int>>& axis,
69+
bool keepdims) {
70+
// TODO(junwu): improve the logic
71+
// If input is a scalar, output should be a scalar too
72+
if (ishape.ndim() == 0) {
73+
if (axis.has_value()) {
74+
const mxnet::Tuple<int>& axes = axis.value();
75+
if (axes.ndim() > 0) {
76+
CHECK_EQ(axes.ndim(), 1);
77+
CHECK(axes[0] == 0 || axes[0] == -1);
78+
}
79+
}
80+
return TShape(0, -1);
81+
}
82+
83+
// axis=None, do global reduction
84+
if (!axis.has_value()) {
85+
if (keepdims) {
86+
return TShape(ishape.ndim(), 1);
87+
} else {
88+
return TShape(0, -1);
89+
}
90+
}
91+
92+
// axis = (), will return identity(input)
93+
if (axis.value().ndim() == 0) {
94+
return ishape;
95+
}
96+
97+
// axis has value
98+
mxnet::Tuple<int> axes(axis.value());
99+
for (index_t i = 0; i < axes.ndim(); i++) {
100+
if (axes[i] < 0) {
101+
axes[i] += ishape.ndim();
102+
}
103+
}
104+
std::sort(axes.begin(), axes.end());
105+
106+
for (index_t i = 1; i < axes.ndim(); i++) {
107+
CHECK_LT(axes[i-1], axes[i])
108+
<< "Reduction axes have duplicates "
109+
<< axes;
110+
}
111+
CHECK_LT(axes[axes.ndim()-1], ishape.ndim())
112+
<< "Reduction axis " << axes[axes.ndim()-1]
113+
<< " Exceeds input dimensions " << ishape;
114+
CHECK_GE(axes[0], 0)
115+
<< "Reduction axis " << axis.value()
116+
<< " Exceeds input dimensions " << ishape;
117+
118+
TShape oshape;
119+
if (keepdims) {
120+
oshape = TShape(ishape);
121+
} else {
122+
oshape = TShape(ishape.ndim() - axes.ndim(), -1);
123+
}
124+
125+
if (keepdims) {
126+
for (index_t i = 0; i < axes.ndim(); ++i) {
127+
oshape[axes[i]] = 1;
128+
}
129+
} else {
130+
for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
131+
if (j < axes.ndim() && i == axes[j]) {
132+
++j;
133+
continue;
134+
}
135+
oshape[k++] = ishape[i];
136+
}
137+
}
138+
return oshape;
139+
}
140+
141+
inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
142+
std::vector<TShape> *in_attrs,
143+
std::vector<TShape> *out_attrs) {
144+
CHECK_EQ(in_attrs->size(), 1U);
145+
CHECK_EQ(out_attrs->size(), 1U);
146+
if (!shape_is_known(in_attrs->at(0))) {
147+
return false;
148+
}
149+
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
150+
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
151+
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
152+
return shape_is_known(out_attrs->at(0));
153+
}
154+
155+
template<bool safe_acc_hint = false>
156+
inline bool NeedSafeAcc(int itype, int otype) {
157+
bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
158+
return safe_acc_hint && rule;
159+
}
160+
161+
template<typename xpu, typename reducer, bool safe_acc_hint = false, bool normalize = false,
162+
typename OP = op::mshadow_op::identity>
163+
void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
164+
const OpContext& ctx,
165+
const std::vector<TBlob>& inputs,
166+
const std::vector<OpReqType>& req,
167+
const std::vector<TBlob>& outputs) {
168+
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
169+
if (param.initial.has_value()) {
170+
LOG(FATAL) << "initial is not supported yet";
171+
}
172+
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
173+
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
174+
}
175+
TShape small;
176+
if (param.keepdims) {
177+
small = outputs[0].shape_;
178+
} else {
179+
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
180+
}
181+
182+
if (NeedSafeAcc<safe_acc_hint>(inputs[0].type_flag_, outputs[0].type_flag_)) {
183+
ReduceAxesComputeImpl<xpu, reducer, true, normalize, OP>(ctx, inputs, req, outputs, small);
184+
} else {
185+
ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small);
186+
}
187+
}
188+
189+
template<typename xpu, bool normalize = false>
190+
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
191+
const OpContext& ctx,
192+
const std::vector<TBlob>& inputs,
193+
const std::vector<OpReqType>& req,
194+
const std::vector<TBlob>& outputs) {
195+
using namespace mshadow;
196+
using namespace mshadow::expr;
197+
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
198+
TShape small;
199+
if (param.keepdims) {
200+
small = inputs[0].shape_;
201+
} else {
202+
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
203+
}
204+
205+
BroadcastComputeImpl<xpu>(attrs, ctx, inputs, req, outputs, small);
206+
if (normalize) {
207+
Stream<xpu> *s = ctx.get_stream<xpu>();
208+
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
209+
Tensor<xpu, 1, IType> igrad = outputs[0].FlatTo1D<xpu, IType>(s);
210+
printf("output size: %lu input_size: %lu\n", outputs[0].Size(), inputs[0].Size());
211+
igrad /= scalar<IType>(outputs[0].Size()/inputs[0].Size());
212+
});
213+
}
214+
}
215+
216+
} // namespace op
217+
} // namespace mxnet
218+
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_

0 commit comments

Comments
 (0)