diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 46de10ac9cc0..f905bf8f7222 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -30,6 +30,7 @@ #include #include "../mxnet_op.h" #include "../operator_common.h" +#include "../../src/operator/tensor/init_op.h" #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ @@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx, }); } +/*! + * \brief Casts a csr matrix to another csr. + */ +template +void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr, + NDArray* output) { + mshadow::Stream* s = ctx.get_stream(); + if (!csr.storage_initialized()) { + FillZerosCsrImpl(s, *output); + return; + } + std::vector aux_shapes({csr.aux_shape(csr::kIndPtr), csr.aux_shape(csr::kIdx)}); + output->CheckAndAlloc(aux_shapes); + const TBlob& val = output->data(); + const TBlob& indptr = output->aux_data(csr::kIndPtr); + const TBlob& idx = output->aux_data(csr::kIdx); + mxnet_op::copy(s, val, csr.data()); + mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr)); + mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx)); +} + +/*! + * \brief Casts a rsp matrix to another rsp. + */ +template +void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp, + NDArray* output) { + CHECK_EQ(rsp.storage_type(), output->storage_type()) + << "Copying with different storage type"; + mshadow::Stream* s = ctx.get_stream(); + if (!rsp.storage_initialized()) { + FillZerosRspImpl(s, *output); + return; + } + auto aux_shape = rsp.aux_shape(rowsparse::kIdx); + output->CheckAndAlloc({aux_shape}); + const TBlob& val = output->data(); + const TBlob& idx = output->aux_data(rowsparse::kIdx); + const TBlob& from_val = rsp.data(); + const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx); + mxnet_op::copy(s, val, from_val); + mxnet_op::copy(s, idx, from_idx); +} + template void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, @@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx, } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); CastStorageCsrDnsImpl(ctx, input, &ret); + } else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) { + NDArray ret = output; + CastStorageCsrCsrImpl(ctx, input, &ret); + } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; + CastStorageRspRspImpl(ctx, input, &ret); #if MXNET_USE_MKLDNN == 1 } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) { CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type); diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index 9f257b140f7b..f77a50a7e70f 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype parameter: - cast_storage(row_sparse, 'default') = default - cast_storage(default, 'csr') = csr - cast_storage(default, 'row_sparse') = row_sparse +- cast_storage(csr, 'csr') = csr +- cast_storage(row_sparse, 'row_sparse') = row_sparse Example:: diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 9417df317482..5ad5215036d6 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1177,10 +1177,13 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad= shape_3d = rand_shape_3d() check_cast_storage(shape_2d, d, 'csr', 'default') check_cast_storage(shape_2d, d, 'default', 'csr') + check_cast_storage(shape_2d, d, 'csr', 'csr') check_cast_storage(shape_2d, d, 'row_sparse', 'default') check_cast_storage(shape_2d, d, 'default', 'row_sparse') + check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse') check_cast_storage(shape_3d, d, 'row_sparse', 'default') check_cast_storage(shape_3d, d, 'default', 'row_sparse') + check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse') for i in range(4, 6): shape = rand_shape_nd(i, 5) check_cast_storage(shape, d, 'default', 'row_sparse')