Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add support for cast storage on same stypes (#10400)
Browse files Browse the repository at this point in the history
* Add cast storage support for same stypes

* Add imports

* Fix cast

* Fix doc for cast_storage

* Fix
  • Loading branch information
anirudh2290 authored and eric-haibin-lin committed Apr 5, 2018
1 parent 6d28d0a commit 2cd09a0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/operator/tensor/cast_storage-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../../src/operator/tensor/init_op.h"
#ifdef __CUDACC__
#include "./cast_storage-inl.cuh"
#endif // __CUDACC__
Expand Down Expand Up @@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx,
});
}

/*!
* \brief Casts a csr matrix to another csr.
*/
template <typename xpu>
void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr,
NDArray* output) {
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (!csr.storage_initialized()) {
FillZerosCsrImpl(s, *output);
return;
}
std::vector<TShape> aux_shapes({csr.aux_shape(csr::kIndPtr), csr.aux_shape(csr::kIdx)});
output->CheckAndAlloc(aux_shapes);
const TBlob& val = output->data();
const TBlob& indptr = output->aux_data(csr::kIndPtr);
const TBlob& idx = output->aux_data(csr::kIdx);
mxnet_op::copy(s, val, csr.data());
mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr));
mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx));
}

/*!
* \brief Casts a rsp matrix to another rsp.
*/
template <typename xpu>
void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp,
NDArray* output) {
CHECK_EQ(rsp.storage_type(), output->storage_type())
<< "Copying with different storage type";
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
if (!rsp.storage_initialized()) {
FillZerosRspImpl(s, *output);
return;
}
auto aux_shape = rsp.aux_shape(rowsparse::kIdx);
output->CheckAndAlloc({aux_shape});
const TBlob& val = output->data();
const TBlob& idx = output->aux_data(rowsparse::kIdx);
const TBlob& from_val = rsp.data();
const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx);
mxnet_op::copy(s, val, from_val);
mxnet_op::copy(s, idx, from_idx);
}

template<typename xpu>
void CastStorageComputeImpl(const OpContext& ctx,
const NDArray& input,
Expand All @@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx,
} else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) {
TBlob ret = output.data();
CastStorageCsrDnsImpl<xpu>(ctx, input, &ret);
} else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) {
NDArray ret = output;
CastStorageCsrCsrImpl<xpu>(ctx, input, &ret);
} else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) {
NDArray ret = output;
CastStorageRspRspImpl<xpu>(ctx, input, &ret);
#if MXNET_USE_MKLDNN == 1
} else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) {
CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type);
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype parameter:
- cast_storage(row_sparse, 'default') = default
- cast_storage(default, 'csr') = csr
- cast_storage(default, 'row_sparse') = row_sparse
- cast_storage(csr, 'csr') = csr
- cast_storage(row_sparse, 'row_sparse') = row_sparse
Example::
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,10 +1177,13 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad=
shape_3d = rand_shape_3d()
check_cast_storage(shape_2d, d, 'csr', 'default')
check_cast_storage(shape_2d, d, 'default', 'csr')
check_cast_storage(shape_2d, d, 'csr', 'csr')
check_cast_storage(shape_2d, d, 'row_sparse', 'default')
check_cast_storage(shape_2d, d, 'default', 'row_sparse')
check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse')
check_cast_storage(shape_3d, d, 'row_sparse', 'default')
check_cast_storage(shape_3d, d, 'default', 'row_sparse')
check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse')
for i in range(4, 6):
shape = rand_shape_nd(i, 5)
check_cast_storage(shape, d, 'default', 'row_sparse')
Expand Down

0 comments on commit 2cd09a0

Please sign in to comment.