From b4497e79339d33a1e80d25a06db84acab9ad946d Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 6 Mar 2019 11:30:27 -0800 Subject: [PATCH] [numpy] Shape support scalar tensor (#14315) * Support scalar and zero-size tensors with np.sum * Add sanity check when ndim is set --- include/mxnet/tensor_blob.h | 1 - include/mxnet/tuple.h | 100 +++++++--- python/mxnet/__init__.py | 1 + python/mxnet/base.py | 29 ++- python/mxnet/ndarray/__init__.py | 2 +- python/mxnet/ndarray/numpy.py | 18 ++ python/mxnet/numpy/__init__.py | 20 ++ python/mxnet/symbol/__init__.py | 2 +- python/mxnet/symbol/numpy.py | 18 ++ src/executor/graph_executor.cc | 3 +- src/executor/infer_graph_attr_pass.cc | 6 +- src/nnvm/plan_memory.cc | 6 +- src/operator/numpy/np_broadcast_reduce_op.h | 186 ++++++++++++++++++ .../numpy/np_broadcast_reduce_op_value.cc | 61 ++++++ .../numpy/np_broadcast_reduce_op_value.cu | 36 ++++ src/operator/operator_common.h | 20 +- 16 files changed, 466 insertions(+), 43 deletions(-) create mode 100644 python/mxnet/ndarray/numpy.py create mode 100644 python/mxnet/numpy/__init__.py create mode 100644 python/mxnet/symbol/numpy.py create mode 100644 src/operator/numpy/np_broadcast_reduce_op.h create mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cc create mode 100644 src/operator/numpy/np_broadcast_reduce_op_value.cu diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 7d059025b03e..45d4c7fda639 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -198,7 +198,6 @@ class TBlob { << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType::kFlag; return mshadow::Tensor(static_cast(dptr_), shape_.FlatTo2D(), - shape_[shape_.ndim() - 1], stream); } /*! diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index 7c1367333630..39c3c185e3c0 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * Copyright (c) 2016 by Contributors + * Copyright (c) 2019 by Contributors * \file mxnet/tuple.h * \brief Data structure Tuple and TShape to store dynamic sized shapes. */ @@ -39,11 +39,14 @@ namespace mxnet { /*! * \brief A dynamic sized array data structure that is optimized for storing - * small number of elements with same type. + * small number of elements with same type. * * Data will be stored in stack when number of elements is small. * It is suitable to hold shape of Tensor. * + * The ndim of a valid tuple is an integer in range [0, inf). + * ndim = 0 means the tuple is empty. + * * \tparam ValueType The type of data stored inside tuple. * \sa TShape */ @@ -61,7 +64,11 @@ class Tuple { * \param s the source tuple */ inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); + if (s.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(s.begin(), s.end()); + } } /*! * \brief constructor from initializer list @@ -106,6 +113,7 @@ class Tuple { inline void assign(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); + CHECK_GE(ndim(), 0); std::copy(begin, end, this->begin()); } /*! @@ -124,7 +132,11 @@ class Tuple { * \return reference of self */ inline Tuple& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); + if (src.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(src.begin(), src.end()); + } return *this; } /*! @@ -151,6 +163,7 @@ class Tuple { */ inline bool operator==(const Tuple &s) const { if (ndim_ != s.ndim_) return false; + if (ndim() == -1) return true; return std::equal(begin(), end(), s.begin()); } /*! @@ -177,7 +190,7 @@ class Tuple { return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); } /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { + inline int ndim() const { return ndim_; } /*! @@ -185,7 +198,8 @@ class Tuple { * \param i dimension index * \return the corresponding dimension size */ - inline ValueType& operator[](size_t i) { + inline ValueType& operator[](int i) { + CHECK(i >= 0 && i < ndim()); return begin()[i]; } /*! @@ -193,7 +207,8 @@ class Tuple { * \param i dimension index * \return the corresponding dimension size */ - inline const ValueType& operator[](size_t i) const { + inline const ValueType& operator[](int i) const { + CHECK(i >= 0 && i < ndim()); return begin()[i]; } /*! @@ -220,6 +235,10 @@ class Tuple { * \return the ostream */ friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + if (t.ndim() == -1) { + os << "UNKNOWN_SHAPE"; + return os; + } os << '['; const ValueType* begin = t.begin(); const ValueType* end = t.end(); @@ -316,22 +335,27 @@ class Tuple { protected: // stack cache size - static const uint32_t kStackCache = 4; + static const int kStackCache = 4; /*! \brief number of dimension of the tuple */ - uint32_t ndim_{0}; + int ndim_{0}; /*! \brief number of cells allocated in data_heap_ */ - uint32_t num_heap_allocated_{0}; + int num_heap_allocated_{0}; /*! \brief in stack space used to store shape when it is small */ ValueType data_stack_[kStackCache]; /*! \brief space to store shape when dimension is big*/ ValueType* data_heap_{nullptr}; // internal function to change the dimension - inline void SetDim(uint32_t ndim) { + inline void SetDim(int ndim) { + CHECK_GE(ndim, -1) << "ndim cannot be less than -1, received " << ndim; if (ndim > kStackCache && ndim > num_heap_allocated_) { delete [] data_heap_; data_heap_ = new ValueType[ndim]; num_heap_allocated_ = ndim; + } else if (ndim == -1 && data_heap_ != nullptr) { + delete [] data_heap_; + data_heap_ = nullptr; + num_heap_allocated_ = 0; } ndim_ = ndim; } @@ -339,25 +363,47 @@ class Tuple { /*! * \brief A Shape class that is used to represent shape of each tensor. + * + * The ndim of a valid shape is an integer in range [-1, inf). + * ndim = -1 means the shape information is unknown and need to be inferred. + * ndim = 0 means the tensor with the shape is a scalar. + * + * The dimension size of a valid shape is an integer in range [-1, inf). + * dim_size = -1 means the size of that dimension is unknown and need to be inferred. + * dim_size = 0 means that dimension is empty. + * + * The definition of ndim = 0 and dim_size = 0 is consistent with NumPy. */ class TShape : public Tuple { public: /*! \brief default constructor */ - TShape() = default; + TShape() { + this->SetDim(-1); + } /*! * constructor to construct a shape with all 1. + * TODO(junwu): The value should default to -1. Need to keep 1 for now + * for backward compatibility. Change it to -1 in the future when we can + * break backward compatibility. * \param ndim the number of dimension + * \param value the dimension size for all dims */ - inline TShape(uint32_t ndim) { // NOLINT(*) + inline TShape(int ndim, int value = 1) { // NOLINT(*) this->SetDim(ndim); - std::fill_n(begin(), ndim, 1); + if (ndim > 0) { + std::fill_n(begin(), ndim, value); + } } /*! * \brief copy constructor of TShape * \param s source shape. */ inline TShape(const Tuple& s) { // NOLINT(*) - this->assign(s.begin(), s.end()); + if (s.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(s.begin(), s.end()); + } } /*! * \brief constructor from initializer list @@ -390,7 +436,11 @@ class TShape : public Tuple { * \return self. */ inline TShape& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); + if (src.ndim() == -1) { + this->SetDim(-1); + } else { + this->assign(src.begin(), src.end()); + } return *this; } /*! @@ -404,9 +454,11 @@ class TShape : public Tuple { } /*! \return total number of elements in the shape */ inline size_t Size() const { + CHECK_GE(this->ndim(), 0) << "Shape is unknown."; dim_t size = 1; const dim_t* start = begin(), *fin = end(); for (const dim_t* it = start; it != fin; ++it) { + CHECK_GE(*it, 0) << "Shape dim size cannot be -1, which means unknown."; size *= *it; } return size; @@ -417,9 +469,11 @@ class TShape : public Tuple { * \param dimend end dimension */ inline size_t ProdShape(int dimstart, int dimend) const { + CHECK_GE(this->ndim(), 0) << "Shape is unknown."; dim_t num = 1; const dim_t *d = this->data(); for (int i = dimstart; i < dimend; ++i) { + CHECK_GE(d[i], 0) << "Shape dim size cannot be -1, which means unknown."; num *= d[i]; } return num; @@ -460,7 +514,7 @@ class TShape : public Tuple { */ template inline mshadow::Shape get() const { - CHECK_EQ(dim, static_cast(ndim())) + CHECK_EQ(dim, ndim()) << "dimension do not match target dimension " << dim << " vs " << ndim(); const dim_t *d = this->data(); mshadow::Shape s; @@ -475,11 +529,12 @@ class TShape : public Tuple { */ inline mshadow::Shape<2> FlatTo2D(void) const { mshadow::Shape<2> s; - if (ndim() == 0) return mshadow::Shape2(0, 0); + CHECK_GE(ndim(), 0); + if (ndim() == 0) return mshadow::Shape2(1, 1); const dim_t *d = this->data(); s.shape_[1] = d[ndim() - 1]; dim_t ymax = 1; - for (size_t i = 1; i < ndim(); ++i) { + for (int i = 1; i < ndim(); ++i) { ymax *= d[i - 1]; } s.shape_[0] = ymax; @@ -494,7 +549,8 @@ class TShape : public Tuple { inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; - if (ndim() == 0) return mshadow::Shape3(0, 0, 0); + CHECK_GE(ndim(), 0); + if (ndim() == 0) return mshadow::Shape3(1, 1, 1); const dim_t *d = this->data(); s.shape_[0] = 1; s.shape_[1] = 1; @@ -506,7 +562,7 @@ class TShape : public Tuple { for (size_t i = axis_begin; i <= axis_end; ++i) { s.shape_[1] *= d[i]; } - for (size_t i = axis_end + 1; i < ndim(); ++i) { + for (int i = axis_end + 1; i < ndim(); ++i) { s.shape_[2] *= d[i]; } return s; @@ -627,7 +683,7 @@ struct hash { size_t operator()(const mxnet::TShape& val) const { std::hash hash_uint; size_t res = hash_uint(val.ndim()); - for (uint32_t i = 0; i < val.ndim(); ++i) { + for (int i = 0; i < val.ndim(); ++i) { res = dmlc::HashCombine(res, val[i]); } return res; diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 374a3b50bbb5..8db83a286157 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -25,6 +25,7 @@ from . import engine from .base import MXNetError from . import base +from . import numpy from . import contrib from . import ndarray from . import ndarray as nd diff --git a/python/mxnet/base.py b/python/mxnet/base.py index feb4d70b6533..7793deacf44c 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -26,7 +26,7 @@ import sys import inspect import platform -import numpy as np +import numpy as _np from . import libinfo @@ -44,8 +44,8 @@ long = int # pylint: enable=pointless-statement -integer_types = (int, long, np.int32, np.int64) -numeric_types = (float, int, long, np.generic) +integer_types = (int, long, _np.int32, _np.int64) +numeric_types = (float, int, long, _np.generic) string_types = basestring, if sys.version_info[0] > 2: @@ -216,7 +216,7 @@ def _load_lib(): mx_uint = ctypes.c_uint mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) -mx_real_t = np.float32 +mx_real_t = _np.float32 NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p OpHandle = ctypes.c_void_p @@ -455,7 +455,7 @@ def ctypes2numpy_shared(cptr, shape): for s in shape: size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) - return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) + return _np.frombuffer(dbuffer, dtype=_np.float32).reshape(shape) def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True): @@ -560,7 +560,7 @@ def _as_list(obj): return [obj] -_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_'] +_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_'] def _get_op_name_prefix(op_name): @@ -606,6 +606,13 @@ def _init_op_module(root_namespace, module_name, make_op_func): # use mx.nd.contrib or mx.sym.contrib from now on contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name) contrib_module_old = sys.modules[contrib_module_name_old] + # special handling of registering numpy ops + if module_name == 'ndarray': + numpy_module_name = "%s.numpy" % root_namespace + numpy_module = sys.modules[numpy_module_name] + else: + numpy_module_name = None + numpy_module = None submodule_dict = {} for op_name_prefix in _OP_NAME_PREFIX_LIST: submodule_dict[op_name_prefix] =\ @@ -644,6 +651,16 @@ def _init_op_module(root_namespace, module_name, make_op_func): function.__module__ = contrib_module_name_old setattr(contrib_module_old, function.__name__, function) contrib_module_old.__all__.append(function.__name__) + elif op_name_prefix == '_numpy_' and numpy_module_name is not None: + # only register numpy ops under mxnet.numpy in imperative mode + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + # TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random. + func_name = name[len(op_name_prefix):] + function = make_op_func(hdl, name, func_name) + function.__module__ = numpy_module_name + setattr(numpy_module, function.__name__, function) + numpy_module.__all__.append(function.__name__) def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func): diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index f09908e894d5..a102399521cc 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -17,7 +17,7 @@ """NDArray API of MXNet.""" -from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray +from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy # pylint: disable=wildcard-import, redefined-builtin try: from .gen_op import * # pylint: disable=unused-wildcard-import diff --git a/python/mxnet/ndarray/numpy.py b/python/mxnet/ndarray/numpy.py new file mode 100644 index 000000000000..0826ac8aca7f --- /dev/null +++ b/python/mxnet/ndarray/numpy.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +__all__ = [] diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py new file mode 100644 index 000000000000..b1139a05791d --- /dev/null +++ b/python/mxnet/numpy/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +__all__ = [] diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index f438e4954aa9..326e4f5aff78 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -17,7 +17,7 @@ """Symbol API of MXNet.""" -from . import _internal, contrib, linalg, op, random, sparse, image, symbol +from . import _internal, contrib, linalg, op, random, sparse, image, symbol, numpy # pylint: disable=wildcard-import, redefined-builtin try: from .gen_op import * # pylint: disable=unused-wildcard-import diff --git a/python/mxnet/symbol/numpy.py b/python/mxnet/symbol/numpy.py new file mode 100644 index 000000000000..0826ac8aca7f --- /dev/null +++ b/python/mxnet/symbol/numpy.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +__all__ = [] diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 3d74bfb9a663..bef644187cf4 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -34,6 +34,7 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../operator/subgraph/subgraph_property.h" +#include "../operator/operator_common.h" namespace mxnet { namespace exec { @@ -966,7 +967,7 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; - CHECK_NE(vshape[eid].ndim(), 0U); + CHECK(mxnet::op::shape_is_known(vshape[eid])); CHECK_NE(vdtype[eid], -1); auto data_eid = idx.entry_id(nid, 0); // initialize based on storage_type diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 6a7fde62c2cf..aa72661e78b2 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -648,14 +648,14 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, std::move(graph), mxnet::TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", "shape_num_unknown_nodes", - [](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + [](const mxnet::TShape& s) { return !mxnet::op::shape_is_known(s); }, [](const mxnet::TShape& s) { - if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim + if (s.ndim() == -1) { return static_cast(1); } size_t ret = 0; for (const auto& val : s) { - if (val == 0) { + if (val == -1) { ++ret; } } diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 2b18f990c845..0dc7e6ddb1d9 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -30,6 +30,7 @@ #include #include #include "graph_algorithm.h" +#include "../operator/operator_common.h" namespace nnvm { namespace pass { @@ -75,7 +76,7 @@ class GraphAllocator { // request a free storage StorageID Request(int dev_id, int dtype, mxnet::TShape shape, uint32_t node_id) { - if (shape.ndim() == 0) return kBadStorageID; + if (!mxnet::op::shape_is_known(shape)) return kBadStorageID; // search memory block in [size / match_range_, size * match_range_) // TODO(tqchen) add size of the dtype, assume 4 bytes for now size_t size = shape.Size() * 4; @@ -267,8 +268,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, // only request memory for kBadStorageID if (storage[eid] == GraphAllocator::kBadStorageID) { auto &eshape = shape_vec[eid]; - size_t esize = 0; - if (eshape.ndim() != 0) esize = eshape.Size(); + size_t esize = eshape.Size(); eids.insert(std::make_pair(esize, eid)); } } diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h new file mode 100644 index 000000000000..bb2b7fca231c --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file broadcast_reduce_op.h + * \brief Function definition of broadcast and reduce operators + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ + +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct NumpyReduceAxesParam : public dmlc::Parameter { + dmlc::optional> axis; + dmlc::optional dtype; + bool keepdims; + dmlc::optional initial; + DMLC_DECLARE_PARAMETER(NumpyReduceAxesParam) { + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional>()) + .describe(R"code()code"); + DMLC_DECLARE_FIELD(dtype).set_default(dmlc::optional()) + .describe(""); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + } +}; + +inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape, + const dmlc::optional>& axis, + bool keepdims) { + // TODO(junwu): improve the logic + // If input is a scalar, output should be a scalar too + if (ishape.ndim() == 0) { + if (axis.has_value()) { + const nnvm::Tuple& axes = axis.value(); + if (axes.ndim() > 0) { + CHECK_EQ(axes.ndim(), 1); + CHECK(axes[0] == 0 || axes[0] == -1); + } + } + return TShape(0); + } + + // axis=None, do global reduction + if (!axis.has_value()) { + if (keepdims) { + return TShape(ishape.ndim(), 1); + } else { + return TShape(0); + } + } + + // axis = (), will return identity(input) + if (axis.value().ndim() == 0) { + return ishape; + } + + // axis has value + nnvm::Tuple axes(axis.value()); + for (index_t i = 0; i < axes.ndim(); i++) { + if (axes[i] < 0) { + axes[i] += ishape.ndim(); + } + } + std::sort(axes.begin(), axes.end()); + + for (index_t i = 1; i < axes.ndim(); i++) { + CHECK_LT(axes[i-1], axes[i]) + << "Reduction axes have duplicates " + << axes; + } + CHECK_LT(axes[axes.ndim()-1], ishape.ndim()) + << "Reduction axis " << axes[axes.ndim()-1] + << " Exceeds input dimensions " << ishape; + CHECK_GE(axes[0], 0) + << "Reduction axis " << axis.value() + << " Exceeds input dimensions " << ishape; + + TShape oshape; + if (keepdims) { + oshape = TShape(ishape); + } else { + oshape = TShape(ishape.ndim() - axes.ndim()); + } + + if (keepdims) { + for (index_t i = 0; i < axes.ndim(); ++i) { + oshape[axes[i]] = 1; + } + } else { + for (index_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) { + if (j < axes.ndim() && i == axes[j]) { + ++j; + continue; + } + oshape[k++] = ishape[i]; + } + } + return oshape; +} + +inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known(in_attrs->at(0))) { + return false; + } + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims)); + return shape_is_known(out_attrs->at(0)); +} + +template +void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + if (param.axis.has_value() && param.axis.value().ndim() == 0) { + UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); + } + TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true); + } + + ReduceAxesComputeImpl(ctx, inputs, req, outputs, small); +} + +template +inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + TShape small; + if (param.keepdims) { + small = inputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); + } + + BroadcastComputeImpl(attrs, ctx, inputs, req, outputs, small); + if (normalize) { + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor igrad = outputs[0].FlatTo1D(s); + igrad /= scalar(outputs[0].Size()/inputs[0].Size()); + }); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc new file mode 100644 index 000000000000..c028e2368737 --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_reduce_op_value.cc + * \brief CPU Implementation of broadcast and reduce functions based on value. + */ + +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam); + +NNVM_REGISTER_OP(_numpy_sum) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxesShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input") +.add_arguments(NumpyReduceAxesParam::__FIELDS__()) +.set_attr("FCompute", NumpyReduceAxesCompute) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}); + +NNVM_REGISTER_OP(_backward_numpy_sum) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_num_inputs(1) +.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu new file mode 100644 index 000000000000..c975b18226db --- /dev/null +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_reduce_op_value.cu + * \brief GPU Implementation of reduce functions based on value. + */ +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { +NNVM_REGISTER_OP(_numpy_sum) +.set_attr("FCompute", NumpyReduceAxesCompute); + +NNVM_REGISTER_OP(_backward_numpy_sum) +.set_attr("FCompute", NumpyReduceAxesBackwardUseNone); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index f629534dabd0..a461d2bc4cef 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -108,6 +108,16 @@ inline bool shape_is_none(const mxnet::TShape& x) { return x.ndim() == 0 || x.Size() == 0; } +/*! brief check if shape is known using the NumPy compatible definition. + * zero-dim and zero-size tensors are valid. -1 means unknown.*/ +inline bool shape_is_known(const TShape& x) { + if (x.ndim() == -1) return false; + for (int i = 0; i < x.ndim(); ++i) { + if (x[i] == -1) return false; + } + return true; +} + /*! \brief check if type is none (-1) */ inline bool type_is_none(const int& x) { return x == -1; @@ -159,16 +169,16 @@ inline std::string type_string(const int& x) { * \return whether x and y are compatible. */ inline bool shape_assign(mxnet::TShape *y, const mxnet::TShape& x) { - if (y->ndim() == 0) { + if (y->ndim() == -1) { *y = x; return true; } else if (y->ndim() != x.ndim()) { - return x.ndim() == 0; + return x.ndim() == -1; } else { - for (size_t i = 0; i < y->ndim(); ++i) { - if ((*y)[i] == 0) { + for (int i = 0; i < y->ndim(); ++i) { + if ((*y)[i] == -1) { (*y)[i] = x[i]; - } else if ((*y)[i] != x[i] && x[i] != 0) { + } else if ((*y)[i] != x[i] && x[i] >= 0) { return false; } }