-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Rust] Improve NDArray, GraphRt, and Relay bindings #6563
Conversation
- Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings
Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings
v.set_len(sz); | ||
} | ||
Ok(v) | ||
let n = self.size() / size_of::<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could a user ever request T ~ ()
, or some other zero-sized item, giving a divide-by-zero error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to constrain T anyways, @mwillsey and I were trying to move the arrays on to the new bindings but it seems like there are so many unsafe gotchas from previous bindings we are trying to work through. In general I think we should probably change NDArray to be like NDArray where T: DataType or something like that. We could probably just slap a : Sized
on it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah i think panicking on div-by-zero would be the correct behavior. Follow up on jared, I think a goal should be to move to NDArray, because right now these conversions are all effectively transmutes.
let ctx = Context::cpu(0); | ||
println!("before empty"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stale helpers, not needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with println! in tests, since the output is hidden by default, but that's just me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! They don't bug me, just flagged them in case they were leftovers.
@@ -184,28 +246,19 @@ impl NDArray { | |||
/// let ctx = Context::cpu(0); | |||
/// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); | |||
/// ndarray.copy_from_buffer(&mut data); | |||
/// assert_eq!(ndarray.shape(), Some(&mut shape[..])); | |||
/// assert_eq!(ndarray.shape(), shape); | |||
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); | |||
/// ``` | |||
pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused about how this function works. If I call it as my_array.to_vec<u8>
, I'll get back a vector with more elements than if I call it as my_array.to_vec::<u32>
, which seems like one or the other would not correspond to a flattened version of self
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, see above. The problem is that NDArray isn't NDArray. So we trust the type of T given to us by to_vec::<T>
. That makes it effectively a transmute of the buffer.
@@ -40,6 +40,10 @@ tvm-macros = { version = "*", path = "../tvm-macros/" } | |||
paste = "0.1" | |||
mashup = "0.1" | |||
once_cell = "^1.3.1" | |||
pyo3 = { version = "0.11.1", optional = true } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: ^0.11.1
, to pick up new non-breaking changes? Or do we want 0.11.1
specifically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that "^" is implicit, no? https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies-from-cratesio
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL :)
@imalsogreg @jroesch Just to break this out into a comment. Many of the remaining issues on this PR that @imalsogreg commented on stem from the fact that Should we block this PR on fixing this? I'm leaning toward no. |
I think we should just add a Rust bindings stabilization tracking issue for all these remaining items and set a stabilization goal on the current API, etc at which we will impose more strict coding standards and clean up. |
We can put them here: #6604. |
@mwillsey 👍 for merge from me. I didn't mean to block anything, mostly I'm trying to get oriented. |
f223c50
to
9e580aa
Compare
We need to update the docker images to match MxNet see #6628 |
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
* WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch <jroesch@octoml.ai> Co-authored-by: Gus Smith <guscomps@gmail.com>
This PR packages up some varied work that @jroesch, @gussmith23, and I have been working on to improve the Rust bindings.