Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Some rust cleanups #6116

Merged
merged 3 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rust/tvm-graph-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
crossbeam = "0.7.3"
failure = "0.1"
crossbeam-channel = "0.4"
thiserror = "1"

itertools = "0.8"
lazy_static = "1.4"
ndarray="0.12"
Expand Down
20 changes: 10 additions & 10 deletions rust/tvm-graph-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};

use failure::{ensure, Error};
use ndarray;
use tvm_sys::{ffi::DLTensor, Context, DataType};

use crate::allocator::Allocation;
use crate::errors::ArrayError;
use std::alloc::LayoutErr;

/// A `Storage` is a container which holds `Tensor` data.
#[derive(PartialEq)]
Expand All @@ -36,7 +37,7 @@ pub enum Storage<'a> {
}

impl<'a> Storage<'a> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, LayoutErr> {
Ok(Storage::Owned(Allocation::new(size, align)?))
}

Expand Down Expand Up @@ -297,21 +298,20 @@ impl<'a> Tensor<'a> {
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'t> TryFrom<Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
type Error = ArrayError;
fn try_from(tensor: Tensor) -> Result<ndarray::ArrayD<$type>, Self::Error> {
if tensor.dtype != $dtype {
return Err(ArrayError::IncompatibleDataType(tensor.dtype));
}
Ok(ndarray::Array::from_shape_vec(
tensor
.shape
.iter()
.map(|s| *s as usize)
.collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)?)
)
.map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?)
}
}
};
Expand Down
37 changes: 29 additions & 8 deletions rust/tvm-graph-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,39 @@
* under the License.
*/

use failure::Fail;
use thiserror::Error;
use tvm_sys::DataType;

#[derive(Debug, Fail)]
#[derive(Debug, Error)]
pub enum GraphFormatError {
#[fail(display = "Could not parse graph json")]
Parse(#[fail(cause)] failure::Error),
#[fail(display = "Could not parse graph params")]
#[error("Could not parse graph json")]
Parse(#[from] serde_json::Error),
#[error("Could not parse graph params")]
Params,
#[fail(display = "{} is missing attr: {}", 0, 1)]
#[error("{0} is missing attr: {1}")]
MissingAttr(String, String),
#[fail(display = "Missing field: {}", 0)]
#[error("Graph has invalid attr that can't be parsed: {0}")]
InvalidAttr(#[from] std::num::ParseIntError),
#[error("Missing field: {0}")]
MissingField(&'static str),
#[fail(display = "Invalid DLType: {}", 0)]
#[error("Invalid DLType: {0}")]
InvalidDLType(String),
#[error("Unsupported Op: {0}")]
UnsupportedOp(String),
}

#[derive(Debug, Error)]
#[error("Function {0} not found")]
pub struct FunctionNotFound(pub String);

#[derive(Debug, Error)]
#[error("Pointer {0:?} invalid when freeing")]
pub struct InvalidPointer(pub *mut u8);

#[derive(Debug, Error)]
pub enum ArrayError {
#[error("Cannot convert Tensor with dtype {0} to ndarray")]
IncompatibleDataType(DataType),
#[error("Shape error when casting ndarray to TVM Array with shape {0:?}")]
ShapeError(Vec<i64>),
}
51 changes: 29 additions & 22 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
* under the License.
*/

use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
use std::{
cmp, collections::HashMap, convert::TryFrom, error::Error, iter::FromIterator, mem, str,
};

use failure::{ensure, format_err, Error};
use itertools::izip;
use nom::{
character::complete::{alpha1, digit1},
complete, count, do_parse, length_count, map, named,
number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
opt, tag, take, tuple,
};

use serde::{Deserialize, Serialize};
use serde_json;

use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt};

use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType};

use crate::{errors::GraphFormatError, Module, Storage, Tensor};
use crate::{errors::*, Module, Storage, Tensor};

// @see `kTVMNDArrayMagic` in `ndarray.h`
const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F;
Expand Down Expand Up @@ -114,7 +114,7 @@ macro_rules! get_node_attr {
}

impl Node {
fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
fn parse_attrs(&self) -> Result<NodeAttrs, GraphFormatError> {
let attrs = self
.attrs
.as_ref()
Expand All @@ -128,15 +128,15 @@ impl Node {
}

impl<'a> TryFrom<&'a String> for Graph {
type Error = Error;
fn try_from(graph_json: &String) -> Result<Self, self::Error> {
type Error = GraphFormatError;
fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
}
}

impl<'a> TryFrom<&'a str> for Graph {
type Error = Error;
type Error = GraphFormatError;
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
Expand Down Expand Up @@ -177,7 +177,7 @@ pub struct GraphExecutor<'m, 't> {
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}

impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Box<dyn Error>> {
let tensors = Self::setup_storages(&graph)?;
Ok(GraphExecutor {
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
Expand All @@ -194,7 +194,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}

/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Box<dyn Error>> {
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
let dtypes = graph
Expand All @@ -221,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
let mut storages: Vec<Storage> = storage_num_bytes
.into_iter()
.map(|nbytes| Storage::new(nbytes, align))
.collect::<Result<Vec<Storage>, Error>>()?;
.collect::<Result<Vec<Storage>, std::alloc::LayoutErr>>()?;

let tensors = izip!(storage_ids, shapes, dtypes)
.map(|(storage_id, shape, dtype)| {
Expand All @@ -246,34 +246,40 @@ impl<'m, 't> GraphExecutor<'m, 't> {
graph: &Graph,
lib: &'m M,
tensors: &[Tensor<'t>],
) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
) -> Result<Vec<Box<dyn Fn() + 'm>>, Box<dyn Error + 'static>> {
if !graph.node_row_ptr.is_some() {
return Err(GraphFormatError::MissingField("node_row_ptr").into());
}
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();

let mut op_execs = Vec::new();
for (i, node) in graph.nodes.iter().enumerate() {
if node.op == "null" {
continue;
}
ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
ensure!(node.attrs.is_some(), "Missing node attrs.");
if node.op != "tvm_op" {
return Err(GraphFormatError::UnsupportedOp(node.op.to_owned()).into());
}
if !node.attrs.is_some() {
return Err(GraphFormatError::MissingAttr(node.op.clone(), "".to_string()).into());
}

let attrs = node.parse_attrs()?;
let attrs: NodeAttrs = node.parse_attrs()?.into();

if attrs.func_name == "__nop" {
continue;
}

let func = lib
.get_function(&attrs.func_name)
.ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?;
.ok_or_else(|| FunctionNotFound(attrs.func_name.clone()))?;
let arg_indices = node
.inputs
.iter()
.map(|entry| graph.entry_index(entry))
.chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi)));

let dl_tensors = arg_indices
let dl_tensors: Vec<DLTensor> = arg_indices
.map(|idx| {
let tensor = &tensors[idx?];
Ok(if attrs.flatten_data {
Expand All @@ -282,14 +288,15 @@ impl<'m, 't> GraphExecutor<'m, 't> {
DLTensor::from(tensor)
})
})
.collect::<Result<Vec<DLTensor>, Error>>()
.unwrap();
.collect::<Result<Vec<DLTensor>, GraphFormatError>>()?
.into();
let op: Box<dyn Fn()> = Box::new(move || {
let args = dl_tensors
let args: Vec<ArgValue> = dl_tensors
.iter()
.map(|t| t.into())
.collect::<Vec<ArgValue>>();
func(&args).unwrap();
let err_str = format!("Function {} failed to execute", attrs.func_name);
func(&args).expect(&err_str);
});
op_execs.push(op);
}
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/src/module/dso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ macro_rules! init_context_func {
}

impl<'a> DsoModule<'a> {
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, failure::Error> {
pub fn new<P: AsRef<std::ffi::OsStr>>(filename: P) -> Result<Pin<Box<Self>>, std::io::Error> {
let lib = libloading::Library::new(filename)?;

init_context_func!(
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/src/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::{
#[cfg(not(target_arch = "wasm32"))]
use std::env;

use crossbeam::channel::{bounded, Receiver, Sender};
use crossbeam_channel::{bounded, Receiver, Sender};
use tvm_sys::ffi::TVMParallelGroupEnv;

pub(crate) type FTVMParallelLambda =
Expand Down
13 changes: 7 additions & 6 deletions rust/tvm-graph-rt/src/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

use std::{
cell::RefCell,
error::Error,
os::raw::{c_int, c_void},
ptr,
};

use failure::{format_err, Error};

use crate::allocator::Allocation;
use crate::errors::InvalidPointer;
use std::alloc::LayoutErr;

const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`

Expand All @@ -49,13 +50,13 @@ impl WorkspacePool {
}
}

fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> {
fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}

fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> {
fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
if self.free.is_empty() {
return self.alloc_new(size);
}
Expand All @@ -82,7 +83,7 @@ impl WorkspacePool {
}
}

fn free(&mut self, ptr: *mut u8) -> Result<(), Error> {
fn free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>> {
let mut ws_idx = None;
for i in 0..self.in_use.len() {
let idx = self.in_use[i];
Expand All @@ -92,7 +93,7 @@ impl WorkspacePool {
break;
}
}
let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?;
let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?;
self.free.push(ws_idx);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/tests/test_wasm32/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ unsafe fn __get_tvm_module_ctx() -> i32 {

extern crate ndarray;
#[macro_use]
extern crate tvm_runtime;
extern crate tvm_graph_rt;

use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
use tvm_graph_rt::{DLTensor, Module as _, SystemLibModule};

fn main() {
// try static
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ edition = "2018"
proc-macro = true

[dependencies]
goblin = "0.0.24"
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
Loading