Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Shape support scalar tensor #14315

Merged
merged 2 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ class TBlob {
<< "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
shape_.FlatTo2D(),
shape_[shape_.ndim() - 1],
stream);
}
/*!
Expand Down
100 changes: 78 additions & 22 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2019 by Contributors
* \file mxnet/tuple.h
* \brief Data structure Tuple and TShape to store dynamic sized shapes.
*/
Expand All @@ -39,11 +39,14 @@ namespace mxnet {

/*!
* \brief A dynamic sized array data structure that is optimized for storing
* small number of elements with same type.
* small number of elements with same type.
*
* Data will be stored in stack when number of elements is small.
* It is suitable to hold shape of Tensor.
*
* The ndim of a valid tuple is an integer in range [0, inf).
* ndim = 0 means the tuple is empty.
*
* \tparam ValueType The type of data stored inside tuple.
* \sa TShape
*/
Expand All @@ -61,7 +64,11 @@ class Tuple {
* \param s the source tuple
*/
inline Tuple(const Tuple<ValueType>& s) {
this->assign(s.begin(), s.end());
if (s.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(s.begin(), s.end());
}
}
/*!
* \brief constructor from initializer list
Expand Down Expand Up @@ -106,6 +113,7 @@ class Tuple {
inline void assign(RandomAccessIterator begin,
RandomAccessIterator end) {
this->SetDim(end - begin);
CHECK_GE(ndim(), 0);
std::copy(begin, end, this->begin());
}
/*!
Expand All @@ -124,7 +132,11 @@ class Tuple {
* \return reference of self
*/
inline Tuple<ValueType>& operator=(const Tuple<ValueType>& src) {
this->assign(src.begin(), src.end());
if (src.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(src.begin(), src.end());
}
return *this;
}
/*!
Expand All @@ -151,6 +163,7 @@ class Tuple {
*/
inline bool operator==(const Tuple<ValueType> &s) const {
if (ndim_ != s.ndim_) return false;
if (ndim() == -1) return true;
return std::equal(begin(), end(), s.begin());
}
/*!
Expand All @@ -177,23 +190,25 @@ class Tuple {
return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_);
}
/*! \return number of dimension of the tuple */
inline uint32_t ndim() const {
inline int ndim() const {
return ndim_;
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline ValueType& operator[](size_t i) {
inline ValueType& operator[](int i) {
CHECK(i >= 0 && i < ndim());
return begin()[i];
}
/*!
* \brief get corresponding index
* \param i dimension index
* \return the corresponding dimension size
*/
inline const ValueType& operator[](size_t i) const {
inline const ValueType& operator[](int i) const {
CHECK(i >= 0 && i < ndim());
return begin()[i];
}
/*!
Expand All @@ -220,6 +235,10 @@ class Tuple {
* \return the ostream
*/
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
if (t.ndim() == -1) {
os << "UNKNOWN_SHAPE";
return os;
}
os << '[';
const ValueType* begin = t.begin();
const ValueType* end = t.end();
Expand Down Expand Up @@ -316,48 +335,75 @@ class Tuple {

protected:
// stack cache size
static const uint32_t kStackCache = 4;
static const int kStackCache = 4;
/*! \brief number of dimension of the tuple */
uint32_t ndim_{0};
int ndim_{0};
/*! \brief number of cells allocated in data_heap_ */
uint32_t num_heap_allocated_{0};
int num_heap_allocated_{0};
/*! \brief in stack space used to store shape when it is small */
ValueType data_stack_[kStackCache];
/*! \brief space to store shape when dimension is big*/
ValueType* data_heap_{nullptr};
// internal function to change the dimension
inline void SetDim(uint32_t ndim) {
inline void SetDim(int ndim) {
CHECK_GE(ndim, -1) << "ndim cannot be less than -1, received " << ndim;
if (ndim > kStackCache &&
ndim > num_heap_allocated_) {
delete [] data_heap_;
data_heap_ = new ValueType[ndim];
num_heap_allocated_ = ndim;
} else if (ndim == -1 && data_heap_ != nullptr) {
delete [] data_heap_;
data_heap_ = nullptr;
num_heap_allocated_ = 0;
}
ndim_ = ndim;
}
};

/*!
* \brief A Shape class that is used to represent shape of each tensor.
*
* The ndim of a valid shape is an integer in range [-1, inf).
* ndim = -1 means the shape information is unknown and need to be inferred.
* ndim = 0 means the tensor with the shape is a scalar.
*
* The dimension size of a valid shape is an integer in range [-1, inf).
* dim_size = -1 means the size of that dimension is unknown and need to be inferred.
* dim_size = 0 means that dimension is empty.
*
* The definition of ndim = 0 and dim_size = 0 is consistent with NumPy.
*/
class TShape : public Tuple<dim_t> {
public:
/*! \brief default constructor */
TShape() = default;
TShape() {
this->SetDim(-1);
}
/*!
* constructor to construct a shape with all 1.
* TODO(junwu): The value should default to -1. Need to keep 1 for now
* for backward compatibility. Change it to -1 in the future when we can
* break backward compatibility.
* \param ndim the number of dimension
* \param value the dimension size for all dims
*/
inline TShape(uint32_t ndim) { // NOLINT(*)
inline TShape(int ndim, int value = 1) { // NOLINT(*)
this->SetDim(ndim);
std::fill_n(begin(), ndim, 1);
if (ndim > 0) {
std::fill_n(begin(), ndim, value);
}
}
/*!
* \brief copy constructor of TShape
* \param s source shape.
*/
inline TShape(const Tuple<dim_t>& s) { // NOLINT(*)
this->assign(s.begin(), s.end());
if (s.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(s.begin(), s.end());
}
}
/*!
* \brief constructor from initializer list
Expand Down Expand Up @@ -390,7 +436,11 @@ class TShape : public Tuple<dim_t> {
* \return self.
*/
inline TShape& operator=(const Tuple<dim_t>& src) {
this->assign(src.begin(), src.end());
if (src.ndim() == -1) {
this->SetDim(-1);
} else {
this->assign(src.begin(), src.end());
}
return *this;
}
/*!
Expand All @@ -404,9 +454,11 @@ class TShape : public Tuple<dim_t> {
}
/*! \return total number of elements in the shape */
inline size_t Size() const {
CHECK_GE(this->ndim(), 0) << "Shape is unknown.";
dim_t size = 1;
const dim_t* start = begin(), *fin = end();
for (const dim_t* it = start; it != fin; ++it) {
CHECK_GE(*it, 0) << "Shape dim size cannot be -1, which means unknown.";
size *= *it;
}
return size;
Expand All @@ -417,9 +469,11 @@ class TShape : public Tuple<dim_t> {
* \param dimend end dimension
*/
inline size_t ProdShape(int dimstart, int dimend) const {
CHECK_GE(this->ndim(), 0) << "Shape is unknown.";
dim_t num = 1;
const dim_t *d = this->data();
for (int i = dimstart; i < dimend; ++i) {
CHECK_GE(d[i], 0) << "Shape dim size cannot be -1, which means unknown.";
num *= d[i];
}
return num;
Expand Down Expand Up @@ -460,7 +514,7 @@ class TShape : public Tuple<dim_t> {
*/
template<int dim>
inline mshadow::Shape<dim> get() const {
CHECK_EQ(dim, static_cast<int>(ndim()))
CHECK_EQ(dim, ndim())
<< "dimension do not match target dimension " << dim << " vs " << ndim();
const dim_t *d = this->data();
mshadow::Shape<dim> s;
Expand All @@ -475,11 +529,12 @@ class TShape : public Tuple<dim_t> {
*/
inline mshadow::Shape<2> FlatTo2D(void) const {
mshadow::Shape<2> s;
if (ndim() == 0) return mshadow::Shape2(0, 0);
CHECK_GE(ndim(), 0);
if (ndim() == 0) return mshadow::Shape2(1, 1);
const dim_t *d = this->data();
s.shape_[1] = d[ndim() - 1];
dim_t ymax = 1;
for (size_t i = 1; i < ndim(); ++i) {
for (int i = 1; i < ndim(); ++i) {
ymax *= d[i - 1];
}
s.shape_[0] = ymax;
Expand All @@ -494,7 +549,8 @@ class TShape : public Tuple<dim_t> {
inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const {
CHECK(axis_end >= axis_begin);
mshadow::Shape<3> s;
if (ndim() == 0) return mshadow::Shape3(0, 0, 0);
CHECK_GE(ndim(), 0);
if (ndim() == 0) return mshadow::Shape3(1, 1, 1);
const dim_t *d = this->data();
s.shape_[0] = 1;
s.shape_[1] = 1;
Expand All @@ -506,7 +562,7 @@ class TShape : public Tuple<dim_t> {
for (size_t i = axis_begin; i <= axis_end; ++i) {
s.shape_[1] *= d[i];
}
for (size_t i = axis_end + 1; i < ndim(); ++i) {
for (int i = axis_end + 1; i < ndim(); ++i) {
s.shape_[2] *= d[i];
}
return s;
Expand Down Expand Up @@ -627,7 +683,7 @@ struct hash<mxnet::TShape> {
size_t operator()(const mxnet::TShape& val) const {
std::hash<uint32_t> hash_uint;
size_t res = hash_uint(val.ndim());
for (uint32_t i = 0; i < val.ndim(); ++i) {
for (int i = 0; i < val.ndim(); ++i) {
res = dmlc::HashCombine(res, val[i]);
}
return res;
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import engine
from .base import MXNetError
from . import base
from . import numpy
from . import contrib
from . import ndarray
from . import ndarray as nd
Expand Down
29 changes: 23 additions & 6 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sys
import inspect
import platform
import numpy as np
import numpy as _np

from . import libinfo

Expand All @@ -44,8 +44,8 @@
long = int
# pylint: enable=pointless-statement

integer_types = (int, long, np.int32, np.int64)
numeric_types = (float, int, long, np.generic)
integer_types = (int, long, _np.int32, _np.int64)
numeric_types = (float, int, long, _np.generic)
string_types = basestring,

if sys.version_info[0] > 2:
Expand Down Expand Up @@ -216,7 +216,7 @@ def _load_lib():
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
mx_real_t = np.float32
mx_real_t = _np.float32
NDArrayHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
Expand Down Expand Up @@ -455,7 +455,7 @@ def ctypes2numpy_shared(cptr, shape):
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
return _np.frombuffer(dbuffer, dtype=_np.float32).reshape(shape)


def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True):
Expand Down Expand Up @@ -560,7 +560,7 @@ def _as_list(obj):
return [obj]


_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']


def _get_op_name_prefix(op_name):
Expand Down Expand Up @@ -606,6 +606,13 @@ def _init_op_module(root_namespace, module_name, make_op_func):
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
# special handling of registering numpy ops
if module_name == 'ndarray':
numpy_module_name = "%s.numpy" % root_namespace
numpy_module = sys.modules[numpy_module_name]
else:
numpy_module_name = None
numpy_module = None
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
Expand Down Expand Up @@ -644,6 +651,16 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
# only register numpy ops under mxnet.numpy in imperative mode
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
func_name = name[len(op_name_prefix):]
function = make_op_func(hdl, name, func_name)
function.__module__ = numpy_module_name
setattr(numpy_module, function.__name__, function)
numpy_module.__all__.append(function.__name__)


def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

"""NDArray API of MXNet."""

from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray
from . import _internal, contrib, linalg, op, random, sparse, utils, image, ndarray, numpy
# pylint: disable=wildcard-import, redefined-builtin
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
Expand Down
Loading