-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-93] Sparse support for Custom Op #10374
Conversation
…to refactor_for_customop
…to refactor_for_customop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong can you help check if the python API is reasonable?
'''Example of how to use custom op with sparse ndarrays | ||
''' | ||
def forward(self, is_train, req, in_data, out_data, aux): | ||
#self.assign(out_data[0], req[0], mx.nd.sparse.square(in_data[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make more sense to have an example without using mx.nd.square
? For example:
input = in_data[0]
data = input.data
output = sparse.csr_matrix((data*data, input.indptr, input.indices), shape=...)
self.assign(out_data[0], req[0], output)
src/operator/operator_common.h
Outdated
@@ -314,6 +314,32 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) { | |||
} | |||
#endif | |||
|
|||
/*! \brief allocate ndarrays from existing ndarrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the best place to put this function. Would moving this inside custom.h work? If we put it here it's very likely that somebody misuses the fucntion looking at the current doc
src/operator/custom/custom.cc
Outdated
void Forward(const OpStatePtr& state, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
void Forward(const OpStatePtr& state, const OpContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to ForwardEx
to follow the convention for ComputeEx?
src/operator/custom/custom.cc
Outdated
const CustomParam& params = state.get_state<CustomParam>(); | ||
std::vector<void*> ptrs; | ||
std::vector<int> tags; | ||
std::vector<NDArray> cpys; | ||
std::unordered_set<int> input_tags({0, 4}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need better documentation to explain what the magic numbers are for...
|
||
if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) { | ||
for (size_t i = 0; i < iattr->size(); i++) { | ||
STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if one of the input/output is sparse??? Would the check fail? Shouldn't it only assign stype to the undefined ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same comment for forward stype inference)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for backward compatibility with other frontends which dont support sparse for customops. will never go into if clause for python frontend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if a perl user creates a custom op with sparse ndarray (without custom infer storage function), would this break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, because sparse is not supported for perl. will have to be added seperately, infer_storage_type and infer_storage_type_backward need to be registered.
src/operator/custom/custom-inl.h
Outdated
bool training, | ||
const std::vector<NDArray>& arrs) { | ||
template <typename Func> | ||
void Push(const Func& func, const OpContext& ctx, bool recording, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some doc explaining why we need to pass inputs/outputs array?
# test for backward compatibility, i.e. the correctness of default implementation of | ||
# infer storage in custom operator | ||
class Mult(mx.operator.CustomOp): | ||
def forward(self, is_train, req, in_data, out_data, aux): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 8 space indentation?
@@ -4059,6 +4059,79 @@ def create_operator(self, ctx, shapes, dtypes): | |||
with mx.contrib.autograd.train_section(): | |||
y = mx.nd.Custom(x, aux, op_type='sqr') | |||
y.backward() | |||
y.wait_to_read() | |||
x.grad.wait_to_read() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the test case for sparse input not added here?
include/mxnet/ndarray.h
Outdated
auto stype = arr.storage_type(); | ||
CHECK(stype == kCSRStorage || stype == kRowSparseStorage) | ||
<< "Only to be used with CSR and RSP storage types"; | ||
ptr_->shandle.dptr = arr.ptr_->shandle.dptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would ptr_->shandle = arr.ptr_->shandle
be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work because the Handle struct also stores a pointer to data. doing ptr_->shandle = arr.ptr_->shandle
would make copies of dptr which point to same data. But then sparse updates the dptr at runtime and this wont reflect in the copied shandle.
include/mxnet/ndarray.h
Outdated
@@ -507,6 +507,35 @@ class NDArray { | |||
ret.reuse_ = true; | |||
return ret; | |||
} | |||
|
|||
inline void SparseUpdateChunk(const NDArray &arr) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely need to add doc for this fucntion to prevent others from misusage
Thank you for reviewing @eric-haibin-lin . I have addressed your comments. |
def forward(self, is_train, req, in_data, out_data, aux): | ||
inp = in_data[0] | ||
if inp.stype == 'csr': | ||
csr_m = inp * inp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work? Did you mean in_data[0].data
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used to fallback to dense. have fixed now.
src/operator/custom/custom.cc
Outdated
@@ -45,6 +46,31 @@ struct CustomParam { | |||
std::shared_ptr<MXCallbackList> info; | |||
}; | |||
|
|||
/*! \brief allocate ndarrays from existing ndarrays | |||
*/ | |||
inline void allocate_ndarray_copy(NDArray** nd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use CamelCase ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
inp = in_data[0] | ||
if inp.stype == 'csr': | ||
csr_m = inp.data | ||
csr_m = csr_m.reshape(inp.shape[0] * inp.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to reshape?
self.assign(out_data[0], req[0], out) | ||
|
||
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): | ||
self.assign(in_grad[0], req[0], 2 * in_data[0] * out_grad[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use sparse.elemwise_mul(csr, csr)
so that it doesn't fall back to dense?
else: | ||
inp = in_data[0] | ||
csr_m = inp.data * inp.data | ||
csr_m = csr_m.reshape(inp.shape[0] * inp.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inp.data should already be 1-D
…to refactor_for_customop
@piiswrong WDYT ? |
include/mxnet/ndarray.h
Outdated
ptr_->storage_shape = arr.ptr_->storage_shape; | ||
ptr_->storage_type = arr.ptr_->storage_type; | ||
ptr_->ctx = arr.ptr_->ctx; | ||
ptr_->aux_handles = arr.ptr_->aux_handles; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is causing memory leak. You should do swaps instead.
include/mxnet/ndarray.h
Outdated
*/ | ||
inline void SparseUpdateChunk(const NDArray &arr) const { | ||
auto stype = arr.storage_type(); | ||
CHECK(stype == kCSRStorage || stype == kRowSparseStorage) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that shape and dtype are the same.
rhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) | ||
lhs.attach_grad() | ||
rhs.attach_grad() | ||
with mx.contrib.autograd.train_section(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be record()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
return MultNoGrad() | ||
|
||
def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype): | ||
return [], [], [], ['default'], [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are the returned values all empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
earlier my interface was such that it was okay to have empty lists and it would infer it as default. after talking to @eric-haibin-lin we decided to enforce users implementing infer_storage_type_backward interface to return lists with same size as input lists. also now, any undefined stypes will throw exception
python/mxnet/operator.py
Outdated
aux_stype : list | ||
list of inferred storage types for auxiliary states. | ||
""" | ||
return list(ograd_stype), list(in_stype), list(out_stype), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This default implementation didn't do anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
src/operator/custom/custom.cc
Outdated
std::vector<int> tags; | ||
std::vector<NDArray> cpys; | ||
|
||
ptrs.reserve(total); | ||
tags.reserve(total); | ||
cpys.reserve(total); | ||
|
||
std::unordered_set<int> input_tags({3, 0, 1, 4}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
@piiswrong @eric-haibin-lin I have addressed your comments. |
…to refactor_for_customop
python/mxnet/operator.py
Outdated
|
||
Parameters | ||
---------- | ||
in_stype : list of stypes, Valid stypes are default, row_sparse and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valid -> valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
python/mxnet/operator.py
Outdated
Will raise an error if undefined storage type is returned. | ||
Returned lists have to be the same size as the input lists to infer_storage_type_backward, | ||
otherwise an exception will be thrown. When this interface is not implemented, | ||
all stypes will fallback to default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will be careful about using the word "fallback" here since it has specific meaning for sparse ops.
I'm a little bit confused about the default behavior of forward stype inference vs. backward stype inference.
In forward you replicated the stype of in_stype to outputs. In backward you replicated the "default" stype to all outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the default implementation only default stypes are supported. that is why i replicated stype of in_stypes. I have added asserts now in infer_storage_type and infer_storage_type_backward to prevent misuse.
* Initial changes * Add custom op suppoort in backend * Add operator changes * Add refactor changes * Add 3p * Add custom ops support * Move out common code to a function * Fix changes * Fix custom op changes * Remove whitespace * Fix * Add fix * Remove test dependency * Add example for custom sparse sqr * Remove extra line * Add comments for InferStorageTypeBackward * Fix lint * Address review comments * Fix for shandle * Fix for shandle second * Fix naive engine bug * Fix * Remove reshape * Add swap logic for shandles * Add rtol atol * Fix op * Fix custom op * Fix pylint * Add assert * Fix lint * Add check for undefined for igrad stypes
* Initial changes * Add custom op suppoort in backend * Add operator changes * Add refactor changes * Add 3p * Add custom ops support * Move out common code to a function * Fix changes * Fix custom op changes * Remove whitespace * Fix * Add fix * Remove test dependency * Add example for custom sparse sqr * Remove extra line * Add comments for InferStorageTypeBackward * Fix lint * Address review comments * Fix for shandle * Fix for shandle second * Fix naive engine bug * Fix * Remove reshape * Add swap logic for shandles * Add rtol atol * Fix op * Fix custom op * Fix pylint * Add assert * Fix lint * Add check for undefined for igrad stypes
Description
Adds sparse support for custom op. Registers InferStorageType and InferStorageTypeBackward interface for custom op. registers Forward and Backward with FStatefulComputeEx interface. Adds NDarray API to update chunk of a sparse ndarray from an existing ndarray.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@piiswrong @eric-haibin-lin