Skip to content

Commit

Permalink
[Rust] Improve NDArray, GraphRt, and Relay bindings (#6563)
Browse files Browse the repository at this point in the history
* WIP

WIP

* Add support for loading Python packed functions

* Flesh out Relay AST in Rust

* More tweeks for getting functions out

* Deploy Rust docs as part of build

* Add some more types

* Introduce NDArray 2.0

* Work on NDArray 2.0 before restoring tests

* Formatting and code fixes to get it to compile

* Add more Rust bindings

- Converts Conv2d attrs to use tvm::String, so that we can add Rust binding
- Uses Type for checked_type in Rust bindings
- Fix type key in Rust bindings
- Make data field contain NDArray in Rust bindings

* Clean up object ptr passing.

* WIP

* Add debugging for NDArray and fix all test cases

* Add breaking test

* Dispatch some todos

* Format

* Fix ndarray size and len

* Add BiasAddAttrs rust bindings

* Add DenseAttrs rust bindings

* Change to TVM string

* Add more Rust bindings

Add GlobalPool2DAttrs Rust binding

Add ExpandDimsAttrs Rust bindings

Add MaxPool2DAttrs rust bindings

* Fix some test attributes

* Improve the NDArray api

* Fix some more ndarray stuff

* Get the resnet demo kinda working

* Add SoftmaxAttrs Rust bindings

* Implement Hash and Eq for Relay Exprs

* Add underscore to unused function

* Fix broken ass resnet script

* Improve some ndarray conversions

* Make sure the build script runs correctly

* Clean up ResNet example tremedously

Expose C++ graph runtime via cleaner Rust API rewrite example.

* Add ASF header

* Format

* Format

* Format resnet rust python script

* Add type files and refactor span

* Format

* Format

* Change types from std::string to tvm::String in packed function

* Add ASF header

* Fix test w/ ndarray's API change

* Fix array test

* Fix anyhow import

* Put back some anyhow stuff

* Clean up

* Try and fix tests/scripts/task_rust.sh

* Disable ResNet for now

* Turn off building of Rust docs until we update CI

* Actually disable

Co-authored-by: Jared Roesch <jroesch@octoml.ai>
Co-authored-by: Gus Smith <guscomps@gmail.com>
  • Loading branch information
3 people authored Oct 6, 2020
1 parent feb041d commit 277bfc8
Show file tree
Hide file tree
Showing 47 changed files with 1,803 additions and 542 deletions.
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/* TODO(@jroesch): rename to field_pats */
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;

Expand Down
12 changes: 7 additions & 5 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <string>

#include "tvm/runtime/container.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -115,9 +117,9 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down Expand Up @@ -681,7 +683,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
tvm::String layout;
bool ceil_mode;

TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
Expand Down Expand Up @@ -744,7 +746,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {

/*! \brief Attributes for global pool operator */
struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
std::string layout;
tvm::String layout;

TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ inline ObjectPtr<Object> NDArray::FFIDataFromHandle(TVMArrayHandle handle) {
inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) {
// NOTE: it is necessary to cast to container then to base
// so that the FFI handle uses the ContainerBase address.
return reinterpret_cast<TVMArrayHandle>(static_cast<NDArray::ContainerBase*>(
auto ptr = reinterpret_cast<TVMArrayHandle>(static_cast<NDArray::ContainerBase*>(
static_cast<NDArray::Container*>(const_cast<Object*>(nd.get()))));
return ptr;
}

inline void NDArray::FFIDecRef(TVMArrayHandle handle) {
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class Layout : public ObjectRef {
public:
explicit Layout(const Array<tir::IterVar>& axes);

/*! \brief construct from a string */
Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*)

/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)

Expand Down
8 changes: 5 additions & 3 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7;
///
/// # Examples
///
/// ```norun
/// let graph_json = fs::read_to_string("graph.json").unwrap();
/// ```no_run
/// use tvm_graph_rt::Graph;
/// use std::convert::TryFrom;
/// let graph_json = std::fs::read_to_string("graph.json").unwrap();
/// let graph = Graph::try_from(&graph_json).unwrap();
/// ```
#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -147,7 +149,7 @@ impl<'a> TryFrom<&'a str> for Graph {
///
/// # Examples
///
/// ```norun
/// ```no_compile
/// use ndarray::Array;
///
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/src/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ pub unsafe extern "C" fn TVMBackendParallelBarrier(

#[cfg(test)]
mod tests {
use std::{ptr, thread, time::Duration};
use std::{thread, time::Duration};

use super::*;

Expand All @@ -228,7 +228,7 @@ mod tests {
assert_eq!(max_concurrency(), 24);
}

extern "C" fn flambda(
extern "C" fn _flambda(
task_id: usize,
penv: *const TVMParallelGroupEnv,
cdata: *const c_void,
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
proc-macro-error = "^1.0"
1 change: 1 addition & 0 deletions rust/tvm-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ tvm-macros = { version = "0.1", path = "../tvm-macros" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
memoffset = "0.5.6"

[dev-dependencies]
anyhow = "^1.0"
Expand Down
17 changes: 17 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,20 @@ impl<'a, T: IsObjectRef> TryFrom<RetValue> for Array<T> {
})
}
}

#[cfg(test)]
mod tests {
use super::Array;
use crate::function::Result;
use crate::string::String;

#[test]
fn create_array_and_get() -> Result<()> {
let vec: Vec<String> = vec!["foo".into(), "bar".into(), "baz".into()];
let array = Array::from_vec(vec)?;
assert_eq!(array.get(0)?.to_string(), "foo");
assert_eq!(array.get(1)?.to_string(), "bar");
assert_eq!(array.get(2)?.to_string(), "baz");
Ok(())
}
}
2 changes: 0 additions & 2 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub struct TypeMismatchError {

#[derive(Debug, Error)]
pub enum NDArrayError {
#[error("Missing NDArray shape.")]
MissingShape,
#[error("Cannot convert from an empty array.")]
EmptyArray,
#[error("Invalid datatype when attempting to convert ndarray.")]
Expand Down
Loading

0 comments on commit 277bfc8

Please sign in to comment.