Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] Shape support scalar tensor (#14315)
Browse files Browse the repository at this point in the history
* Support scalar and zero-size tensors with np.sum

* Add sanity check when ndim is set
  • Loading branch information
reminisce committed Apr 6, 2019
1 parent b68f18c commit 0f7bcf8
Show file tree
Hide file tree
Showing 16 changed files with 466 additions and 43 deletions.
1 change: 0 additions & 1 deletion include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ class TBlob {
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
shape_.FlatTo2D(),
shape_[shape_.ndim() - 1],
stream);
}
/*!
Expand Down
100 changes: 78 additions & 22 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2019 by Contributors
* \file mxnet/tuple.h
* \brief Data structure Tuple and TShape to store dynamic sized shapes.
*/
Expand All @@ -39,11 +39,14 @@ namespace mxnet {

/*!
* \brief A dynamic sized array data structure that is optimized for storing
* small number of elements with same type.
* small number of elements with same type.
*
* Data will be stored in stack when number of elements is small.
* It is suitable to hold shape of Tensor.
*
* The ndim of a valid tuple is an integer in range [0, inf).
* ndim = 0 means the tuple is empty.
*
* \tparam ValueType The type of data stored inside tuple.
* \sa TShape
*/
Expand All @@ -61,7 +64,11 @@ class Tuple {
* \param s the source tuple
*/
inline Tuple(const Tuple<ValueType>& s) {
this->assign(s.begin(), s.end());
if (s.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(s.begin(), s.end());
}
}
/*!
* \brief constructor from initializer list
Expand Down Expand Up @@ -106,6 +113,7 @@ class Tuple {
inline void assign(RandomAccessIterator begin,
RandomAccessIterator end) {
this->SetDim(end - begin);
CHECK_GE(ndim(), 0);
std::copy(begin, end, this->begin());
}
/*!
Expand All @@ -124,7 +132,11 @@ class Tuple {
* \return reference of self
*/
inline Tuple<ValueType>& operator=(const Tuple<ValueType>& src) {
this->assign(src.begin(), src.end());
if (src.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(src.begin(), src.end());
}
return *this;
}
/*!
Expand All @@ -151,6 +163,7 @@ class Tuple {
*/
inline bool operator==(const Tuple<ValueType> &s) const {
if (ndim_ != s.ndim_) return false;
if (ndim() == -1) return true;
return std::equal(begin(), end(), s.begin());
}
/*!
Expand All @@ -177,23 +190,25 @@ class Tuple {
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
}
/*! \return number of dimension of the tuple */
inline uint32_t ndim() const {
inline int ndim() const {
return ndim_;
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline ValueType& operator[](size_t i) {
inline ValueType& operator[](int i) {
CHECK(i >= 0 && i < ndim());
return begin()[i];
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const ValueType& operator[](size_t i) const {
inline const ValueType& operator[](int i) const {
CHECK(i >= 0 && i < ndim());
return begin()[i];
}
/*!
Expand All @@ -220,6 +235,10 @@ class Tuple {
* \return the ostream
*/
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
if (t.ndim() == -1) {
os << "UNKNOWN_SHAPE";
return os;
}
os << '[';
const ValueType* begin = t.begin();
const ValueType* end = t.end();
Expand Down Expand Up @@ -316,48 +335,75 @@ class Tuple {

protected:
// stack cache size
static const uint32_t kStackCache = 4;
static const int kStackCache = 4;
/*! \brief number of dimension of the tuple */
uint32_t ndim_{0};
int ndim_{0};
/*! \brief number of cells allocated in data_heap_ */
uint32_t num_heap_allocated_{0};
int num_heap_allocated_{0};
/*! \brief in stack space used to store shape when it is small */
ValueType data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
ValueType* data_heap_{nullptr};
// internal function to change the dimension
inline void SetDim(uint32_t ndim) {
inline void SetDim(int ndim) {
CHECK_GE(ndim, -1) << "ndim cannot be less than -1, received " << ndim;
if (ndim > kStackCache &&
ndim > num_heap_allocated_) {
delete [] data_heap_;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
} else if (ndim == -1 && data_heap_ != nullptr) {
delete [] data_heap_;
data_heap_ = nullptr;
num_heap_allocated_ = 0;
}
ndim_ = ndim;
}
};

/*!
* \brief A Shape class that is used to represent shape of each tensor.
*
* The ndim of a valid shape is an integer in range [-1, inf).
* ndim = -1 means the shape information is unknown and need to be inferred.
* ndim = 0 means the tensor with the shape is a scalar.
*
* The dimension size of a valid shape is an integer in range [-1, inf).
* dim_size = -1 means the size of that dimension is unknown and need to be inferred.
* dim_size = 0 means that dimension is empty.
*
* The definition of ndim = 0 and dim_size = 0 is consistent with NumPy.
*/
class TShape : public Tuple<dim_t> {
public:
/*! \brief default constructor */
TShape() = default;
TShape() {
this->SetDim(-1);
}
/*!
* constructor to construct a shape with all 1.
* TODO(junwu): The value should default to -1. Need to keep 1 for now
* for backward compatibility. Change it to -1 in the future when we can
* break backward compatibility.
* \param ndim the number of dimension
* \param value the dimension size for all dims
*/
inline TShape(uint32_t ndim) { // NOLINT(*)
inline TShape(int ndim, int value = 1) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
if (ndim > 0) {
std::fill_n(begin(), ndim, value);
}
}
/*!
* \brief copy constructor of TShape
* \param s source shape.
*/
inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
if (s.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(s.begin(), s.end());
}
}
/*!
* \brief constructor from initializer list
Expand Down Expand Up @@ -390,7 +436,11 @@ class TShape : public Tuple<dim_t> {
* \return self.
*/
inline TShape& operator=(const Tuple<dim_t>& src) {
this->assign(src.begin(), src.end());
if (src.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(src.begin(), src.end());
}
return *this;
}
/*!
Expand All @@ -404,9 +454,11 @@ class TShape : public Tuple<dim_t> {
}
/*! \return total number of elements in the shape */
inline size_t Size() const {
CHECK_GE(this->ndim(), 0) << "Shape is unknown.";
dim_t size = 1;
const dim_t* start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
CHECK_GE(*it, 0) << "Shape dim size cannot be -1, which means unknown.";
size *= *it;
}
return size;
Expand All @@ -417,9 +469,11 @@ class TShape : public Tuple<dim_t> {
* \param dimend end dimension
*/
inline size_t ProdShape(int dimstart, int dimend) const {
CHECK_GE(this->ndim(), 0) << "Shape is unknown.";
dim_t num = 1;
const dim_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
CHECK_GE(d[i], 0) << "Shape dim size cannot be -1, which means unknown.";
num *= d[i];
}
return num;
Expand Down Expand Up @@ -460,7 +514,7 @@ class TShape : public Tuple<dim_t> {
*/
template<int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, static_cast<int>(ndim()))
CHECK_EQ(dim, ndim())
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const dim_t *d = this->data();
mshadow::Shape<dim> s;
Expand All @@ -475,11 +529,12 @@ class TShape : public Tuple<dim_t> {
*/
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
CHECK_GE(ndim(), 0);
if (ndim() == 0) return mshadow::Shape2(1, 1);
const dim_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
dim_t ymax = 1;
for (size_t i = 1; i < ndim(); ++i) {
for (int i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
Expand All @@ -494,7 +549,8 @@ class TShape : public Tuple<dim_t> {
inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
CHECK_GE(ndim(), 0);
if (ndim() == 0) return mshadow::Shape3(1, 1, 1);
const dim_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
Expand All @@ -506,7 +562,7 @@ class TShape : public Tuple<dim_t> {
for (size_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (size_t i = axis_end + 1; i < ndim(); ++i) {
for (int i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
Expand Down Expand Up @@ -627,7 +683,7 @@ struct hash<mxnet::TShape> {
size_t operator()(const mxnet::TShape& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
for (uint32_t i = 0; i < val.ndim(); ++i) {
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
return res;
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import engine
from .base import MXNetError
from . import base
from . import numpy
from . import contrib
from . import ndarray
from . import ndarray as nd
Expand Down
29 changes: 23 additions & 6 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sys
import inspect
import platform
import numpy as np
import numpy as _np

from . import libinfo

Expand All @@ -44,8 +44,8 @@
long = int
# pylint: enable=pointless-statement

integer_types = (int, long, np.int32, np.int64)
numeric_types = (float, int, long, np.generic)
integer_types = (int, long, _np.int32, _np.int64)
numeric_types = (float, int, long, _np.generic)
string_types = basestring,

if sys.version_info[0] > 2:
Expand Down Expand Up @@ -216,7 +216,7 @@ def _load_lib():
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = np.float32
mx_real_t = _np.float32
NDArrayHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
Expand Down Expand Up @@ -455,7 +455,7 @@ def ctypes2numpy_shared(cptr, shape):
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
return _np.frombuffer(dbuffer, dtype=_np.float32).reshape(shape)


def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True):
Expand Down Expand Up @@ -560,7 +560,7 @@ def _as_list(obj):
return [obj]


_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']


def _get_op_name_prefix(op_name):
Expand Down Expand Up @@ -606,6 +606,13 @@ def _init_op_module(root_namespace, module_name, make_op_func):
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
# special handling of registering numpy ops
if module_name == 'ndarray':
numpy_module_name = "%s.numpy" % root_namespace
numpy_module = sys.modules[numpy_module_name]
else:
numpy_module_name = None
numpy_module = None
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
Expand Down Expand Up @@ -644,6 +651,16 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
# only register numpy ops under mxnet.numpy in imperative mode
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
func_name = name[len(op_name_prefix):]
function = make_op_func(hdl, name, func_name)
function.__module__ = numpy_module_name
setattr(numpy_module, function.__name__, function)
numpy_module.__all__.append(function.__name__)


def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""NDArray API of MXNet."""

from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray
from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy
# pylint: disable=wildcard-import, redefined-builtin
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
Expand Down
Loading

0 comments on commit 0f7bcf8

Please sign in to comment.