Skip to content

Commit

Permalink
[Rust] Some rust cleanups (apache#6116)
Browse files Browse the repository at this point in the history
* Some rust cleanups

* Turn off default features for bindgen
* Upgrade some deps for smaller total dep tree
* Switch (/complete switch) to thiserror
* Remove unnecessary transmutes

* Fix null pointer assert

* Update wasm32 test
  • Loading branch information
binarybana authored and Trevor Morris committed Sep 2, 2020
1 parent f226005 commit 8e84327
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 73 deletions.
5 changes: 3 additions & 2 deletions rust/tvm-graph-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
crossbeam = "0.7.3"
failure = "0.1"
crossbeam-channel = "0.4"
thiserror = "1"

itertools = "0.8"
lazy_static = "1.4"
ndarray="0.12"
Expand Down
20 changes: 10 additions & 10 deletions rust/tvm-graph-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};

use failure::{ensure, Error};
use ndarray;
use tvm_sys::{ffi::DLTensor, Context, DataType};

use crate::allocator::Allocation;
use crate::errors::ArrayError;
use std::alloc::LayoutErr;

/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
Expand All @@ -36,7 +37,7 @@ pub enum Storage<'a> {
}

impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, LayoutErr> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}

Expand Down Expand Up @@ -297,21 +298,20 @@ impl<'a> Tensor<'a> {
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
type Error = ArrayError;
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Self::Error> {
if tensor.dtype != $dtype {
return Err(ArrayError::IncompatibleDataType(tensor.dtype));
}
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.iter()
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)?)
)
.map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?)
}
}
};
Expand Down
37 changes: 29 additions & 8 deletions rust/tvm-graph-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,39 @@
* under the License.
*/

use failure::Fail;
use thiserror::Error;
use tvm_sys::DataType;

#[derive(Debug, Fail)]
#[derive(Debug, Error)]
pub enum GraphFormatError {
#[fail(display = "Could not parse graph json")]
Parse(#[fail(cause)] failure::Error),
#[fail(display = "Could not parse graph params")]
#[error("Could not parse graph json")]
Parse(#[from] serde_json::Error),
#[error("Could not parse graph params")]
Params,
#[fail(display = "{} is missing attr: {}", 0, 1)]
#[error("{0} is missing attr: {1}")]
MissingAttr(String, String),
#[fail(display = "Missing field: {}", 0)]
#[error("Graph has invalid attr that can't be parsed: {0}")]
InvalidAttr(#[from] std::num::ParseIntError),
#[error("Missing field: {0}")]
MissingField(&'static str),
#[fail(display = "Invalid DLType: {}", 0)]
#[error("Invalid DLType: {0}")]
InvalidDLType(String),
#[error("Unsupported Op: {0}")]
UnsupportedOp(String),
}

#[derive(Debug, Error)]
#[error("Function {0} not found")]
pub struct FunctionNotFound(pub String);

#[derive(Debug, Error)]
#[error("Pointer {0:?} invalid when freeing")]
pub struct InvalidPointer(pub *mut u8);

#[derive(Debug, Error)]
pub enum ArrayError {
#[error("Cannot convert Tensor with dtype {0} to ndarray")]
IncompatibleDataType(DataType),
#[error("Shape error when casting ndarray to TVM Array with shape {0:?}")]
ShapeError(Vec<i64>),
}
51 changes: 29 additions & 22 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
* under the License.
*/

use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
use std::{
cmp, collections::HashMap, convert::TryFrom, error::Error, iter::FromIterator, mem, str,
};

use failure::{ensure, format_err, Error};
use itertools::izip;
use nom::{
character::complete::{alpha1, digit1},
complete, count, do_parse, length_count, map, named,
number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
opt, tag, take, tuple,
};

use serde::{Deserialize, Serialize};
use serde_json;

use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt};

use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};

use crate::{errors::GraphFormatError, Module, Storage, Tensor};
use crate::{errors::*, Module, Storage, Tensor};

// @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
Expand Down Expand Up @@ -114,7 +114,7 @@ macro_rules! get_node_attr {
}

impl Node {
fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
fn parse_attrs(&self) -> Result<NodeAttrs, GraphFormatError> {
let attrs = self
.attrs
.as_ref()
Expand All @@ -128,15 +128,15 @@ impl Node {
}

impl<'a> TryFrom<&'a String> for Graph {
type Error = Error;
fn try_from(graph_json: &String) -> Result<Self, self::Error> {
type Error = GraphFormatError;
fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
}

impl<'a> TryFrom<&'a str> for Graph {
type Error = Error;
type Error = GraphFormatError;
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
Expand Down Expand Up @@ -177,7 +177,7 @@ pub struct GraphExecutor<'m, 't> {
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}

impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Box<dyn Error>> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
Expand All @@ -194,7 +194,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}

/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Box<dyn Error>> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
Expand All @@ -221,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
.collect::<Result<Vec<Storage>, Error>>()?;
.collect::<Result<Vec<Storage>, std::alloc::LayoutErr>>()?;

let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
Expand All @@ -246,34 +246,40 @@ impl<'m, 't> GraphExecutor<'m, 't> {
graph: &Graph,
lib: &'m M,
tensors: &[Tensor<'t>],
) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
) -> Result<Vec<Box<dyn Fn() + 'm>>, Box<dyn Error + 'static>> {
if !graph.node_row_ptr.is_some() {
return Err(GraphFormatError::MissingField("node_row_ptr").into());
}
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();

let mut op_execs = Vec::new();
for (i, node) in graph.nodes.iter().enumerate() {
if node.op == "null" {
continue;
}
ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
ensure!(node.attrs.is_some(), "Missing node attrs.");
if node.op != "tvm_op" {
return Err(GraphFormatError::UnsupportedOp(node.op.to_owned()).into());
}
if !node.attrs.is_some() {
return Err(GraphFormatError::MissingAttr(node.op.clone(), "".to_string()).into());
}

let attrs = node.parse_attrs()?;
let attrs: NodeAttrs = node.parse_attrs()?.into();

if attrs.func_name == "__nop" {
continue;
}

let func = lib
.get_function(&attrs.func_name)
.ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?;
.ok_or_else(|| FunctionNotFound(attrs.func_name.clone()))?;
let arg_indices = node
.inputs
.iter()
.map(|entry| graph.entry_index(entry))
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi)));

let dl_tensors = arg_indices
let dl_tensors: Vec<DLTensor> = arg_indices
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
Expand All @@ -282,14 +288,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
DLTensor::from(tensor)
})
})
.collect::<Result<Vec<DLTensor>, Error>>()
.unwrap();
.collect::<Result<Vec<DLTensor>, GraphFormatError>>()?
.into();
let op: Box<dyn Fn()> = Box::new(move || {
let args = dl_tensors
let args: Vec<ArgValue> = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<ArgValue>>();
func(&args).unwrap();
let err_str = format!("Function {} failed to execute", attrs.func_name);
func(&args).expect(&err_str);
});
op_execs.push(op);
}
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/src/module/dso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ macro_rules! init_context_func {
}

impl<'a> DsoModule<'a> {
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, std::io::Error> {
let lib = libloading::Library::new(filename)?;

init_context_func!(
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/src/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::{
#[cfg(not(target_arch = "wasm32"))]
use std::env;

use crossbeam::channel::{bounded, Receiver, Sender};
use crossbeam_channel::{bounded, Receiver, Sender};
use tvm_sys::ffi::TVMParallelGroupEnv;

pub(crate) type FTVMParallelLambda =
Expand Down
13 changes: 7 additions & 6 deletions rust/tvm-graph-rt/src/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

use std::{
cell::RefCell,
error::Error,
os::raw::{c_int, c_void},
ptr,
};

use failure::{format_err, Error};

use crate::allocator::Allocation;
use crate::errors::InvalidPointer;
use std::alloc::LayoutErr;

const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`

Expand All @@ -49,13 +50,13 @@ impl WorkspacePool {
}
}

fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}

fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
if self.free.is_empty() {
return self.alloc_new(size);
}
Expand All @@ -82,7 +83,7 @@ impl WorkspacePool {
}
}

fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
fn free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
Expand All @@ -92,7 +93,7 @@ impl WorkspacePool {
break;
}
}
let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?;
let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?;
self.free.push(ws_idx);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/tests/test_wasm32/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ unsafe fn __get_tvm_module_ctx() -> i32 {

extern crate ndarray;
#[macro_use]
extern crate tvm_runtime;
extern crate tvm_graph_rt;

use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
use tvm_graph_rt::{DLTensor, Module as _, SystemLibModule};

fn main() {
// try static
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ edition = "2018"
proc-macro = true

[dependencies]
goblin = "0.0.24"
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
Loading

0 comments on commit 8e84327

Please sign in to comment.