From bd0a62fd6e6fb6c090612d500d1b258812e5a733 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 21 Jun 2017 20:55:23 +0000 Subject: [PATCH] move init dns zeros to init_op.h for kvstore to use --- src/operator/optimizer_op-inl.h | 24 +++--------------------- src/operator/tensor/init_op.h | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 272ebcc44c4c..176da461f31f 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -18,6 +18,7 @@ #include "./mshadow_op.h" #include "./elemwise_op_common.h" #include "mxnet_op.h" +#include "./tensor/init_op.h" namespace mxnet { namespace op { @@ -414,25 +415,6 @@ struct SGDMomRspDnsKernel { } }; -template -inline void InitDnsZeros(mshadow::Stream *s, NDArray *out) { - using namespace rowsparse; - using namespace mshadow::expr; - using namespace mshadow; - using namespace mxnet_op; - CHECK_EQ(out->storage_type(), kRowSparseStorage); - MSHADOW_REAL_TYPE_SWITCH(out->dtype(), DType, { - MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { - auto num_rows = out->shape()[0]; - out->CheckAndAlloc({Shape1(num_rows)}); - auto idx = out->aux_data(kIdx).FlatTo1D(s); - auto val = out->data(); - Kernel::Launch(s, val.Size(), val.dptr()); - ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)) - }); - }); -} - template inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param, const OpContext &ctx, @@ -452,7 +434,7 @@ inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param, // fill mom with zero values if not initialized yet if (!mom.storage_initialized()) { NDArray mom_zeros = mom; - InitDnsZeros(s, &mom_zeros); + FillDnsZerosRspImpl(s, &mom_zeros); } // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only // feed in kWriteTo as req for all operators. @@ -493,7 +475,7 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, // fill mom with zero values in order to reuse the sgd mom dns impl if (!mom.storage_initialized()) { NDArray mom_zeros = mom; - InitDnsZeros(s, &mom_zeros); + FillDnsZerosRspImpl(s, &mom_zeros); } // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only // feed in kWriteTo as req for all operators. diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5b72a69fd72f..bc885f3cecf5 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -128,6 +128,27 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray, +// instead of the usual compact representation. +template +inline void FillDnsZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + using namespace rowsparse; + using namespace mshadow::expr; + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(dst->storage_type(), kRowSparseStorage); + MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(dst->aux_type(kIdx), IType, { + auto num_rows = dst->shape()[0]; + dst->CheckAndAlloc({Shape1(num_rows)}); + auto idx = dst->aux_data(kIdx).FlatTo1D(s); + auto val = dst->data(); + Kernel::Launch(s, val.Size(), val.dptr()); + ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)) + }); + }); +} + // Fill a rsp NDArray with zeros by updating the aux shape. template void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) {