Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Improve NDArray, GraphRt, and Relay bindings #6563

Merged
merged 50 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
8c36f88
WIP
jroesch Aug 25, 2020
b0b2b59
Add support for loading Python packed functions
jroesch Aug 31, 2020
d282ef5
Flesh out Relay AST in Rust
jroesch Sep 7, 2020
a80a9e4
More tweeks for getting functions out
jroesch Sep 7, 2020
5f9b521
Deploy Rust docs as part of build
jroesch Sep 10, 2020
3343f61
Add some more types
mwillsey Sep 21, 2020
772375f
Introduce NDArray 2.0
jroesch Sep 11, 2020
2b982c1
Work on NDArray 2.0 before restoring tests
jroesch Sep 11, 2020
dbe32f1
Formatting and code fixes to get it to compile
gussmith23 Sep 18, 2020
f7c36ca
Add more Rust bindings
gussmith23 Sep 22, 2020
5363ff4
Clean up object ptr passing.
jroesch Sep 22, 2020
930e29e
WIP
jroesch Sep 22, 2020
d35c03c
Add debugging for NDArray and fix all test cases
jroesch Sep 23, 2020
5eb46c2
Add breaking test
gussmith23 Sep 23, 2020
518b230
Dispatch some todos
mwillsey Sep 23, 2020
ba92c43
Format
mwillsey Sep 23, 2020
f99154d
Fix ndarray size and len
mwillsey Sep 23, 2020
4830621
Add BiasAddAttrs rust bindings
gussmith23 Sep 23, 2020
53d8377
Add DenseAttrs rust bindings
gussmith23 Sep 23, 2020
5062ecb
Change to TVM string
gussmith23 Sep 24, 2020
66fcc91
Add more Rust bindings
gussmith23 Sep 24, 2020
4519acd
Fix some test attributes
mwillsey Sep 24, 2020
0b922f2
Improve the NDArray api
mwillsey Sep 24, 2020
15f88fd
Fix some more ndarray stuff
mwillsey Sep 25, 2020
a435cb8
Get the resnet demo kinda working
mwillsey Sep 25, 2020
d119570
Add SoftmaxAttrs Rust bindings
gussmith23 Sep 24, 2020
f48282a
Implement Hash and Eq for Relay Exprs
gussmith23 Sep 25, 2020
4979041
Add underscore to unused function
gussmith23 Sep 25, 2020
0ef7e3e
Fix broken ass resnet script
jroesch Sep 25, 2020
f01dcfc
Improve some ndarray conversions
mwillsey Sep 25, 2020
84f864e
Make sure the build script runs correctly
mwillsey Sep 25, 2020
c702cf4
Clean up ResNet example tremedously
jroesch Sep 26, 2020
3b6edf9
Add ASF header
jroesch Sep 26, 2020
f0af06e
Format
gussmith23 Sep 29, 2020
70e8a3e
Format
gussmith23 Sep 29, 2020
3e96484
Format resnet rust python script
mwillsey Sep 29, 2020
e893b57
Add type files and refactor span
jroesch Sep 29, 2020
b6f3962
Format
jroesch Sep 29, 2020
ed326e8
Format
jroesch Sep 29, 2020
49a42ba
Change types from std::string to tvm::String in packed function
gussmith23 Sep 29, 2020
54ed9b1
Add ASF header
gussmith23 Sep 29, 2020
644b746
Fix test w/ ndarray's API change
gussmith23 Sep 30, 2020
5be2063
Fix array test
mwillsey Sep 30, 2020
83eb87f
Fix anyhow import
mwillsey Sep 30, 2020
72bce31
Put back some anyhow stuff
mwillsey Sep 30, 2020
9e580aa
Clean up
jroesch Oct 1, 2020
778f3ba
Try and fix tests/scripts/task_rust.sh
jroesch Oct 3, 2020
2266ddd
Disable ResNet for now
jroesch Oct 5, 2020
d93d134
Turn off building of Rust docs until we update CI
jroesch Oct 6, 2020
422e970
Actually disable
jroesch Oct 6, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/* TODO(@jroesch): rename to field_pats */
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;

Expand Down
12 changes: 7 additions & 5 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <string>

#include "tvm/runtime/container.h"

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -115,9 +117,9 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
Expand Down Expand Up @@ -681,7 +683,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
tvm::String layout;
bool ceil_mode;

TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
Expand Down Expand Up @@ -744,7 +746,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {

/*! \brief Attributes for global pool operator */
struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
std::string layout;
tvm::String layout;

TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,9 @@ inline ObjectPtr<Object> NDArray::FFIDataFromHandle(TVMArrayHandle handle) {
inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) {
// NOTE: it is necessary to cast to container then to base
// so that the FFI handle uses the ContainerBase address.
return reinterpret_cast<TVMArrayHandle>(static_cast<NDArray::ContainerBase*>(
auto ptr = reinterpret_cast<TVMArrayHandle>(static_cast<NDArray::ContainerBase*>(
static_cast<NDArray::Container*>(const_cast<Object*>(nd.get()))));
return ptr;
}

inline void NDArray::FFIDecRef(TVMArrayHandle handle) {
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class Layout : public ObjectRef {
public:
explicit Layout(const Array<tir::IterVar>& axes);

/*! \brief construct from a string */
Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*)

/*! \brief construct from a string */
Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)

Expand Down
8 changes: 5 additions & 3 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7;
///
/// # Examples
///
/// ```norun
/// let graph_json = fs::read_to_string("graph.json").unwrap();
/// ```no_run
/// use tvm_graph_rt::Graph;
/// use std::convert::TryFrom;
/// let graph_json = std::fs::read_to_string("graph.json").unwrap();
/// let graph = Graph::try_from(&graph_json).unwrap();
/// ```
#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -147,7 +149,7 @@ impl<'a> TryFrom<&'a str> for Graph {
///
/// # Examples
///
/// ```norun
/// ```no_compile
/// use ndarray::Array;
///
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/src/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ pub unsafe extern "C" fn TVMBackendParallelBarrier(

#[cfg(test)]
mod tests {
use std::{ptr, thread, time::Duration};
use std::{thread, time::Duration};

use super::*;

Expand All @@ -228,7 +228,7 @@ mod tests {
assert_eq!(max_concurrency(), 24);
}

extern "C" fn flambda(
extern "C" fn _flambda(
task_id: usize,
penv: *const TVMParallelGroupEnv,
cdata: *const c_void,
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
proc-macro-error = "^1.0"
1 change: 1 addition & 0 deletions rust/tvm-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ tvm-macros = { version = "0.1", path = "../tvm-macros" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
memoffset = "0.5.6"

[dev-dependencies]
anyhow = "^1.0"
Expand Down
17 changes: 17 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,20 @@ impl<'a, T: IsObjectRef> TryFrom<RetValue> for Array<T> {
})
}
}

#[cfg(test)]
mod tests {
use super::Array;
use crate::function::Result;
use crate::string::String;

#[test]
fn create_array_and_get() -> Result<()> {
let vec: Vec<String> = vec!["foo".into(), "bar".into(), "baz".into()];
let array = Array::from_vec(vec)?;
assert_eq!(array.get(0)?.to_string(), "foo");
assert_eq!(array.get(1)?.to_string(), "bar");
assert_eq!(array.get(2)?.to_string(), "baz");
Ok(())
}
}
2 changes: 0 additions & 2 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub struct TypeMismatchError {

#[derive(Debug, Error)]
pub enum NDArrayError {
#[error("Missing NDArray shape.")]
MissingShape,
#[error("Cannot convert from an empty array.")]
EmptyArray,
#[error("Invalid datatype when attempting to convert ndarray.")]
Expand Down
Loading