-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NODE][REFLECTION] Support NDArray as field #1452
Conversation
@jroesch can you review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modulo a few nitpicks looks good to me, did you write the base64 encoding code yourself?
include/tvm/runtime/ndarray.h
Outdated
uint64_t header = kTVMNDArrayMagic, reserved = 0; | ||
strm->Write(header); | ||
strm->Write(reserved); | ||
// always save data as CPU context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe expand on this note, we expect that after deserialization the user will reallocate the data with the appropriate device type and id?
/*! | ||
* \brief wrapper node container for exchange. | ||
*/ | ||
struct NDArrayWrapperNode : public ::tvm::Node { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this struct?
data.push_back(LoadDLTensor(strm)); | ||
tvm::runtime::NDArray temp; | ||
temp.Load(strm); | ||
std::shared_ptr<NDArrayWrapperNode> n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these wrappers just so we get a 2-tuple like piece of data in Python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these are wrappers to get 2-tuple pieces of data, we put it in nnvm so they can be deprecated in future. Used to test the features introduced
python/tvm/_ffi/ndarray.py
Outdated
return not self.__eq__(other) | ||
|
||
def same_as(self, other): | ||
"""check object identity equality""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we capitalize this?
TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); | ||
NDArray temp; | ||
temp.Load(strm); | ||
temp.CopyTo(dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@jroesch I have addressed your comments, please check again. base64 encoding was adapted from code that I wrote in previous projects, so we can keep things self-contained. The NDArrayWrapper is mainly used to test the introduced feature. It also simplifies the previous hack in nnvm. We can phase it out later(that is why i didn't put it in tvm) |
👍 everything looked good, sorry took me a while to get back to it, didn't check again until now. |
Support NDArray as a serializable and reflectable field of Node system. This is going to be useful to introduce constant node in the language.
Example Testcase
In C++