From cb2a6b55a1e8e70f88ab565db8f5e3201725b11f Mon Sep 17 00:00:00 2001 From: Jason Knight Date: Thu, 23 Jul 2020 14:04:30 -0700 Subject: [PATCH] [Rust] Some rust cleanups (#6116) * Some rust cleanups * Turn off default features for bindgen * Upgrade some deps for smaller total dep tree * Switch (/complete switch) to thiserror * Remove unnecessary transmutes * Fix null pointer assert * Update wasm32 test --- rust/tvm-graph-rt/Cargo.toml | 5 +- rust/tvm-graph-rt/src/array.rs | 20 ++++---- rust/tvm-graph-rt/src/errors.rs | 37 +++++++++++--- rust/tvm-graph-rt/src/graph.rs | 51 +++++++++++-------- rust/tvm-graph-rt/src/module/dso.rs | 2 +- rust/tvm-graph-rt/src/threading.rs | 2 +- rust/tvm-graph-rt/src/workspace.rs | 13 ++--- .../tests/test_wasm32/src/main.rs | 4 +- rust/tvm-macros/Cargo.toml | 2 +- rust/tvm-rt/src/object/object_ptr.rs | 33 +++++------- rust/tvm-sys/Cargo.toml | 2 +- 11 files changed, 98 insertions(+), 73 deletions(-) diff --git a/rust/tvm-graph-rt/Cargo.toml b/rust/tvm-graph-rt/Cargo.toml index 0cf2ac139ff79..d8dfcdb73269d 100644 --- a/rust/tvm-graph-rt/Cargo.toml +++ b/rust/tvm-graph-rt/Cargo.toml @@ -28,8 +28,9 @@ authors = ["TVM Contributors"] edition = "2018" [dependencies] -crossbeam = "0.7.3" -failure = "0.1" +crossbeam-channel = "0.4" +thiserror = "1" + itertools = "0.8" lazy_static = "1.4" ndarray="0.12" diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs index 1ed0f3cc47570..b911aa816489e 100644 --- a/rust/tvm-graph-rt/src/array.rs +++ b/rust/tvm-graph-rt/src/array.rs @@ -19,11 +19,12 @@ use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; -use failure::{ensure, Error}; use ndarray; use tvm_sys::{ffi::DLTensor, Context, DataType}; use crate::allocator::Allocation; +use crate::errors::ArrayError; +use std::alloc::LayoutErr; /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] @@ -36,7 +37,7 @@ pub enum Storage<'a> { } impl<'a> Storage<'a> { - pub fn new(size: usize, align: Option) -> Result, Error> { + pub fn new(size: usize, align: Option) -> Result, LayoutErr> { Ok(Storage::Owned(Allocation::new(size, align)?)) } @@ -297,13 +298,11 @@ impl<'a> Tensor<'a> { macro_rules! impl_ndarray_try_from_tensor { ($type:ty, $dtype:expr) => { impl<'t> TryFrom> for ndarray::ArrayD<$type> { - type Error = Error; - fn try_from(tensor: Tensor) -> Result, Error> { - ensure!( - tensor.dtype == $dtype, - "Cannot convert Tensor with dtype {:?} to ndarray", - tensor.dtype - ); + type Error = ArrayError; + fn try_from(tensor: Tensor) -> Result, Self::Error> { + if tensor.dtype != $dtype { + return Err(ArrayError::IncompatibleDataType(tensor.dtype)); + } Ok(ndarray::Array::from_shape_vec( tensor .shape @@ -311,7 +310,8 @@ macro_rules! impl_ndarray_try_from_tensor { .map(|s| *s as usize) .collect::>(), tensor.to_vec::<$type>(), - )?) + ) + .map_err(|_| ArrayError::ShapeError(tensor.shape.clone()))?) } } }; diff --git a/rust/tvm-graph-rt/src/errors.rs b/rust/tvm-graph-rt/src/errors.rs index d82da15f87f4d..2ca97bdabb6bd 100644 --- a/rust/tvm-graph-rt/src/errors.rs +++ b/rust/tvm-graph-rt/src/errors.rs @@ -17,18 +17,39 @@ * under the License. */ -use failure::Fail; +use thiserror::Error; +use tvm_sys::DataType; -#[derive(Debug, Fail)] +#[derive(Debug, Error)] pub enum GraphFormatError { - #[fail(display = "Could not parse graph json")] - Parse(#[fail(cause)] failure::Error), - #[fail(display = "Could not parse graph params")] + #[error("Could not parse graph json")] + Parse(#[from] serde_json::Error), + #[error("Could not parse graph params")] Params, - #[fail(display = "{} is missing attr: {}", 0, 1)] + #[error("{0} is missing attr: {1}")] MissingAttr(String, String), - #[fail(display = "Missing field: {}", 0)] + #[error("Graph has invalid attr that can't be parsed: {0}")] + InvalidAttr(#[from] std::num::ParseIntError), + #[error("Missing field: {0}")] MissingField(&'static str), - #[fail(display = "Invalid DLType: {}", 0)] + #[error("Invalid DLType: {0}")] InvalidDLType(String), + #[error("Unsupported Op: {0}")] + UnsupportedOp(String), +} + +#[derive(Debug, Error)] +#[error("Function {0} not found")] +pub struct FunctionNotFound(pub String); + +#[derive(Debug, Error)] +#[error("Pointer {0:?} invalid when freeing")] +pub struct InvalidPointer(pub *mut u8); + +#[derive(Debug, Error)] +pub enum ArrayError { + #[error("Cannot convert Tensor with dtype {0} to ndarray")] + IncompatibleDataType(DataType), + #[error("Shape error when casting ndarray to TVM Array with shape {0:?}")] + ShapeError(Vec), } diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index 895739de62d85..91021dd12bb76 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -17,9 +17,10 @@ * under the License. */ -use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; +use std::{ + cmp, collections::HashMap, convert::TryFrom, error::Error, iter::FromIterator, mem, str, +}; -use failure::{ensure, format_err, Error}; use itertools::izip; use nom::{ character::complete::{alpha1, digit1}, @@ -27,7 +28,6 @@ use nom::{ number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8}, opt, tag, take, tuple, }; - use serde::{Deserialize, Serialize}; use serde_json; @@ -35,7 +35,7 @@ use tvm_sys::ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCod use tvm_sys::{ffi::DLTensor, ArgValue, Context, DataType, DeviceType}; -use crate::{errors::GraphFormatError, Module, Storage, Tensor}; +use crate::{errors::*, Module, Storage, Tensor}; // @see `kTVMNDArrayMagic` in `ndarray.h` const _NDARRAY_MAGIC: u64 = 0xDD5E_40F0_96B4_A13F; @@ -114,7 +114,7 @@ macro_rules! get_node_attr { } impl Node { - fn parse_attrs(&self) -> Result { + fn parse_attrs(&self) -> Result { let attrs = self .attrs .as_ref() @@ -128,15 +128,15 @@ impl Node { } impl<'a> TryFrom<&'a String> for Graph { - type Error = Error; - fn try_from(graph_json: &String) -> Result { + type Error = GraphFormatError; + fn try_from(graph_json: &String) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) } } impl<'a> TryFrom<&'a str> for Graph { - type Error = Error; + type Error = GraphFormatError; fn try_from(graph_json: &'a str) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) @@ -177,7 +177,7 @@ pub struct GraphExecutor<'m, 't> { unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} impl<'m, 't> GraphExecutor<'m, 't> { - pub fn new(graph: Graph, lib: &'m M) -> Result { + pub fn new(graph: Graph, lib: &'m M) -> Result> { let tensors = Self::setup_storages(&graph)?; Ok(GraphExecutor { op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, @@ -194,7 +194,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. - fn setup_storages<'a>(graph: &'a Graph) -> Result>, Error> { + fn setup_storages<'a>(graph: &'a Graph) -> Result>, Box> { let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; let dtypes = graph @@ -221,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let mut storages: Vec = storage_num_bytes .into_iter() .map(|nbytes| Storage::new(nbytes, align)) - .collect::, Error>>()?; + .collect::, std::alloc::LayoutErr>>()?; let tensors = izip!(storage_ids, shapes, dtypes) .map(|(storage_id, shape, dtype)| { @@ -246,8 +246,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { graph: &Graph, lib: &'m M, tensors: &[Tensor<'t>], - ) -> Result>, Error> { - ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + ) -> Result>, Box> { + if !graph.node_row_ptr.is_some() { + return Err(GraphFormatError::MissingField("node_row_ptr").into()); + } let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); let mut op_execs = Vec::new(); @@ -255,10 +257,14 @@ impl<'m, 't> GraphExecutor<'m, 't> { if node.op == "null" { continue; } - ensure!(node.op == "tvm_op", "Only TVM ops are supported."); - ensure!(node.attrs.is_some(), "Missing node attrs."); + if node.op != "tvm_op" { + return Err(GraphFormatError::UnsupportedOp(node.op.to_owned()).into()); + } + if !node.attrs.is_some() { + return Err(GraphFormatError::MissingAttr(node.op.clone(), "".to_string()).into()); + } - let attrs = node.parse_attrs()?; + let attrs: NodeAttrs = node.parse_attrs()?.into(); if attrs.func_name == "__nop" { continue; @@ -266,14 +272,14 @@ impl<'m, 't> GraphExecutor<'m, 't> { let func = lib .get_function(&attrs.func_name) - .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?; + .ok_or_else(|| FunctionNotFound(attrs.func_name.clone()))?; let arg_indices = node .inputs .iter() .map(|entry| graph.entry_index(entry)) .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i] + oi))); - let dl_tensors = arg_indices + let dl_tensors: Vec = arg_indices .map(|idx| { let tensor = &tensors[idx?]; Ok(if attrs.flatten_data { @@ -282,14 +288,15 @@ impl<'m, 't> GraphExecutor<'m, 't> { DLTensor::from(tensor) }) }) - .collect::, Error>>() - .unwrap(); + .collect::, GraphFormatError>>()? + .into(); let op: Box = Box::new(move || { - let args = dl_tensors + let args: Vec = dl_tensors .iter() .map(|t| t.into()) .collect::>(); - func(&args).unwrap(); + let err_str = format!("Function {} failed to execute", attrs.func_name); + func(&args).expect(&err_str); }); op_execs.push(op); } diff --git a/rust/tvm-graph-rt/src/module/dso.rs b/rust/tvm-graph-rt/src/module/dso.rs index 51645d5b8111c..f1145da4b4ded 100644 --- a/rust/tvm-graph-rt/src/module/dso.rs +++ b/rust/tvm-graph-rt/src/module/dso.rs @@ -59,7 +59,7 @@ macro_rules! init_context_func { } impl<'a> DsoModule<'a> { - pub fn new>(filename: P) -> Result>, failure::Error> { + pub fn new>(filename: P) -> Result>, std::io::Error> { let lib = libloading::Library::new(filename)?; init_context_func!( diff --git a/rust/tvm-graph-rt/src/threading.rs b/rust/tvm-graph-rt/src/threading.rs index 9b83ff37116ec..cbb3bf14c31c5 100644 --- a/rust/tvm-graph-rt/src/threading.rs +++ b/rust/tvm-graph-rt/src/threading.rs @@ -29,7 +29,7 @@ use std::{ #[cfg(not(target_arch = "wasm32"))] use std::env; -use crossbeam::channel::{bounded, Receiver, Sender}; +use crossbeam_channel::{bounded, Receiver, Sender}; use tvm_sys::ffi::TVMParallelGroupEnv; pub(crate) type FTVMParallelLambda = diff --git a/rust/tvm-graph-rt/src/workspace.rs b/rust/tvm-graph-rt/src/workspace.rs index 35cfe91423d4e..cf264974bc032 100644 --- a/rust/tvm-graph-rt/src/workspace.rs +++ b/rust/tvm-graph-rt/src/workspace.rs @@ -19,13 +19,14 @@ use std::{ cell::RefCell, + error::Error, os::raw::{c_int, c_void}, ptr, }; -use failure::{format_err, Error}; - use crate::allocator::Allocation; +use crate::errors::InvalidPointer; +use std::alloc::LayoutErr; const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` @@ -49,13 +50,13 @@ impl WorkspacePool { } } - fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> { + fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> { self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.in_use.push(self.workspaces.len() - 1); Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) } - fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> { + fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> { if self.free.is_empty() { return self.alloc_new(size); } @@ -82,7 +83,7 @@ impl WorkspacePool { } } - fn free(&mut self, ptr: *mut u8) -> Result<(), Error> { + fn free(&mut self, ptr: *mut u8) -> Result<(), Box> { let mut ws_idx = None; for i in 0..self.in_use.len() { let idx = self.in_use[i]; @@ -92,7 +93,7 @@ impl WorkspacePool { break; } } - let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?; + let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?; self.free.push(ws_idx); Ok(()) } diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs b/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs index a46cfa979becd..67ef21779cde5 100644 --- a/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs +++ b/rust/tvm-graph-rt/tests/test_wasm32/src/main.rs @@ -30,10 +30,10 @@ unsafe fn __get_tvm_module_ctx() -> i32 { extern crate ndarray; #[macro_use] -extern crate tvm_runtime; +extern crate tvm_graph_rt; use ndarray::Array; -use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; +use tvm_graph_rt::{DLTensor, Module as _, SystemLibModule}; fn main() { // try static diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index 7abc9ae64f7c6..a9ac09e6fa68b 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -30,7 +30,7 @@ edition = "2018" proc-macro = true [dependencies] -goblin = "0.0.24" +goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 7d133fac18d96..68808241250b2 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -38,7 +38,7 @@ type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); #[derive(Debug)] #[repr(C)] pub struct Object { - /// The index into into TVM's runtime type information table. + /// The index into TVM's runtime type information table. pub(self) type_index: u32, // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. // NB: in general we should not touch this in Rust. @@ -57,10 +57,10 @@ pub struct Object { /// trait magic here to get a monomorphized deleter for each object /// "subtype". /// -/// This function just transmutes the pointer to the correct type +/// This function just converts the pointer to the correct type /// and invokes the underlying typed delete function. unsafe extern "C" fn delete(object: *mut Object) { - let typed_object: *mut T = std::mem::transmute(object); + let typed_object: *mut T = object as *mut T; T::typed_delete(typed_object); } @@ -104,8 +104,7 @@ impl Object { } else { let mut index = 0; unsafe { - let index_ptr = std::mem::transmute(&mut index); - if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 { + if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 { panic!(crate::get_last_error()) } } @@ -130,16 +129,16 @@ impl Object { /// Increases the object's reference count by one. pub(self) fn inc_ref(&self) { + let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; unsafe { - let raw_ptr = std::mem::transmute(self); assert_eq!(TVMObjectRetain(raw_ptr), 0); } } /// Decreases the object's reference count by one. pub(self) fn dec_ref(&self) { + let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; unsafe { - let raw_ptr = std::mem::transmute(self); assert_eq!(TVMObjectFree(raw_ptr), 0); } } @@ -277,10 +276,9 @@ impl std::ops::Deref for ObjectPtr { impl<'a, T: IsObject> From> for RetValue { fn from(object_ptr: ObjectPtr) -> RetValue { - let raw_object_ptr = ObjectPtr::leak(object_ptr); - let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) }; - assert!(!void_ptr.is_null()); - RetValue::ObjectHandle(void_ptr) + let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; + assert!(!raw_object_ptr.is_null()); + RetValue::ObjectHandle(raw_object_ptr) } } @@ -290,8 +288,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { fn try_from(ret_value: RetValue) -> Result, Self::Error> { match ret_value { RetValue::ObjectHandle(handle) => { - let handle: *mut Object = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); println!("back to type {}", optr.count()); optr.downcast() @@ -304,10 +301,9 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let raw_object_ptr = ObjectPtr::leak(object_ptr); - let void_ptr: *mut std::ffi::c_void = unsafe { std::mem::transmute(raw_object_ptr) }; - assert!(!void_ptr.is_null()); - ArgValue::ObjectHandle(void_ptr) + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::ObjectHandle(raw_ptr) } } @@ -317,8 +313,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { match arg_value { ArgValue::ObjectHandle(handle) => { - let handle = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); println!("count: {}", optr.count()); optr.downcast() diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index fe4d0bf987bf7..faddce48d15d0 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -32,4 +32,4 @@ ndarray = "0.12" enumn = "^0.1" [build-dependencies] -bindgen = "0.51" +bindgen = { version="0.51", default-features=false }