From e65e5778432efd09c355cb00a9bc0d3eab91e83d Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 25 Apr 2017 04:54:03 +0800 Subject: [PATCH] refactor conversion api (#25) * add multiple stage test * refactor GetFCompute * adjust jenkins python version. pass all unit test * register cast storage type op * refactor conversion interface --- Jenkinsfile | 6 +- include/mxnet/c_api.h | 5 -- include/mxnet/ndarray.h | 48 ++--------- include/mxnet/op_attr_types.h | 2 - nnvm | 2 +- python/mxnet/sparse_ndarray.py | 53 ++++++------ src/c_api/c_api.cc | 12 +-- src/c_api/c_api_common.h | 2 +- src/c_api/c_api_ndarray.cc | 49 +++++------- src/c_api/c_api_symbolic.cc | 2 +- src/common/utils.h | 58 +++++++++++++- src/executor/attach_op_execs_pass.cc | 80 ++++++------------- src/executor/exec_pass.h | 2 +- src/executor/graph_executor.cc | 26 +++--- src/operator/elemwise_op_common.h | 35 +++++--- src/operator/operator_common.h | 9 ++- src/operator/tensor/elemwise_binary_op.h | 28 +++---- .../tensor/elemwise_binary_op_basic.cc | 3 +- src/operator/tensor/elemwise_unary_op.cc | 15 ++++ src/operator/tensor/elemwise_unary_op.h | 54 +++++++++++-- src/operator/tensor/init_op.h | 9 +-- tests/cpp/ndarray_test.cc | 79 ++++++++++-------- tests/python/unittest/test_sparse_ndarray.py | 2 +- tests/python/unittest/test_sparse_operator.py | 34 ++++++++ 24 files changed, 332 insertions(+), 283 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index c94ade88614a..f8974933d034 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -201,9 +201,9 @@ del /Q *.7z // Python unittest for CPU def python_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest" // sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train" } } @@ -211,7 +211,7 @@ def python_ut(docker_type) { // both CPU and GPU def python_gpu_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu" // sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu" } } diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 428c11e5b57d..ad1d578035a5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -272,11 +272,6 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, int *aux_types, NDArrayHandle *out); -// TEMPORARY API FOR TESTING PURPOSE. Conversion should be an op instead -MXNET_DLL int MXNDArrayConvert(NDArrayHandle in, - int storage_type, - NDArrayHandle *out); - /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 75ff00156e32..8e8d9634d177 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -98,7 +98,7 @@ class NDArray { } /*! \brief constructor for NDArray with chunk type */ - NDArray(NDArrayStorageType storage_type, const TShape &shape, Context ctx, + NDArray(const NDArrayStorageType storage_type, const TShape &shape, Context ctx, bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}) : shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { @@ -127,6 +127,7 @@ class NDArray { Mkl_mem_ = std::make_shared(); #endif } + // TODO this constructor should be removed NDArray(NDArray data, const std::vector aux_data, Context ctx, NDArrayStorageType storage_type, const TShape &shape) : ptr_(std::make_shared(data, aux_data, ctx, storage_type)), shape_(shape), @@ -137,11 +138,6 @@ class NDArray { CHECK(aux_data.size() == 1) << "Multiple aux_data not supported yet"; } - template - NDArray ConvertTo(NDArrayStorageType storage_type, mshadow::Stream *s) const { - CHECK_EQ(storage_type, kDefaultStorage) << "other storage type not supported yet"; - return ToDefault(s); - } /*! * \return the shape of current NDArray. */ @@ -487,44 +483,10 @@ class NDArray { private: friend class autograd::AutogradRuntime; - // Make a copy of the ndarray in dense format - template - NDArray ToDefault(mshadow::Stream* s) const { - NDArray result(shape_, ptr_->ctx, false, dtype()); - this->WaitToRead(); - if (storage_type() == kDefaultStorage) { - MSHADOW_TYPE_SWITCH(dtype(), DType, { - mshadow::Copy(result.data().FlatTo1D(), data().FlatTo1D()); - }); - return result; - } - CHECK(storage_type() == kRowSparseStorage); - MSHADOW_TYPE_SWITCH(dtype(), DType, { - MSHADOW_TYPE_SWITCH(aux_type(rowsparse::kIdx), AuxType, { - // Fill in zeros - result.data().FlatTo1D(s) = 0; - result.data().shape_ = shape_; - // data() is not empty - if (storage_shape().ndim() != 0) { - // Copy over - auto in_data = data().FlatTo2D(s); - auto out_data = result.data().FlatTo2D(s); - auto num_rows = aux_shape(rowsparse::kIdx)[0]; - auto in_idx = aux_data(rowsparse::kIdx).FlatTo1D(s); - for (size_t i = 0; i < num_rows; i += 1) { - mshadow::Copy(out_data[in_idx[i]], in_data[i], s); - } - } - }); - }); - return result; - } - /*! \brief the real data chunk that backs NDArray */ // shandle is used to store the actual values in the NDArray // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - // TODO(haibin) Also specify the capacity & size of the chunk, we don't want to resize it // every time a new element is added to a non default storage /*! \brief storage handle from storage engine. for non-default storage, shandle stores the data(value) array. @@ -551,7 +513,7 @@ class NDArray { // context of data Context ctx; // The shape of the chunk data. - // This might not be the same shape as the NDArray, since the chunk may be sparse. + // This might not be the same shape as the NDArray, since the storage may be sparse. TShape storage_shape; // The shape of aux data. The default value for the shape is 0. std::vector aux_shapes; @@ -660,7 +622,7 @@ class NDArray { CHECK_EQ(storage_type, kRowSparseStorage) << "Not yet implemented"; // calculate size, perform allocation if (delay_alloc) { - // For row sparse chunk, aux_shape indicates the number of rows to allocate + // For row sparse storage, aux_shape indicates the number of rows to allocate auto aux_shape = aux_shapes[0]; CHECK_EQ(aux_shape.ndim(), 1); auto num_rows = aux_shape[0]; @@ -670,7 +632,7 @@ class NDArray { shandle = Storage::Get()->Alloc(dbytes, ctx); aux_handles.push_back(Storage::Get()->Alloc(aux_bytes, ctx)); delay_alloc = false; - // Initialize aux_shape and shape + // Initialize shapes this->aux_shapes = aux_shapes; storage_shape = shape; storage_shape[0] = num_rows; diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index d446e7a46d91..51c921859e26 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -66,8 +66,6 @@ using FCompute = std::function" and "FComputeEx" * e.g FComputeEx - * TODO should probably change const std::vector& outputs to - std::vector *outputs */ using FComputeEx = std::function' % (self.__class__.__name__) + #def __repr__(self): def __reduce__(self): return (SparseNDArray, (None,), self.__getstate__()) def __add__(self, other): @@ -160,26 +156,22 @@ def _sync_copyfrom(self, source_array): def _slice(self, start, stop): raise Exception('Not implemented for SparseND yet!') def _at(self, idx): - raise Exception('Not implemented for SparseND yet!') + raise Exception('at operator for SparseND is not supported.') def reshape(self, shape): raise Exception('Not implemented for SparseND yet!') def broadcast_to(self, shape): raise Exception('Not implemented for SparseND yet!') #def wait_to_read(self): - #inherited from parent #@property #def shape(self): - #inherited from parent @property def size(self): raise Exception('Not implemented for SparseND yet!') - @property - def context(self): - raise Exception('Not implemented for SparseND yet!') - @property - def dtype(self): - raise Exception('Not implemented for SparseND yet!') + #@property + #def context(self): + #@property + #def dtype(self): @property # pylint: disable= invalid-name, undefined-variable def T(self): @@ -202,9 +194,8 @@ def as_in_context(self, context): def to_dense(self): return to_dense(self) -#TODO(haibin) also add aux_types. Not tested yet. -#We need a to_dense method to test it -def csr(values, indptr, idx, shape, ctx=Context.default_ctx, dtype=mx_real_t): +#TODO We need a to_dense method to test it +def csr(values, indptr, idx, shape, ctx=Context.default_ctx, dtype=mx_real_t, aux_types=None): ''' constructor ''' hdl = NDArrayHandle() #TODO currently only supports NDArray input @@ -212,6 +203,7 @@ def csr(values, indptr, idx, shape, ctx=Context.default_ctx, dtype=mx_real_t): assert(isinstance(index, NDArray)) indices = c_array(NDArrayHandle, [idx.handle, indptr.handle]) num_aux = mx_uint(2) + # TODO create an empty handle with specified types, then assign values check_call(_LIB.MXNDArrayCreateSparse( values.handle, num_aux, indices, c_array(mx_uint, shape), @@ -226,13 +218,14 @@ def csr(values, indptr, idx, shape, ctx=Context.default_ctx, dtype=mx_real_t): # pylint: enable= no-member #TODO(haibin) also specify aux_types -def row_sparse(values, index, shape, ctx=Context.default_ctx, dtype=mx_real_t): +def row_sparse(values, index, shape, ctx=Context.default_ctx, dtype=mx_real_t, aux_types=None): ''' constructor ''' hdl = NDArrayHandle() assert(isinstance(values, NDArray)) assert(isinstance(index, NDArray)) indices = c_array(NDArrayHandle, [index.handle]) num_aux = mx_uint(1) + # TODO create an empty handle with specified types, then assign values check_call(_LIB.MXNDArrayCreateSparse( values.handle, num_aux, indices, c_array(mx_uint, shape), @@ -245,7 +238,7 @@ def row_sparse(values, index, shape, ctx=Context.default_ctx, dtype=mx_real_t): ctypes.byref(hdl))) return SparseNDArray(hdl) -def array(values, index_list, storage_type, shape, ctx=None, dtype=mx_real_t): +def array(values, index_list, storage_type, shape, ctx=None, dtype=mx_real_t, aux_types=None): # TODO check input array types. Assume NDArray class for now # TODO support other types assert(storage_type == 'row_sparse') @@ -253,18 +246,19 @@ def array(values, index_list, storage_type, shape, ctx=None, dtype=mx_real_t): shape = (shape, ) if ctx is None: ctx = Context.default_ctx - arr = row_sparse(values, index_list[0], shape, ctx=ctx, dtype=dtype) + arr = row_sparse(values, index_list[0], shape, ctx=ctx, dtype=dtype, aux_types=aux_types) return arr # Temporary function for testing purpose def to_dense(source): - hdl = NDArrayHandle() + return ndarray.cast_storage(source, storage_type=1) + '''hdl = NDArrayHandle() check_call(_LIB.MXNDArrayConvert( source.handle, _STORAGE_TYPE_STR_TO_ID['default'], ctypes.byref(hdl))) - return ndarray.NDArray(handle=hdl, writable=True) + return ndarray.NDArray(handle=hdl, writable=True)''' -def zeros(shape, storage_type, ctx=None, dtype=mx_real_t): +def zeros(shape, storage_type, ctx=None, dtype=mx_real_t, aux_types=None): """Return a new array of given shape and type, filled with zeros. Parameters @@ -294,12 +288,13 @@ def zeros(shape, storage_type, ctx=None, dtype=mx_real_t): """ if ctx is None: ctx = Context.default_ctx - if storage_type == 'row_sparse': - # pylint: disable= no-member, protected-access - out = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, - aux_types=_STORAGE_AUX_TYPES['row_sparse'])) - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype) + assert(storage_type == 'row_sparse') + if aux_types == None: + aux_types = _STORAGE_AUX_TYPES['row_sparse'] + # pylint: disable= no-member, protected-access + out = SparseNDArray(_new_alloc_handle(storage_type, shape, ctx, + aux_types=aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) # pylint: enable= no-member, protected-access _STORAGE_TYPE_TO_ND_CLASS = { diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f1ba19f7ef35..646ef05c2a1f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -131,6 +131,7 @@ int MXNDArrayCreate(const mx_uint *shape, API_END(); } +// TODO remove this API int MXNDArrayCreateSparse(NDArrayHandle data, mx_uint num_aux, NDArrayHandle *aux_vec, @@ -155,16 +156,6 @@ int MXNDArrayCreateSparse(NDArrayHandle data, API_END(); } -// TODO(haibin) Should also consider context -int MXNDArrayConvert(NDArrayHandle in, - int storage_type, - NDArrayHandle *out) { - API_BEGIN(); - NDArray* nd = reinterpret_cast(in); - *out = new NDArray(nd->ConvertTo(static_cast(storage_type), nullptr)); - API_END(); -} - int MXNDArrayCreateEx(const mx_uint *shape, mx_uint ndim, int dev_type, @@ -363,7 +354,6 @@ int MXNDArrayGetStorageType(NDArrayHandle handle, int *out_storage_type) { API_BEGIN(); NDArray *arr = static_cast(handle); - // Check is_none? if (!arr->is_none()) { *out_storage_type = arr->storage_type(); } else { diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index cd825afaaf63..27bce311f980 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -58,7 +58,7 @@ struct MXAPIThreadLocalEntry { std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning type flags */ std::vector arg_types, out_types, aux_types; - /*! \brief result holder for returning chunk types */ + /*! \brief result holder for returning storage types */ std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index e6dd977dc5bb..9112fbfb99f3 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -131,7 +131,7 @@ void SetShapeType(const nnvm::Op* op, std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); - static auto& inferchunktype = nnvm::Op::GetAttr("FInferStorageType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -183,8 +183,8 @@ void SetShapeType(const nnvm::Op* op, } out_storage_types.push_back(storage_type); } - if (inferchunktype.count(op)) { - CHECK(inferchunktype[op](attrs, &in_storage_types, &out_storage_types)); + if (inferstorage.count(op)) { + CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); } else { // LOG(INFO) << "FInferStorageType not present."; @@ -291,18 +291,20 @@ void PushFCompute(const FCompute& fn, std::vector input_blobs, output_blobs; std::vector tmp_nds; - if (ctx.dev_mask() == gpu::kDevMask) { - // mshadow::Stream *s = rctx.get_stream(); - // common::PrepDefaultBlobs(ndinputs, ndoutputs, &input_blobs, &output_blobs, - // &tmp_nds, true, s); - } else { - mshadow::Stream *s = rctx.get_stream(); - common::PrepDefaultBlobs(ndinputs, ndoutputs, &input_blobs, &output_blobs, - &tmp_nds, true, s); - } OpContext opctx{false, rctx, engine::CallbackOnComplete(), requested}; + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + common::PrepDefaultBlobs(ndinputs, ndoutputs, &input_blobs, + &output_blobs, &tmp_nds, true, opctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + common::PrepDefaultBlobs(ndinputs, ndoutputs, &input_blobs, + &output_blobs, &tmp_nds, true, opctx); + } std::vector req(output_blobs.size(), kWriteTo); fn(attrs, opctx, input_blobs, req, output_blobs); if (ctx.dev_mask() == gpu::kDevMask) { @@ -408,9 +410,6 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, int num_params, const char **param_keys, const char **param_vals) { - static auto& fcpu = nnvm::Op::GetAttr("FCompute"); - static auto& fnd_cpu_row_sparse = nnvm::Op::GetAttr("FComputeEx"); - static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); const nnvm::Op* op = static_cast(creator); @@ -446,23 +445,11 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, SetDependency(&read_vars, &write_vars, &requested, &auxidx, op, attrs, ctx, ndinputs, ndoutputs); - FCompute fn; - FComputeEx fn_nd; - // dispatch based on ctx and storage_type - // std::cout << "I - dispatch: " << storage_type << " for op " << op->name << std::endl; - if (ctx.dev_mask() == cpu::kDevMask && fnd_cpu_row_sparse.count(op) && - storage_type == kRowSparseStorage) { - fn_nd = fnd_cpu_row_sparse[op]; - // std::cout << "I - fnd_cpu_row_sparse dispatched." << std::endl; - } else if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) { - // std::cout << "I - fcpu dispatched." << std::endl; - fn = fcpu[op]; - } else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) { - fn = fgpu[op]; - } + FCompute fn = common::GetFCompute(op, ctx); + FComputeEx fcompute_ex = common::GetFComputeEx(op, ctx, storage_type); - if (fn_nd) { - PushFComputeEx(fn_nd, op, attrs, ctx, read_vars, write_vars, + if (fcompute_ex) { + PushFComputeEx(fcompute_ex, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); } else if (fn) { if (AutogradRuntime::Get()->IsRecording()) { diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 3c6b0b032997..66adeae16bf0 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -497,7 +497,7 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } - +// TODO(haibin) refactor with infer_type int MXSymbolInferStorageType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/common/utils.h b/src/common/utils.h index 8138495e7ae1..c76abe36545d 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -19,12 +19,24 @@ #include #include #include +#include #include namespace mxnet { +// forward declaration +namespace op { +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +} + namespace common { #if DMLC_USE_CXX11 +// TODO move to op_utils.h template inline void PrepDefaultBlobs(const std::vector& ndinputs, const std::vector& ndoutputs, @@ -32,10 +44,11 @@ inline void PrepDefaultBlobs(const std::vector& ndinputs, std::vector *output_blobs, std::vector *tmp_nds, bool alloc_outputs, - mshadow::Stream *s) { + const OpContext& ctx) { for (auto& i : ndinputs) { if (i.storage_type() != kDefaultStorage) { - NDArray tmp_nd = i.ConvertTo(kDefaultStorage, s); + NDArray tmp_nd(i.shape(), i.ctx(), false); + op::CastStorageComputeEx({}, ctx, {i}, {}, {tmp_nd}); tmp_nds->push_back(tmp_nd); input_blobs->push_back(tmp_nd.data()); } else { @@ -56,7 +69,7 @@ inline void PrepVars(const std::vector &nds, } } -// Only check input storage type for now. +// Only dispatch based on input storage type for now. inline NDArrayStorageType GetDispatchStorageType(const nnvm::StorageTypeVector& vstorage_type) { NDArrayStorageType dispatch_storage_type = kDefaultStorage; for (auto& i : vstorage_type) { @@ -69,6 +82,45 @@ inline NDArrayStorageType GetDispatchStorageType(const nnvm::StorageTypeVector& return dispatch_storage_type; } +inline FCompute GetFCompute(const Op* op, Context ctx) { + static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); + static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); + if (ctx.dev_mask() == cpu::kDevMask) { + return fcompute_cpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fcompute_gpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + +inline FComputeEx GetFComputeEx(const Op* op, Context ctx, + NDArrayStorageType storage_type) { + static auto& fcpu_rs = nnvm::Op::GetAttr("FComputeEx"); + static auto& fgpu_rs = nnvm::Op::GetAttr("FComputeEx"); + static auto& fcpu_csr = nnvm::Op::GetAttr("FComputeEx"); + static auto& fgpu_csr = nnvm::Op::GetAttr("FComputeEx"); + if (storage_type == kDefaultStorage) return nullptr; + if (ctx.dev_mask() == cpu::kDevMask) { + if (storage_type == kRowSparseStorage) return fcpu_rs.get(op, nullptr); + if (storage_type == kCSRStorage) return fcpu_csr.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + if (storage_type == kRowSparseStorage) return fgpu_rs.get(op, nullptr); + if (storage_type == kCSRStorage) return fgpu_csr.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + +inline bool HasDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; + } + } + return false; +} + // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index cff55aa97b46..3cc8b01a3bba 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -26,7 +26,7 @@ namespace exec { // forward executor class ForwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); #if MKL_EXPERIMENTAL == 1 @@ -69,7 +69,7 @@ class ForwardOpExecutor : public OpExecutor { // backward executor class BackwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; op_->Backward(op_ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); @@ -137,14 +137,21 @@ class BackwardOpExecutor : public OpExecutor { // fcompute executor executor class FComputeExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { // std::cout << "FCompute::Run" << std::endl; op_ctx.run_ctx = rctx; - // TODO(haibin) Get stream? - // mshadow::Stream *s = rctx.get_stream(); if (!initialized) { - common::PrepDefaultBlobs(in_array, out_array, &in_data_, &out_data_, - &tmp_nds_, true, nullptr); + if (is_gpu) { +#if MXNET_USE_CUDA + common::PrepDefaultBlobs(in_array, out_array, &in_data_, + &out_data_, &tmp_nds_, true, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + common::PrepDefaultBlobs(in_array, out_array, &in_data_, + &out_data_, &tmp_nds_, true, op_ctx); + } initialized = true; } fcompute_(attrs_, op_ctx, in_data_, req, out_data_); @@ -165,21 +172,6 @@ class FComputeExecutor : public OpExecutor { : fcompute_(fcompute), attrs_(attrs) { } - static FCompute GetFCompute(const Op* op, Context ctx) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); - if (ctx.dev_mask() == cpu::kDevMask) { - // if (fcompute_cpu.get(op, nullptr) != nullptr) - // std::cout << "FCompute for op " << op->name << std::endl; - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } - } - private: FCompute fcompute_; NodeAttrs attrs_; @@ -191,7 +183,7 @@ class FComputeExecutor : public OpExecutor { // fcomputend executor class FComputeExExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { // std::cout << "FComputeExExecutor::Run" << std::endl; op_ctx.run_ctx = rctx; fcompute_(attrs_, op_ctx, in_data_, req, out_data_); @@ -208,27 +200,6 @@ class FComputeExExecutor : public OpExecutor { : fcompute_(fcompute), attrs_(attrs) { } - static FComputeEx GetFComputeEx(const Op* op, Context ctx, - NDArrayStorageType dispatch_storage_type) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FComputeEx"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FComputeEx"); - if (dispatch_storage_type != kRowSparseStorage) { - return nullptr; - } - if (ctx.dev_mask() == cpu::kDevMask) { -#if EXECUTOR_DEBUG - if (fcompute_cpu.get(op, nullptr) != nullptr) - LOG(INFO) << "FComputeEx for op " << op->name; -#endif - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } - } - private: FComputeEx fcompute_; NodeAttrs attrs_; @@ -266,9 +237,9 @@ Graph AttachOpExecs(Graph g) { mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); } NDArrayStorageType dispatch_stype = static_cast(dispatch_stypes[i]); - FCompute fcompute = FComputeExecutor::GetFCompute(inode.source->op(), vctx[i]); - FComputeEx fcompute_ndarray = - FComputeExExecutor::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stype); + FCompute fcompute = common::GetFCompute(inode.source->op(), vctx[i]); + FComputeEx fcompute_ex = + common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stype); if (fcreate_layer_op.count(inode.source->op())) { std::vector ishape; std::vector itype; @@ -293,15 +264,16 @@ Graph AttachOpExecs(Graph g) { dynamic_cast(ret[fwd_id].get())->op_, mxnet::op::OpPropGetOpProperty(inode.source->attrs), mutate_index); - } else if (fcompute_ndarray != nullptr) { - // Also check the storage type - // std::cout << "S - fcompute_ndarray" << std::endl; - ret[i] = std::make_shared(fcompute_ndarray, inode.source->attrs); + } else if (fcompute_ex != nullptr) { +#if EXECUTOR_DEBUG + LOG(INFO) << "FComputeEx for op " << inode.source->op()->name; +#endif + ret[i] = std::make_shared(fcompute_ex, inode.source->attrs); } else if (fcompute != nullptr) { - // std::cout << "S - fcompute" << std::endl; +#if EXECUTOR_DEBUG + LOG(INFO) << "FCompute for op " << inode.source->op()->name; +#endif ret[i] = std::make_shared(fcompute, inode.source->attrs); - } else { - // LOG(INFO) << "FCompute not registered " << inode.source->op()->name; } } g.attrs["op_execs"] = std::make_shared(ret); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 4b79f76d4d26..f32908b428d2 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -51,7 +51,7 @@ class OpExecutor { * This function call do not synchronize the stream. * \param rctx The runtime context passed in by environment. */ - virtual void Run(RunContext rctx) = 0; + virtual void Run(RunContext rctx, bool is_gpu) = 0; /*! \return the execution type */ virtual Operator::ExecType exec_type() const = 0; }; diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 5f71979af424..6415db0b5c82 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -658,20 +658,15 @@ void GraphExecutor::InitCachedOps() { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index); exec->out_array.push_back(data_entry_[eid]); - if (vstorage_type[eid] != kDefaultStorage) { - // FIXME temporarily disable inplace for sparse ndarrays - exec->req.push_back(kWriteTo); + if (addto_entry.at(eid) != 0) { + exec->req.push_back(kAddTo); + } else if (vstorage_inplace[eid] >= 0) { + exec->req.push_back(kWriteInplace); + } else if (vstorage_inplace[eid] == -2) { + // -2 indicate that the entry is never referenced. + exec->req.push_back(kNullOp); } else { - if (addto_entry.at(eid) != 0) { - exec->req.push_back(kAddTo); - } else if (vstorage_inplace[eid] >= 0) { - exec->req.push_back(kWriteInplace); - } else if (vstorage_inplace[eid] == -2) { - // -2 indicate that the entry is never referenced. - exec->req.push_back(kNullOp); - } else { - exec->req.push_back(kWriteTo); - } + exec->req.push_back(kWriteTo); } } } @@ -709,7 +704,6 @@ void GraphExecutor::InitCachedOps() { std::inserter(all_vars, all_vars.end())); // setup exec vars Engine::Get()->PushSync([exec](RunContext rctx) { - // LOG(INFO) << "Setup.. "; exec->Setup(); }, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0, PROFILER_MESSAGE("SetupExec")); @@ -719,7 +713,7 @@ void GraphExecutor::InitCachedOps() { if (is_async) { exec->op_ctx.async_on_complete = on_complete; } - exec->Run(ctx); + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { @@ -951,7 +945,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, RunContext ctx, Engine::CallbackOnComplete on_complete) { // Run all opr in the sub-graph for (auto &exec : exec_list) { - exec->Run(ctx); + exec->Run(ctx, is_gpu); } if (is_gpu) { #if MXNET_USE_CUDA diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index b7c29c95be72..d296d177abce 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -50,32 +50,31 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } -// Only inferring output chunk types from input chunk types now -// It never returns false -// Implemented for add & sub now +// Only inferring output storage types from input for now template + bool (*assign)(AttrType*, const AttrType&), bool reverse_infer, + bool enable_fallback> inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs, const AttrType& none) { + // LOG(INFO) << "ElemwiseStorageAttr for " << attrs.name; auto deduce = [&](std::vector *vec, const char *name, AttrType& result, - bool &fallback) { + bool fallback) { for (size_t i = 0; i < vec->size(); ++i) { // LOG(INFO) << "deduce " << (*vec)[i]; - if (assign(&result, (*vec)[i]) == false) { - fallback = true; + CHECK_NE((*vec)[i], -1) << "ElemwiseStorageAttr assumes all input storage types are known"; + if (assign(&result, (*vec)[i]) == false && fallback) { result = kDefaultStorage; // LOG(INFO) << "ElemwiseStorageAttr Fallback"; return; } } }; - bool fallback = false; AttrType dattr = none; - deduce(in_attrs, "input", dattr, fallback); + deduce(in_attrs, "input", dattr, enable_fallback); if (reverse_infer) { - // TODO(haibin) also do reverse pass + LOG(FATAL) << "not implemented yet"; } auto write = [&](std::vector *vec, const char *name) { for (size_t i = 0; i < vec->size(); ++i) { @@ -84,7 +83,6 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; } }; - // write(in_attrs, "input"); write(out_attrs, "output"); if (is_none(dattr)) return false; return true; @@ -116,8 +114,19 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - // TODO(haibin) replace type_is_none to storage_type_is_none & type_assign - return ElemwiseStorageAttr( + // TODO(haibin) not doing inverse infer yet + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs, -1); +} + +// Useful for binary multiplication / division +template +inline bool ElemwiseSameStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + return ElemwiseStorageAttr( attrs, in_attrs, out_attrs, -1); } diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index fd3b84c68448..e9ba5f81c339 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -287,21 +287,22 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } -template +// TODO move to op_util.h +template void FComputeExFallback(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs, - mshadow::Stream* s, FCompute fcompute) { std::vector input_blobs, output_blobs; std::vector tmp_nds; - common::PrepDefaultBlobs(inputs, outputs, &input_blobs, &output_blobs, - &tmp_nds, false, s); + common::PrepDefaultBlobs(inputs, outputs, &input_blobs, + &output_blobs, &tmp_nds, false, ctx); fcompute(attrs, ctx, input_blobs, req, output_blobs); } + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index e35e3681d145..f04f6f8030b2 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -32,9 +32,10 @@ void BinaryCompute(const nnvm::NodeAttrs& attrs, }); } -// TODO(haibin) This is temporary implementation. Make use of templated OP +// TODO(haibin) This is an inefficient temporary implementation +// Binary Compute between two row-sparse ndarray template -void BinaryComputeExSpSp(const nnvm::NodeAttrs& attrs, +void BinaryComputeExRsRs(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -51,7 +52,7 @@ void BinaryComputeExSpSp(const nnvm::NodeAttrs& attrs, auto num_rows_r = nd_r.aux_shape(rowsparse::kIdx)[0]; // This is (roughly) the number of result rows output.CheckAndAlloc({TShape({num_rows_l + num_rows_r})}); - // LOG(INFO) << "BinaryComputeExSpSp" << output.aux_shape(rowsparse::kIdx)[0]; + // LOG(INFO) << "BinaryComputeExRsRs" << output.aux_shape(rowsparse::kIdx)[0]; // Indices mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(output.dtype(), DType, { @@ -111,24 +112,17 @@ void BinaryComputeEx(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - // std::cout << "BinaryComputeEx invoked\n"; using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - // Check if any input is dense - bool fallback = false; - for (auto &nd : inputs) { - if (nd.storage_type() == kDefaultStorage) { - fallback = true; - } - } - if (fallback) { - FComputeExFallback(attrs, ctx, inputs, req, outputs, s, BinaryCompute); + // If any input is dense, fallback to FCompute + if (common::HasDefaultStorage(inputs)) { + FComputeExFallback(attrs, ctx, inputs, req, outputs, BinaryCompute); return; } - // Call SpSp function + // Call RsRs function CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; - BinaryComputeExSpSp(attrs, ctx, inputs, req, outputs); + BinaryComputeExRsRs(attrs, ctx, inputs, req, outputs); } template @@ -149,6 +143,7 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, }); } +// Only implemented for _backward_add for now template void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -162,7 +157,8 @@ void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "BinaryBackwardUseNoneEx fallback not implemented yet"; } // LOG(INFO) << "BinaryBackwardUseNoneEx"; - // WARNING: Assume identity op. Assume same shape + // The following code assumes LOP == mshadow_op::identity == ROP + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); TShape shape = inputs[0].aux_shape(rowsparse::kIdx); outputs[0].CheckAndAlloc({shape}); outputs[1].CheckAndAlloc({shape}); diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 4cf7f71ef591..8edfacc66865 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -32,7 +32,7 @@ NNVM_REGISTER_OP(_backward_add) mshadow_op::identity>) .set_attr("FComputeEx", BinaryBackwardUseNoneEx) -.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") @@ -63,6 +63,7 @@ NNVM_REGISTER_OP(_backward_mul) [](const NodeAttrs& attrs){ return std::vector >{{0, 1}}; }) +//.set_attr("FInferStorageType", ElemwiseSameStorageType<1, 2>) .set_attr("FCompute", BinaryBackwardUseIn); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index e76dde140255..e7c38644380f 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -9,6 +9,7 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(CastParam); +DMLC_REGISTER_PARAMETER(CastStorageParam); // copy MXNET_OPERATOR_REGISTER_UNARY(_copy) @@ -102,6 +103,20 @@ NNVM_REGISTER_OP(_backward_cast) .set_attr("TIsBackward", true) .set_attr("FCompute", CastCompute); +// TODO(haibin) declare backward op for cast storage. Also add FCompute(identity compute) +NNVM_REGISTER_OP(cast_storage) +.describe(R"code(Casts tensor storage type to the new type. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FComputeEx", CastStorageComputeEx) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CastStorageParam::__FIELDS__()); + + // negative MXNET_OPERATOR_REGISTER_UNARY(negative) .MXNET_DESCRIBE("Negate src") diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 5e6e51167ed4..19587c9ee8ad 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -59,6 +59,8 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } +// FIXME the index is hard coded for _identity_with_attr_like_rhs op +// Only implemented for row_sparse for now template void IdentityComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -69,24 +71,19 @@ void IdentityComputeEx(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; Stream *s = ctx.get_stream(); // LOG(INFO) << "IdentityComputeEx"; - // FIXME the input index is hard coded for _identity_with_attr_like_rhs op NDArrayStorageType storage_type = inputs[1].storage_type(); CHECK_EQ(storage_type, kRowSparseStorage) << "storage type " << storage_type << " not supported yet"; if (req[0] == kNullOp) { - LOG(FATAL) << "kNullOp in IdentityComputeEx"; + LOG(FATAL) << "kNullOp in IdentityComputeEx not supported yet"; return; } if (req[0] == kWriteInplace) { LOG(FATAL) << "kWriteInplace for sparse storage not supported yet"; // CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_); return; } - // FIXME probably need an interface to check if a sparse tensor is all zero TShape shape = inputs[1].aux_shape(rowsparse::kIdx); - if (shape.ndim() == 0) { - // LOG(INFO) << "Identify for all zero sparse ndarray"; - return; - } + if (shape.ndim() == 0) return; outputs[0].CheckAndAlloc({shape}); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { MSHADOW_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), AuxType, { @@ -142,6 +139,49 @@ void CastCompute(const nnvm::NodeAttrs& attrs, }); } +struct CastStorageParam : public dmlc::Parameter { + // use int for enumeration + // TODO(haibin) add enum for storage_type. Probably also aux-types + int storage_type; + DMLC_DECLARE_PARAMETER(CastStorageParam) { + DMLC_DECLARE_FIELD(storage_type) + .describe("Output storage type."); + } +}; + +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + auto out = outputs[0]; + auto in = inputs[0]; + CHECK(in.storage_type() == kRowSparseStorage); + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + MSHADOW_TYPE_SWITCH(in.aux_type(rowsparse::kIdx), AuxType, { + // Fill in zeros. SLOW + out.data().FlatTo1D(s) = 0; + // data() is not empty + if (in.storage_shape().ndim() != 0) { + // Copy over + auto in_data = in.data().FlatTo2D(s); + auto out_data = out.data().FlatTo2D(s); + auto num_rows = in.aux_shape(rowsparse::kIdx)[0]; + auto in_idx = in.aux_data(rowsparse::kIdx).FlatTo1D(s); + for (size_t i = 0; i < num_rows; i += 1) { + mshadow::Copy(out_data[in_idx[i]], in_data[i], s); + } + } + }); + }); +} + #define MXNET_OPERATOR_REGISTER_UNARY(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 73e002a09488..5f873dc21a89 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -124,12 +124,9 @@ void FillComputeEx(const nnvm::NodeAttrs& attrs, if (value == 0 && outputs[0].storage_type() != kDefaultStorage) { return; } - // TODO(haibin) Fallback - /* - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor out = outputs[0].FlatTo1D(s); - ASSIGN_DISPATCH(out, req[0], scalar(value)); - });*/ + CHECK_EQ(value, 0) << "Not implemented yet"; + CHECK_EQ(inputs.size(), 0); + CHECK_NE(outputs[0].storage_type(), kDefaultStorage); } template diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc index 17406c4f5964..e2aacfa9987e 100644 --- a/tests/cpp/ndarray_test.cc +++ b/tests/cpp/ndarray_test.cc @@ -9,6 +9,7 @@ #include #include "../src/executor/graph_executor.h" #include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" #define TEST_DTYPE float #define TEST_AUX_TYPE int32_t @@ -44,6 +45,20 @@ NDArray GetDenseND(const TShape shape, const Context ctx, const std::vectorPushSync([src, converted](RunContext ctx) { + // TODO provide type in attrs, which is empty now + OpContext op_ctx; + op_ctx.run_ctx = ctx; + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + }, src.ctx(), {src.var()}, {converted.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + converted.WaitToRead(); + return converted; +} void BasicTest() { Context ctx; @@ -67,22 +82,20 @@ void BinaryDenseSparseTest() { Engine::Get()->WaitForAll(); NDArray output(kRowSparseStorage, output_shape, ctx); - // Push the right vars! FIXME std::vector const_vars; const_vars.push_back(raw_data0.var()); const_vars.push_back(index0.var()); - // TODO Add switch stmt - Engine::Get()->PushSync([input_nd0, input_nd1, output](RunContext ctx) { - nnvm::NodeAttrs attrs; - OpContext op_ctx; - std::vector inputs, outputs; - std::vector req; - inputs.push_back(input_nd0); - inputs.push_back(input_nd1); - outputs.push_back(output); - op::BinaryComputeEx(attrs, op_ctx, inputs, req, outputs); - }, input_nd0.ctx(), const_vars, {output.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([input_nd0, input_nd1, output](RunContext ctx) { + nnvm::NodeAttrs attrs; + OpContext op_ctx; + std::vector inputs, outputs; + std::vector req; + inputs.push_back(input_nd0); + inputs.push_back(input_nd1); + outputs.push_back(output); + op::BinaryComputeEx(attrs, op_ctx, inputs, req, outputs); + }, input_nd0.ctx(), const_vars, {output.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); std::vector output_vals({11, 12, 3, 4, 15, 16}); NDArray out_data = GetDenseND(output_shape, ctx, output_vals); Engine::Get()->WaitForAll(); @@ -100,7 +113,7 @@ void SetValueTest() { CheckDataRegion(nd0.data(), nd1.data()); } -void BinarySpSpTest() { +void BinaryRsRsTest() { Context ctx = Context::CPU(); TShape index_shape({2}); @@ -122,21 +135,20 @@ void BinarySpSpTest() { const_vars.push_back(input_nd0.var()); const_vars.push_back(input_nd1.var()); - Engine::Get()->PushSync([input_nd0, input_nd1, output](RunContext ctx) { - nnvm::NodeAttrs attrs; - OpContext op_ctx; - std::vector inputs, outputs; - std::vector req; - inputs.push_back(input_nd0); - inputs.push_back(input_nd1); - outputs.push_back(output); - op::BinaryComputeExSpSp(attrs, op_ctx, inputs, req, outputs); - }, input_nd0.ctx(), const_vars, {output.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([input_nd0, input_nd1, output](RunContext ctx) { + OpContext op_ctx; + std::vector inputs, outputs; + std::vector req; + inputs.push_back(input_nd0); + inputs.push_back(input_nd1); + outputs.push_back(output); + op::BinaryComputeExRsRs({}, op_ctx, inputs, req, outputs); + }, input_nd0.ctx(), const_vars, {output.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); // Check the data region of output ndarray NDArray dense_output = GetDenseND(output_shape, ctx, {15, 15, 10, 10, 5, 5, 0, 0}); - NDArray copy = output.ConvertTo(kDefaultStorage, nullptr); + NDArray copy = Convert(kDefaultStorage, output); CheckDataRegion(dense_output.data(), copy.data()); } @@ -156,7 +168,7 @@ void InferElemwiseStorageTest() { TEST(NDArray, basics) { BasicTest(); - BinarySpSpTest(); + BinaryRsRsTest(); //Wait for all operations to finish Engine::Get()->WaitForAll(); InferElemwiseStorageTest(); @@ -166,9 +178,10 @@ TEST(NDArray, basics) { void TestDenseToDenseConversion() { Context ctx; TShape shape({2, 2}); - NDArray nd = GetDenseND(shape, ctx, {1, 2, 3, 4}); - auto nd_copy = nd.ConvertTo(kDefaultStorage, nullptr); - CheckDataRegion(nd_copy.data(), nd.data()); + NDArray nd = GetDenseND(shape, ctx, {1, 2, 3, 10}); + // TODO dense to dense conversion is not implemented yet + //auto nd_copy = Convert(kDefaultStorage, nd); + //CheckDataRegion(nd_copy.data(), nd.data()); } // sparse to dense conversion @@ -184,10 +197,8 @@ void TestSparseToDenseConversion() { // Dense ndarray NDArray dense_nd = GetDenseND(shape, ctx, {1, 1, 0, 0}); - - auto converted_nd = nd.ConvertTo(kDefaultStorage, nullptr); - auto converted_data = converted_nd.data(); - CheckDataRegion(converted_data, dense_nd.data()); + NDArray converted = Convert(kDefaultStorage, nd); + CheckDataRegion(converted.data(), dense_nd.data()); } TEST(NDArray, conversion) { diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 86bcef11866a..cd30359d9611 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -34,6 +34,7 @@ def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list= assert_almost_equal(out1, out2) def test_ndarray_elementwise_add(): + # TODO initialize with rand number dense_np = np.array([[1,2],[3,4],[5,6]]) sparse_np1 = np.array([[5,10],[0,0],[0,0]]) dense_nd = mx.nd.array(dense_np) @@ -42,7 +43,6 @@ def test_ndarray_elementwise_add(): idx = mx.nd.array([0], dtype=np.int32); sparse_nd1 = mx.sparse_nd.row_sparse(val, idx, (3,2)) sparse_nd2 = mx.sparse_nd.row_sparse(val, idx, (3,2)) - #TODO register under mx.sparse_nd namespace # dense - dense addition dense_plus_dense = mx.nd.elemwise_add(dense_nd, dense_nd); assert_almost_equal(dense_plus_dense.asnumpy(), dense_np + dense_np) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index b725c8527366..1b1b3ac2c13b 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -69,7 +69,41 @@ def test_elemwise_add_sparse_sparse(): exec_test.backward(out_grads = exec_test.outputs) assert_almost_equal(arr_grad1.asnumpy(), arr_grad2.asnumpy()) +def test_elemwise_add_multiple_stages(): + # prep data + shape = (4, 2) + ds_np = np.array([[1,2],[3,4],[5,6],[7,8]]) + sp_np1 = np.array([[5,10],[0,0],[0,0],[0,0]]) + sp_np2 = np.array([[0,0],[5,10],[0,0],[0,0]]) + + val1 = mx.nd.array([5, 10]); + val2 = mx.nd.array([5, 10]); + idx1 = mx.nd.array([0], dtype=np.int32); + idx2 = mx.nd.array([1], dtype=np.int32); + sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1':sp_nd1, 'sp_data2':sp_nd2, + 'ds_data':ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in xrange(3)] + exec_test = test.bind(default_context(), args={'sp_data1':sp_nd1, 'sp_data2':sp_nd2, + 'ds_data':ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads = exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + if __name__ == '__main__': test_elemwise_add_dense() test_elemwise_add_dense_sparse() test_elemwise_add_sparse_sparse() + test_elemwise_add_multiple_stages()