From 277bfc86aac94442625e8307aa55a9215e2fddd6 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Tue, 6 Oct 2020 01:52:53 -0700 Subject: [PATCH] [Rust] Improve NDArray, GraphRt, and Relay bindings (#6563) * WIP WIP * Add support for loading Python packed functions * Flesh out Relay AST in Rust * More tweeks for getting functions out * Deploy Rust docs as part of build * Add some more types * Introduce NDArray 2.0 * Work on NDArray 2.0 before restoring tests * Formatting and code fixes to get it to compile * Add more Rust bindings - Converts Conv2d attrs to use tvm::String, so that we can add Rust binding - Uses Type for checked_type in Rust bindings - Fix type key in Rust bindings - Make data field contain NDArray in Rust bindings * Clean up object ptr passing. * WIP * Add debugging for NDArray and fix all test cases * Add breaking test * Dispatch some todos * Format * Fix ndarray size and len * Add BiasAddAttrs rust bindings * Add DenseAttrs rust bindings * Change to TVM string * Add more Rust bindings Add GlobalPool2DAttrs Rust binding Add ExpandDimsAttrs Rust bindings Add MaxPool2DAttrs rust bindings * Fix some test attributes * Improve the NDArray api * Fix some more ndarray stuff * Get the resnet demo kinda working * Add SoftmaxAttrs Rust bindings * Implement Hash and Eq for Relay Exprs * Add underscore to unused function * Fix broken ass resnet script * Improve some ndarray conversions * Make sure the build script runs correctly * Clean up ResNet example tremedously Expose C++ graph runtime via cleaner Rust API rewrite example. * Add ASF header * Format * Format * Format resnet rust python script * Add type files and refactor span * Format * Format * Change types from std::string to tvm::String in packed function * Add ASF header * Fix test w/ ndarray's API change * Fix array test * Fix anyhow import * Put back some anyhow stuff * Clean up * Try and fix tests/scripts/task_rust.sh * Disable ResNet for now * Turn off building of Rust docs until we update CI * Actually disable Co-authored-by: Jared Roesch Co-authored-by: Gus Smith --- include/tvm/relay/adt.h | 1 + include/tvm/relay/attrs/nn.h | 12 +- include/tvm/runtime/ndarray.h | 3 +- include/tvm/tir/data_layout.h | 3 + rust/tvm-graph-rt/src/graph.rs | 8 +- rust/tvm-graph-rt/src/threading.rs | 4 +- rust/tvm-macros/Cargo.toml | 1 + rust/tvm-rt/Cargo.toml | 1 + rust/tvm-rt/src/array.rs | 17 + rust/tvm-rt/src/errors.rs | 2 - rust/tvm-rt/src/ndarray.rs | 322 ++++++----- rust/tvm-rt/src/object/object_ptr.rs | 72 +-- rust/tvm-rt/src/string.rs | 4 +- rust/tvm-rt/src/value.rs | 56 +- rust/tvm-sys/src/packed_func.rs | 15 + rust/tvm/Cargo.toml | 4 + rust/tvm/examples/resnet/Cargo.toml | 3 +- rust/tvm/examples/resnet/build.rs | 10 +- rust/tvm/examples/resnet/src/build_resnet.py | 95 ++-- rust/tvm/examples/resnet/src/main.rs | 130 ++--- rust/tvm/src/ir/arith.rs | 2 +- rust/tvm/src/ir/attrs.rs | 29 + rust/tvm/src/ir/expr.rs | 100 ++++ rust/tvm/src/ir/function.rs | 46 ++ rust/tvm/src/ir/mod.rs | 62 +-- rust/tvm/src/ir/module.rs | 159 ++++++ rust/tvm/src/ir/op.rs | 43 ++ rust/tvm/src/ir/relay/attrs/mod.rs | 21 + rust/tvm/src/ir/relay/attrs/nn.rs | 96 ++++ rust/tvm/src/ir/relay/attrs/transform.rs | 31 ++ rust/tvm/src/ir/relay/mod.rs | 499 ++++++++++++++---- rust/tvm/src/ir/span.rs | 22 + rust/tvm/src/ir/tir.rs | 4 +- rust/tvm/src/ir/ty.rs | 242 +++++++++ rust/tvm/src/lib.rs | 2 + rust/tvm/src/python.rs | 60 +++ rust/tvm/src/runtime/graph_rt.rs | 97 ++++ rust/tvm/src/runtime/mod.rs | 2 + rust/tvm/tests/basics/src/main.rs | 4 +- rust/tvm/tests/callback/src/bin/array.rs | 19 +- src/relay/op/nn/convolution.cc | 4 +- src/relay/qnn/op/convolution.cc | 6 +- .../transforms/combine_parallel_conv2d.cc | 3 +- src/relay/transforms/pattern_util.h | 2 +- src/runtime/ndarray.cc | 5 +- tests/scripts/task_python_docs.sh | 9 + tests/scripts/task_rust.sh | 13 +- 47 files changed, 1803 insertions(+), 542 deletions(-) create mode 100644 rust/tvm/src/ir/attrs.rs create mode 100644 rust/tvm/src/ir/expr.rs create mode 100644 rust/tvm/src/ir/function.rs create mode 100644 rust/tvm/src/ir/module.rs create mode 100644 rust/tvm/src/ir/op.rs create mode 100644 rust/tvm/src/ir/relay/attrs/mod.rs create mode 100644 rust/tvm/src/ir/relay/attrs/nn.rs create mode 100644 rust/tvm/src/ir/relay/attrs/transform.rs create mode 100644 rust/tvm/src/ir/span.rs create mode 100644 rust/tvm/src/ir/ty.rs create mode 100644 rust/tvm/src/python.rs create mode 100644 rust/tvm/src/runtime/graph_rt.rs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 37182abb2681..b5dcab5e0bfc 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -190,6 +190,7 @@ class PatternTuple; /*! \brief PatternVar container node */ class PatternTupleNode : public PatternNode { public: + /* TODO(@jroesch): rename to field_pats */ /*! Sub-patterns to match against each value of the tuple. */ tvm::Array patterns; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index fbe31a305ea5..b2555de6d35e 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -29,6 +29,8 @@ #include +#include "tvm/runtime/container.h" + namespace tvm { namespace relay { @@ -115,9 +117,9 @@ struct Conv2DAttrs : public tvm::AttrsNode { int groups; IndexExpr channels; Array kernel_size; - std::string data_layout; - std::string kernel_layout; - std::string out_layout; + tvm::String data_layout; + tvm::String kernel_layout; + tvm::String out_layout; DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { @@ -681,7 +683,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; - std::string layout; + tvm::String layout; bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { @@ -744,7 +746,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { /*! \brief Attributes for global pool operator */ struct GlobalPool2DAttrs : public tvm::AttrsNode { - std::string layout; + tvm::String layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NCHW").describe( diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 1208190ca7df..92b3857fbec8 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -371,8 +371,9 @@ inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - return reinterpret_cast(static_cast( + auto ptr = reinterpret_cast(static_cast( static_cast(const_cast(nd.get())))); + return ptr; } inline void NDArray::FFIDecRef(TVMArrayHandle handle) { diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index af384f9b67f9..ee93a0675470 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -124,6 +124,9 @@ class Layout : public ObjectRef { public: explicit Layout(const Array& axes); + /*! \brief construct from a string */ + Layout(const tvm::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) + /*! \brief construct from a string */ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) diff --git a/rust/tvm-graph-rt/src/graph.rs b/rust/tvm-graph-rt/src/graph.rs index 91021dd12bb7..87dd4a76d5e4 100644 --- a/rust/tvm-graph-rt/src/graph.rs +++ b/rust/tvm-graph-rt/src/graph.rs @@ -46,8 +46,10 @@ const _NDARRAY_LIST_MAGIC: u64 = 0xF7E5_8D4F_0504_9CB7; /// /// # Examples /// -/// ```norun -/// let graph_json = fs::read_to_string("graph.json").unwrap(); +/// ```no_run +/// use tvm_graph_rt::Graph; +/// use std::convert::TryFrom; +/// let graph_json = std::fs::read_to_string("graph.json").unwrap(); /// let graph = Graph::try_from(&graph_json).unwrap(); /// ``` #[derive(Serialize, Deserialize, Debug)] @@ -147,7 +149,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// /// # Examples /// -/// ```norun +/// ```no_compile /// use ndarray::Array; /// /// let syslib = SystemLibModule::default(); // a provider of TVM functions diff --git a/rust/tvm-graph-rt/src/threading.rs b/rust/tvm-graph-rt/src/threading.rs index cbb3bf14c31c..03765e0a049b 100644 --- a/rust/tvm-graph-rt/src/threading.rs +++ b/rust/tvm-graph-rt/src/threading.rs @@ -215,7 +215,7 @@ pub unsafe extern "C" fn TVMBackendParallelBarrier( #[cfg(test)] mod tests { - use std::{ptr, thread, time::Duration}; + use std::{thread, time::Duration}; use super::*; @@ -228,7 +228,7 @@ mod tests { assert_eq!(max_concurrency(), 24); } - extern "C" fn flambda( + extern "C" fn _flambda( task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void, diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index a9ac09e6fa68..63b84727c525 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -34,3 +34,4 @@ goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" syn = { version = "1.0.17", features = ["full", "extra-traits"] } +proc-macro-error = "^1.0" diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index 465ae583ab6c..acece5aeec48 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -37,6 +37,7 @@ tvm-macros = { version = "0.1", path = "../tvm-macros" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" +memoffset = "0.5.6" [dev-dependencies] anyhow = "^1.0" diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index d2c82fce0b33..5e19cefd8e97 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -117,3 +117,20 @@ impl<'a, T: IsObjectRef> TryFrom for Array { }) } } + +#[cfg(test)] +mod tests { + use super::Array; + use crate::function::Result; + use crate::string::String; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec: Vec = vec!["foo".into(), "bar".into(), "baz".into()]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.to_string(), "foo"); + assert_eq!(array.get(1)?.to_string(), "bar"); + assert_eq!(array.get(2)?.to_string(), "baz"); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index e194bfa9febd..c884c56fed44 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -33,8 +33,6 @@ pub struct TypeMismatchError { #[derive(Debug, Error)] pub enum NDArrayError { - #[error("Missing NDArray shape.")] - MissingShape, #[error("Cannot convert from an empty array.")] EmptyArray, #[error("Invalid datatype when attempting to convert ndarray.")] diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 24fa5e0dfcbc..ed280ccc2d80 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -38,7 +38,7 @@ //! .unwrap() //! .into_dyn(); // Rust's ndarray //! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! assert_eq!(nd.shape(), &[2, 2]); //! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); //! assert!(rnd.all_close(&a, 1e-8f32)); //! ``` @@ -47,73 +47,146 @@ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx -use std::convert::TryInto; use std::ffi::c_void; +use std::{borrow::Cow, convert::TryInto}; use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; -use crate::errors::NDArrayError; - +use mem::size_of; +use tvm_macros::Object; use tvm_sys::ffi::DLTensor; use tvm_sys::{ffi, ByteArray, Context, DataType}; use ndarray::{Array, ArrayD}; use num_traits::Num; +use crate::errors::NDArrayError; + +use crate::object::{Object, ObjectPtr}; + /// See the [`module-level documentation`](../ndarray/index.html) for more details. -/// -/// Wrapper around TVM array handle. -#[derive(Debug)] -pub enum NDArray { - Borrowed { handle: ffi::TVMArrayHandle }, - Owned { handle: *mut c_void }, +#[repr(C)] +#[derive(Object)] +#[ref_name = "NDArray"] +#[type_key = "runtime.NDArray"] +pub struct NDArrayContainer { + base: Object, + // Container Base + dl_tensor: DLTensor, + manager_ctx: *mut c_void, + // TOOD: shape? } -impl NDArray { - pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { - NDArray::Borrowed { handle } +impl NDArrayContainer { + pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; + let object_ptr = ObjectPtr::from_raw(base_ptr.cast()); + object_ptr.map(|ptr| { + ptr.downcast::() + .expect("we know this is an NDArray container") + }) } - pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { - NDArray::Owned { handle } + pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + &mut *std::mem::ManuallyDrop::new(object_ptr) + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } } +} - pub fn as_dltensor(&self) -> &DLTensor { - let ptr: *mut DLTensor = match self { - NDArray::Borrowed { ref handle } => *handle, - NDArray::Owned { ref handle } => *handle as *mut DLTensor, - }; +fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { + if std::mem::size_of::() == 64 { + debug_assert!(slice.iter().all(|&x| x >= 0)); + let shape: &[usize] = unsafe { std::mem::transmute(slice) }; + Cow::Borrowed(shape) + } else { + let shape: Vec = slice + .iter() + .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot fit into usize: {}", x))) + .collect(); + Cow::Owned(shape) + } +} - unsafe { std::mem::transmute(ptr) } +impl NDArray { + pub(crate) fn _from_raw(handle: ffi::TVMArrayHandle) -> Self { + let ptr = NDArrayContainer::from_raw(handle); + NDArray(ptr) + } + + // I think these should be marked as unsafe functions? projecting a reference is bad news. + pub fn as_dltensor(&self) -> &DLTensor { + &self.dl_tensor } pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - match self { - NDArray::Borrowed { handle } => *handle, - NDArray::Owned { handle } => *handle as *mut DLTensor, - } + unsafe { std::mem::transmute(self.as_dltensor()) } } pub fn is_view(&self) -> bool { - if let &NDArray::Borrowed { .. } = self { - true - } else { - false - } + false } /// Returns the shape of the NDArray. - pub fn shape(&self) -> Option<&mut [usize]> { + pub fn shape(&self) -> &[i64] { let arr = self.as_dltensor(); if arr.shape.is_null() || arr.data.is_null() { - return None; - }; - let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; - Some(slc) + &[] + } else { + unsafe { slice::from_raw_parts(arr.shape, self.ndim()) } + } + } + + /// Returns the shape of the NDArray as a &[usize] + /// + /// On 64-bit platforms, this is zero-cost and uses the shape from the DLTensor. + /// On other platforms, this copies into a buffer. + pub fn shape_usize(&self) -> Cow<[usize]> { + cow_usize(self.shape()) + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[i64]> { + let arr = self.as_dltensor(); + if arr.strides.is_null() { + None + } else { + Some(unsafe { slice::from_raw_parts(arr.strides, self.ndim()) }) + } + } + + /// Returns the strides of the NDArray as a &[usize] + /// + /// On 64-bit platforms, this is zero-cost and uses the strides from the DLTensor. + /// On other platforms, this copies into a buffer. + pub fn strides_usize(&self) -> Option> { + self.strides().map(cow_usize) + } + + /// Returns true if the tensor is empty + pub fn is_empty(&self) -> bool { + self.as_dltensor().data.is_null() } /// Returns the total number of entries of the NDArray. - pub fn size(&self) -> Option { - self.shape().map(|v| v.iter().product()) + pub fn len(&self) -> usize { + let len: i64 = self.shape().iter().product(); + usize::try_from(len).unwrap_or_else(|_| panic!("bad len: {}", len)) + } + + /// Returns the total bytes taken up by the data. + /// This is equal to `nd.len() * nd.dtype().itemsize()` + pub fn size(&self) -> usize { + self.len() * self.dtype().itemsize() } /// Returns the context which the NDArray was defined. @@ -134,24 +207,13 @@ impl NDArray { .expect("number of dimensions must always be positive") } - /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[usize]> { - unsafe { - let sz = self.ndim() * mem::size_of::(); - let strides_ptr = self.as_dltensor().strides as *const usize; - let slc = slice::from_raw_parts(strides_ptr, sz); - Some(slc) - } - } - /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { - Ok(match self.strides() { + pub fn is_contiguous(&self) -> bool { + match self.strides() { None => true, Some(strides) => { // NDArrayError::MissingShape in case shape is not determined self.shape() - .ok_or(NDArrayError::MissingShape)? .iter() .zip(strides) .rfold( @@ -159,13 +221,13 @@ impl NDArray { |(is_contig, expected_stride), (shape, stride)| { ( is_contig && *stride == expected_stride, - expected_stride * (*shape as usize), + expected_stride * shape, ) }, ) .0 } - }) + } } pub fn byte_offset(&self) -> isize { @@ -184,28 +246,19 @@ impl NDArray { /// let ctx = Context::cpu(0); /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); /// ndarray.copy_from_buffer(&mut data); - /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.shape(), shape); /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` pub fn to_vec(&self) -> Result, NDArrayError> { - if !self.shape().is_some() { - return Err(NDArrayError::EmptyArray); - } - let earr = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - Context::cpu(0), - self.dtype(), - ); - let target = self.copy_to_ndarray(earr)?; - let arr = target.as_dltensor(); - let sz = self.size().ok_or(NDArrayError::MissingShape)?; - let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); - unsafe { - v.as_mut_ptr() - .copy_from_nonoverlapping(arr.data as *const T, sz); - v.set_len(sz); - } - Ok(v) + let n = self.size() / size_of::(); + let mut vec: Vec = Vec::with_capacity(n); + + let ptr = vec.as_mut_ptr(); + let slice = unsafe { slice::from_raw_parts_mut(ptr, n) }; + self.copy_to_buffer(slice); + + unsafe { vec.set_len(n) }; + Ok(vec) } /// Converts the NDArray to [`ByteArray`]. @@ -230,7 +283,7 @@ impl NDArray { /// /// *Note*: if something goes wrong during the copy, it will panic /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. - pub fn copy_from_buffer(&mut self, data: &mut [T]) { + pub fn copy_from_buffer(&mut self, data: &[T]) { check_call!(ffi::TVMArrayCopyFromBytes( self.as_raw_dltensor(), data.as_ptr() as *mut _, @@ -238,6 +291,29 @@ impl NDArray { )); } + pub fn copy_to_buffer(&self, data: &mut [T]) { + assert_eq!(self.size(), data.len() * size_of::()); + check_call!(ffi::TVMArrayCopyToBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + self.size(), + )); + } + + pub fn fill_from_iter(&mut self, iter: I) + where + T: Num32, + I: ExactSizeIterator, + { + assert!(self.is_contiguous()); + assert_eq!(self.size(), size_of::() * iter.len()); + let mut ptr: *mut T = self.as_dltensor().data.cast(); + iter.for_each(|x| unsafe { + ptr.write(x); + ptr = ptr.add(1); + }) + } + /// Copies the NDArray to another target NDArray. pub fn copy_to_ndarray(&self, target: NDArray) -> Result { if self.dtype() != target.dtype() { @@ -258,37 +334,29 @@ impl NDArray { /// Copies the NDArray to a target context. pub fn copy_to_ctx(&self, target: &Context) -> Result { - let tmp = NDArray::empty( - self.shape().ok_or(NDArrayError::MissingShape)?, - *target, - self.dtype(), - ); + let tmp = NDArray::empty(self.shape(), *target, self.dtype()); let copy = self.copy_to_ndarray(tmp)?; Ok(copy) } /// Converts a Rust's ndarray to TVM NDArray. pub fn from_rust_ndarray( - rnd: &ArrayD, + input_nd: &ArrayD, ctx: Context, dtype: DataType, ) -> Result { - let shape = rnd.shape().to_vec(); + let shape: Vec = input_nd.shape().iter().map(|&x| x as i64).collect(); let mut nd = NDArray::empty(&shape, ctx, dtype); - let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); - nd.copy_from_buffer( - buf.as_slice_mut() - .expect("Array from iter must be contiguous."), - ); + nd.fill_from_iter(input_nd.iter().copied()); Ok(nd) } /// Allocates and creates an empty NDArray given the shape, context and dtype. - pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + pub fn empty(shape: &[i64], ctx: Context, dtype: DataType) -> NDArray { let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; let dtype: tvm_sys::ffi::DLDataType = dtype.into(); check_call!(ffi::TVMArrayAlloc( - shape.as_ptr() as *const i64, + shape.as_ptr(), shape.len() as c_int, i32::from(dtype.code) as c_int, i32::from(dtype.bits) as c_int, @@ -297,7 +365,19 @@ impl NDArray { ctx.device_id as c_int, &mut handle as *mut _, )); - NDArray::Borrowed { handle: handle } + let ptr = NDArrayContainer::from_raw(handle) + .map(|o| o.downcast().expect("this should never fail")); + NDArray(ptr) + } + + pub fn zeroed(self) -> NDArray { + unsafe { + let dltensor = self.as_raw_dltensor(); + let bytes_ptr: *mut u8 = std::mem::transmute((*dltensor).data); + println!("size {}", self.size()); + std::ptr::write_bytes(bytes_ptr, 0, self.size()); + self + } } } @@ -307,12 +387,9 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - } assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, + &*nd.shape_usize(), nd.to_vec::<$type>()?, )?) } @@ -322,12 +399,9 @@ macro_rules! impl_from_ndarray_rustndarray { type Error = NDArrayError; fn try_from(nd: &mut NDArray) -> Result, Self::Error> { - if !nd.shape().is_some() { - return Err(NDArrayError::MissingShape); - }; assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( - &*nd.shape().ok_or(NDArrayError::MissingShape)?, + &*nd.shape_usize(), nd.to_vec::<$type>()?, )?) } @@ -339,14 +413,6 @@ impl_from_ndarray_rustndarray!(i32, "int"); impl_from_ndarray_rustndarray!(u32, "uint"); impl_from_ndarray_rustndarray!(f32, "float"); -impl Drop for NDArray { - fn drop(&mut self) { - if let &mut NDArray::Owned { .. } = self { - check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); - } - } -} - mod sealed { /// Private trait to prevent other traits from being implemeneted in downstream crates. pub trait Sealed {} @@ -374,14 +440,13 @@ mod tests { #[test] fn basics() { - let shape = &mut [1, 2, 3]; + let shape = &[1, 2, 3]; let ctx = Context::cpu(0); + println!("before empty"); let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - assert_eq!(ndarray.shape().unwrap(), shape); - assert_eq!( - ndarray.size().unwrap(), - shape.to_vec().into_iter().product() - ); + println!("after empty"); + assert_eq!(ndarray.shape(), shape); + assert_eq!(ndarray.len(), shape.iter().product::() as usize); assert_eq!(ndarray.ndim(), 3); assert!(ndarray.strides().is_none()); assert_eq!(ndarray.byte_offset(), 0); @@ -389,16 +454,16 @@ mod tests { #[test] fn copy() { - let shape = &mut [4]; - let mut data = vec![1i32, 2, 3, 4]; + let shape = &[4]; + let data = vec![1i32, 2, 3, 4]; let ctx = Context::cpu(0); - let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); - assert!(ndarray.to_vec::().is_ok()); - ndarray.copy_from_buffer(&mut data); - assert_eq!(ndarray.shape().unwrap(), shape); + let mut ndarray = NDArray::empty(shape, ctx, DataType::int(32, 1)).zeroed(); + assert_eq!(ndarray.to_vec::().unwrap(), vec![0, 0, 0, 0]); + ndarray.copy_from_buffer(&data); + assert_eq!(ndarray.shape(), shape); assert_eq!(ndarray.to_vec::().unwrap(), data); assert_eq!(ndarray.ndim(), 1); - assert!(ndarray.is_contiguous().is_ok()); + assert!(ndarray.is_contiguous()); assert_eq!(ndarray.byte_offset(), 0); let shape = vec![4]; let e = NDArray::empty( @@ -411,17 +476,18 @@ mod tests { assert_eq!(nd.unwrap().to_vec::().unwrap(), data); } - // #[test] - // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - // fn copy_wrong_dtype() { - // let shape = vec![4]; - // let mut data = vec![1f32, 2., 3., 4.]; - // let ctx = Context::cpu(0); - // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - // nd_float.copy_from_buffer(&mut data); - // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - // nd_float.copy_to_ndarray(empty_int).unwrap(); - // } + /// This occasionally panics on macOS: https://github.com/rust-lang/rust/issues/71397 + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } #[test] fn rust_ndarray() { @@ -431,7 +497,7 @@ mod tests { let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) .unwrap(); - assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + assert_eq!(nd.shape(), &[2, 2]); let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); assert!(rnd.all_close(&a, 1e-8f32)); } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 1388d3c96d02..77254d2fbca2 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -276,14 +276,22 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { type Error = Error; fn try_from(ret_value: RetValue) -> Result, Self::Error> { + use crate::ffi::DLTensor; + use crate::ndarray::NDArrayContainer; + match ret_value { RetValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // println!("back to type {}", optr.count()); optr.downcast() } - _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), + RetValue::NDArrayHandle(handle) => { + let optr: ObjectPtr = + NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); + optr.upcast::().downcast() + } + _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), } } } @@ -291,9 +299,22 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::ObjectHandle(raw_ptr) + let object_ptr = object_ptr.upcast::(); + match T::TYPE_KEY { + "runtime.NDArray" => { + use crate::ndarray::NDArrayContainer; + // TODO(this is probably not optimal) + let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) + as *mut NDArrayContainer as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::NDArrayHandle(raw_ptr) + } + _ => { + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::ObjectHandle(raw_ptr) + } + } } } @@ -301,13 +322,21 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { + use crate::ffi::DLTensor; + use crate::ndarray::NDArrayContainer; + match arg_value { ArgValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - // println!("count: {}", optr.count()); optr.downcast() } + ArgValue::NDArrayHandle(handle) => { + let optr = + NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); + optr.upcast::().downcast() + } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } } @@ -402,37 +431,8 @@ mod tests { return o; } - // #[test] - // fn test_ref_count_boundary() { - // use super::*; - // use crate::function::{register, Function, Result}; - // // 1 - // let ptr = ObjectPtr::new(Object::base_object::()); - // assert_eq!(ptr.count(), 1); - // // 2 - // let stay = ptr.clone(); - // assert_eq!(ptr.count(), 2); - // register(test_fn, "my_func").unwrap(); - // let func = Function::get("my_func").unwrap(); - // let func = func.to_boxed_fn::) -> Result>>(); - // let same = func(ptr).unwrap(); - // drop(func); - // assert_eq!(stay.count(), 4); - // assert_eq!(same.count(), 4); - // drop(same); - // assert_eq!(stay.count(), 3); - // } - - // fn test_fn2(o: ArgValue<'static>) -> RetValue { - // // The call machinery adds at least 1 extra count while inside the call. - // match o { - // ArgValue::ObjectHandle(ptr) => RetValue::ObjectHandle(ptr), - // _ => panic!() - // } - // } - #[test] - fn test_ref_count_boundary2() { + fn test_ref_count_boundary3() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base_object::()); diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index a5ee1f183389..6ff24bef3a60 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -114,9 +114,7 @@ impl Hash for String { impl std::fmt::Debug for String { fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // TODO(@mwillsey): remove this clone? - let string: String = self.clone().into(); - formatter.write_fmt(format_args!("{:?}", string)) + formatter.write_fmt(format_args!("{:?}", self.to_string_lossy())) } } diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index 1812c0cfbe45..c49944dc7e33 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -24,7 +24,7 @@ use std::convert::TryFrom; // use std::ffi::c_void; -use crate::{ArgValue, Module, NDArray, RetValue}; +use crate::{ArgValue, Module, RetValue}; use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; macro_rules! impl_handle_val { @@ -72,60 +72,6 @@ macro_rules! impl_handle_val { impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); -impl<'a> From<&'a NDArray> for ArgValue<'a> { - fn from(arg: &'a NDArray) -> Self { - match arg { - &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), - &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), - } - } -} - -impl<'a> From<&'a mut NDArray> for ArgValue<'a> { - fn from(arg: &'a mut NDArray) -> Self { - match arg { - &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), - &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), - } - } -} - -impl<'a> TryFrom> for NDArray { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> NDArray, - |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, - |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) - } -} - -impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> NDArray, - |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, - |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) - } -} - -impl From for RetValue { - fn from(val: NDArray) -> RetValue { - match val { - NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), - _ => panic!("NYI"), - } - } -} - -impl TryFrom for NDArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> NDArray, - |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, - |RetValue::ArrayHandle(val)| { NDArray::new(val) }) - } -} - #[cfg(test)] mod tests { use std::{convert::TryInto, str::FromStr}; diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 358853951fda..f7b289c59675 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -409,3 +409,18 @@ impl<'a> TryFrom> for bool { try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) } } + +impl From<()> for RetValue { + fn from(_: ()) -> Self { + RetValue::Null + } +} + +impl TryFrom for () { + type Error = ValueDowncastError; + + fn try_from(val: RetValue) -> Result<(), Self::Error> { + try_downcast!(val -> bool, + |RetValue::Null| { () }) + } +} diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index ebfb5e64a4a7..55fc1790604e 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -40,6 +40,10 @@ tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" +pyo3 = { version = "0.11.1", optional = true } [features] +default = ["python"] + blas = ["ndarray/blas"] +python = ["pyo3"] diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml index fd10569869d5..646385a6373e 100644 --- a/rust/tvm/examples/resnet/Cargo.toml +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -28,6 +28,7 @@ ndarray = "0.12" tvm = { path = "../../" } image = "0.20" csv = "1.1" +anyhow = "^1.0" [build-dependencies] -anyhow = "^1.0" +anyhow = "1.0" diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index b259a626eb5e..1e5d8a98736d 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -18,7 +18,7 @@ */ use anyhow::{Context, Result}; -use std::{path::Path, process::Command}; +use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { let output = Command::new("python3") @@ -26,7 +26,12 @@ fn main() -> Result<()> { .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; - + if !output.status.success() { + std::io::stdout() + .write_all(&output.stderr) + .context("Failed to write error")?; + panic!("Failed to execute build script"); + } assert!( Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), "Could not prepare demo: {}", @@ -37,7 +42,6 @@ fn main() -> Result<()> { .last() .unwrap_or("") ); - println!( "cargo:rustc-link-search=native={}", env!("CARGO_MANIFEST_DIR") diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 904f244e0a9a..324bb52e08a9 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -21,6 +21,7 @@ import logging from os import path as osp import sys +import shutil import numpy as np @@ -29,6 +30,9 @@ from tvm import relay from tvm.relay import testing from tvm.contrib import graph_runtime, cc +from PIL import Image +from tvm.contrib.download import download_testdata +from mxnet.gluon.model_zoo.vision import get_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -38,7 +42,6 @@ parser = argparse.ArgumentParser(description="Resnet build example") aa = parser.add_argument aa("--build-dir", type=str, required=True, help="directory to put the build artifacts") -aa("--pretrained", action="store_true", help="use a pretrained resnet") aa("--batch-size", type=int, default=1, help="input image batch size") aa( "--opt-level", @@ -54,41 +57,30 @@ build_dir = args.build_dir batch_size = args.batch_size opt_level = args.opt_level -target = tvm.target.Target(args.target) +target = tvm.target.create(args.target) image_shape = tuple(map(int, args.image_shape.split(","))) data_shape = (batch_size,) + image_shape def build(target_dir): """ Compiles resnet18 with TVM""" - deploy_lib = osp.join(target_dir, "deploy_lib.o") - if osp.exists(deploy_lib): - return - - if args.pretrained: - # needs mxnet installed - from mxnet.gluon.model_zoo.vision import get_model - - # if `--pretrained` is enabled, it downloads a pretrained - # resnet18 trained on imagenet1k dataset for image classification task - block = get_model("resnet18_v1", pretrained=True) - net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) - # we want a probability so add a softmax operator - func = net["main"] - net = relay.Function( - func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs - ) - else: - # use random weights from relay.testing - net, params = relay.testing.resnet.get_workload( - num_layers=18, batch_size=batch_size, image_shape=image_shape - ) - - # compile the model - with tvm.transform.PassContext(opt_level=opt_level): - graph, lib, params = relay.build_module.build(net, target, params=params) + # Download the pretrained model in MxNet's format. + block = get_model("resnet18_v1", pretrained=True) + + shape_dict = {"data": (1, 3, 224, 224)} + mod, params = relay.frontend.from_mxnet(block, shape_dict) + # Add softmax to do classification in last layer. + func = mod["main"] + func = relay.Function( + func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs + ) + + target = "llvm" + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) # save the model artifacts + deploy_lib = osp.join(target_dir, "deploy_lib.o") lib.save(deploy_lib) cc.create_shared(osp.join(target_dir, "deploy_lib.so"), [osp.join(target_dir, "deploy_lib.o")]) @@ -103,7 +95,6 @@ def download_img_labels(): """ Download an image and imagenet1k class labels for test""" from mxnet.gluon.utils import download - img_name = "cat.png" synset_url = "".join( [ "https://gist.githubusercontent.com/zhreshold/", @@ -113,37 +104,53 @@ def download_img_labels(): ] ) synset_name = "synset.txt" - download("https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true", img_name) - download(synset_url, synset_name) + synset_path = download_testdata(synset_url, synset_name, module="data") - with open(synset_name) as fin: + with open(synset_path) as fin: synset = eval(fin.read()) - with open("synset.csv", "w") as fout: - w = csv.writer(fout) - w.writerows(synset.items()) + with open(synset_name, "w") as f: + for key in synset: + f.write(synset[key]) + f.write("\n") + + return synset + + +def transform_image(image): + image = np.array(image) - np.array([123.0, 117.0, 104.0]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + + +def get_cat_image(): + img_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true" + img_path = download_testdata(img_url, "cat.png", module="data") + shutil.copyfile(img_path, "cat.png") + img = Image.open(img_path).resize((224, 224)) + return transform_image(img) def test_build(build_dir): - """ Sanity check with random input""" + """ Sanity check with the cat image we download.""" graph = open(osp.join(build_dir, "deploy_graph.json")).read() lib = tvm.runtime.load_module(osp.join(build_dir, "deploy_lib.so")) params = bytearray(open(osp.join(build_dir, "deploy_param.params"), "rb").read()) - input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + input_data = get_cat_image() ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) module.load_params(params) module.run(data=input_data) out = module.get_output(0).asnumpy() + top1 = np.argmax(out[0]) + synset = download_img_labels() + print("TVM prediction top-1:", top1, synset[top1]) if __name__ == "__main__": - logger.info("building the model") + logger.info("Compiling the model to graph runtime.") build(build_dir) - logger.info("build was successful") - logger.info("test the build artifacts") + logger.info("Testing the model's predication on test data.") test_build(build_dir) - logger.info("test was successful") - if args.pretrained: - download_img_labels() - logger.info("image and synset downloads are successful") diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 16ca8c7386f1..f24c358ab52a 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -18,21 +18,25 @@ */ use std::{ - collections::HashMap, - convert::TryInto, fs::{self, File}, + io::{BufRead, BufReader}, path::Path, }; use ::ndarray::{Array, ArrayD, Axis}; use image::{FilterType, GenericImageView}; -use tvm::runtime::ByteArray; +use anyhow::Context as _; +use tvm::runtime::graph_rt::GraphRt; use tvm::*; -fn main() { +fn main() -> anyhow::Result<()> { let ctx = Context::cpu(0); - let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")); + + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")) + .context("Failed to open cat.png")?; + println!("original image dimensions: {:?}", img.dimensions()); // for bigger size images, one needs to first resize to 256x256 // with `img.resize_exact` method and then `image.crop` to 224x224 @@ -52,100 +56,68 @@ fn main() { } } - let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr = Array::from_shape_vec((224, 224, 3), pixels)?; let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); // make arr shape as [1, 3, 224, 224] acceptable to resnet let arr = arr.insert_axis(Axis(0)); // create input tensor from rust's ndarray - let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1)).unwrap(); + let input = NDArray::from_rust_ndarray(&arr, Context::cpu(0), DataType::float(32, 1))?; println!( - "input size is {:?}", - input.shape().expect("cannot get the input shape") + "input shape is {:?}, len: {}, size: {}", + input.shape(), + input.len(), + input.size(), ); - let graph = - fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + + let graph = fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")) + .context("Failed to open graph")?; + // load the built module let lib = Module::load(&Path::new(concat!( env!("CARGO_MANIFEST_DIR"), "/deploy_lib.so" - ))) - .unwrap(); - // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); - let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ - graph.into(), - (&lib).into(), - (&ctx.device_type).into(), - (&ctx.device_id).into(), - ]); - - // get graph runtime module - let graph_runtime_module: Module = runtime_create_fn_ret.unwrap().try_into().unwrap(); - - // get the registered `load_params` from runtime module - let ref load_param_fn = graph_runtime_module - .get_function("load_params", false) - .unwrap(); + )))?; + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, ctx)?; + // parse parameters and convert to TVMByteArray - let params: Vec = - fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); - let barr = ByteArray::from(¶ms); - // load the parameters - load_param_fn.invoke(vec![(&barr).into()]).unwrap(); - // get the set_input function - let ref set_input_fn = graph_runtime_module - .get_function("set_input", false) - .unwrap(); + let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - set_input_fn - .invoke(vec!["data".into(), (&input).into()]) - .unwrap(); + println!("param bytes: {}", params.len()); + + graph_rt.load_params(¶ms)?; + graph_rt.set_input("data", input)?; + graph_rt.run()?; - // get `run` function from runtime module - let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); - // execute the run function. Note that it has no argument - run_fn.invoke(vec![]).unwrap(); // prepare to get the output - let output_shape = &mut [1, 1000]; + let output_shape = &[1, 1000]; let output = NDArray::empty(output_shape, Context::cpu(0), DataType::float(32, 1)); - // get the `get_output` function from runtime module - let ref get_output_fn = graph_runtime_module - .get_function("get_output", false) - .unwrap(); - // execute the get output function - get_output_fn - .invoke(vec![(&0).into(), (&output).into()]) - .unwrap(); + graph_rt.get_output_into(0, output.clone())?; + // flatten the output as Vec - let output = output.to_vec::().unwrap(); + let output = output.to_vec::()?; + // find the maximum entry in the output and its index - let mut argmax = -1; - let mut max_prob = 0.; - for i in 0..output.len() { - if output[i] > max_prob { - max_prob = output[i]; - argmax = i as i32; - } - } + let (argmax, max_prob) = output + .iter() + .copied() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + // create a hash map of (class id, class name) - let mut synset: HashMap = HashMap::new(); - let file = File::open("synset.csv").unwrap(); - let mut rdr = csv::ReaderBuilder::new() - .has_headers(true) - .from_reader(file); - - for result in rdr.records() { - let record = result.unwrap(); - let id: i32 = record[0].parse().unwrap(); - let cls = record[1].to_string(); - synset.insert(id, cls); - } + let file = File::open("synset.txt").context("failed to open synset")?; + let synset: Vec = BufReader::new(file) + .lines() + .into_iter() + .map(|x| x.expect("readline failed")) + .collect(); + let label = &synset[argmax]; println!( "input image belongs to the class `{}` with probability {}", - synset - .get(&argmax) - .expect("cannot find the class id for argmax"), - max_prob + label, max_prob ); + + Ok(()) } diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index c2de24a299f7..f589f2ac25c6 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -19,7 +19,7 @@ use crate::runtime::{Object, ObjectPtr}; -use super::*; +use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { diff --git a/rust/tvm/src/ir/attrs.rs b/rust/tvm/src/ir/attrs.rs new file mode 100644 index 000000000000..5bd027ab4b4c --- /dev/null +++ b/rust/tvm/src/ir/attrs.rs @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::Object; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Attrs"] +#[type_key = "Attrs"] +pub struct BaseAttrsNode { + pub base: Object, +} diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs new file mode 100644 index 000000000000..91c42f0edbcf --- /dev/null +++ b/rust/tvm/src/ir/expr.rs @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::relay; +use crate::runtime::String as TString; +use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; +use crate::DataType; + +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +impl BaseExprNode { + pub fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PrimExpr"] +#[type_key = "PrimExpr"] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl PrimExprNode { + pub fn base(datatype: DataType) -> PrimExprNode { + PrimExprNode { + base: BaseExprNode::base::(), + datatype, + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "GlobalVar"] +pub struct GlobalVarNode { + pub base: relay::ExprNode, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: relay::ExprNode::base::(), + name_hint: name_hint.into(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } +} + +// TODO(@jroesch): update to match TVM +// Move IntImm +// Define FloatImm +// Define Bool +// Define tvm::Integer? +// Define RangeNode + +// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) +external! { + #[name("ir.AsText")] + fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; +} + +pub fn as_text(object: T) -> String { + let no_func = unsafe { runtime::Function::null() }; + _as_text(object.upcast(), 0, no_func) + .unwrap() + .as_str() + .unwrap() + .into() +} diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs new file mode 100644 index 000000000000..3043bf9e7cff --- /dev/null +++ b/rust/tvm/src/ir/function.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::relay::ExprNode; +use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; + +use tvm_macros::Object; + +// Define Calling Convention. + +// TODO(@jroesch): define DictAttrs +pub type DictAttrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseFunc"] +#[type_key = "BaseFunc"] +pub struct BaseFuncNode { + pub base: ExprNode, + pub attrs: DictAttrs, +} + +impl BaseFuncNode { + pub fn base() -> BaseFuncNode { + BaseFuncNode { + base: ExprNode::base::(), + attrs: ::null(), + } + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index b615c1ec588e..126d0faccabb 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -17,60 +17,16 @@ * under the License. */ -use crate::runtime::String as TString; -use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectRef}; -use crate::DataType; -use tvm_macros::Object; - pub mod arith; +pub mod attrs; +pub mod expr; +pub mod function; +pub mod module; +pub mod op; pub mod relay; +pub mod span; pub mod tir; +pub mod ty; -// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) -external! { - #[name("ir.AsText")] - fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; -} - -pub fn as_text(object: T) -> String { - let no_func = unsafe { runtime::Function::null() }; - _as_text(object.upcast(), 0, no_func) - .unwrap() - .as_str() - .unwrap() - .into() -} - -#[repr(C)] -#[derive(Object)] -#[ref_name = "BaseExpr"] -#[type_key = "Expr"] -pub struct BaseExprNode { - pub base: Object, -} - -impl BaseExprNode { - fn base() -> BaseExprNode { - BaseExprNode { - base: Object::base_object::(), - } - } -} - -#[repr(C)] -#[derive(Object)] -#[ref_name = "PrimExpr"] -#[type_key = "PrimExpr"] -pub struct PrimExprNode { - pub base: BaseExprNode, - pub datatype: DataType, -} - -impl PrimExprNode { - pub fn base(datatype: DataType) -> PrimExprNode { - PrimExprNode { - base: BaseExprNode::base::(), - datatype, - } - } -} +pub use expr::*; +pub use module::IRModule; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs new file mode 100644 index 000000000000..e0444b3101da --- /dev/null +++ b/rust/tvm/src/ir/module.rs @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::array::Array; +use crate::runtime::function::Result; +use crate::runtime::map::Map; +use crate::runtime::string::String as TVMString; +use crate::runtime::{external, Object, ObjectRef}; + +use super::expr::GlobalVar; +use super::function::BaseFunc; + +use std::io::Result as IOResult; +use std::path::Path; + +use tvm_macros::Object; + +// TODO(@jroesch): define type +type TypeData = ObjectRef; +type GlobalTypeVar = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "IRModule"] +#[type_key = "IRModule"] +pub struct IRModuleNode { + pub base: Object, + pub functions: Map, + pub type_definitions: Map, +} + +external! { + // Parser functions + #[name("parser.ParseModule")] + fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; + #[name("parser.ParseExpr")] + fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; + // Module methods + #[name("ir.Module_AddDef")] + fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); + #[name("ir.Module_GetGlobalVar")] + fn module_get_global_var(module: IRModule, name: TVMString) -> GlobalVar; + #[name("ir.Module_GetGlobalVars")] + fn module_get_global_vars(module: IRModule) -> Array; + #[name("ir.Module_Lookup")] + fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc; + #[name("ir.Module_Lookup_str")] + fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; +} + +// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") +// .set_body_method(&IRModuleNode::GetGlobalTypeVars); + +// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") +// .set_body_method(&IRModuleNode::ContainGlobalVar); + +// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") +// .set_body_method(&IRModuleNode::GetGlobalTypeVar); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { +// return mod->LookupTypeDef(var); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { +// return mod->LookupTypeDef(var); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { +// return mod->LookupTag(tag); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_FromExpr") +// .set_body_typed([](RelayExpr e, tvm::Map funcs, +// tvm::Map type_defs) { +// return IRModule::FromExpr(e, funcs, type_defs); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { +// mod->Update(from); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") +// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); + +// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { +// mod->Import(path); +// }); + +// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { +// mod->ImportFromStd(path); +// }); + +// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { +// auto* node = static_cast(ref.get()); +// p->stream << "IRModuleNode( " << node->functions << ")"; +// }); + +impl IRModule { + pub fn parse(file_name: N, source: S) -> IRModule + where + N: Into, + S: Into, + { + parse_module(file_name.into(), source.into()).expect("failed to call parser") + } + + pub fn parse_file>(file_path: P) -> IOResult { + let file_path = file_path.as_ref(); + let file_path_as_str = file_path.to_str().unwrap().to_string(); + let source = std::fs::read_to_string(file_path)?; + let module = IRModule::parse(file_path_as_str, source); + Ok(module) + } + + pub fn add_def( + &mut self, + type_name: GlobalTypeVar, + type_data: TypeData, + update: bool, + ) -> Result<()> { + module_add_def(self.clone(), type_name, type_data, update) + } + + pub fn get_global_var(&self, name: TVMString) -> Result { + module_get_global_var(self.clone(), name) + } + + pub fn get_global_vars(&self) -> Result> { + module_get_global_vars(self.clone()) + } + + pub fn lookup(&self, var: GlobalVar) -> Result { + module_lookup(self.clone(), var) + } + + pub fn lookup_str(&self, name: S) -> Result + where + S: Into, + { + module_lookup_str(self.clone(), name.into()) + } +} diff --git a/rust/tvm/src/ir/op.rs b/rust/tvm/src/ir/op.rs new file mode 100644 index 000000000000..d81d6a69c1eb --- /dev/null +++ b/rust/tvm/src/ir/op.rs @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::relay::ExprNode; +use crate::runtime::array::Array; +use crate::runtime::ObjectRef; +use crate::runtime::String as TString; +use tvm_macros::Object; + +type FuncType = ObjectRef; +type AttrFieldInfo = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Op"] +#[type_key = "Op"] +pub struct OpNode { + pub base: ExprNode, + pub name: TString, + pub op_type: FuncType, + pub description: TString, + pub arguments: Array, + pub attrs_type_key: TString, + pub attrs_type_index: u32, + pub num_inputs: i32, + pub support_level: i32, +} diff --git a/rust/tvm/src/ir/relay/attrs/mod.rs b/rust/tvm/src/ir/relay/attrs/mod.rs new file mode 100644 index 000000000000..d1bcc0009657 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/mod.rs @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod nn; +pub mod transform; diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs new file mode 100644 index 000000000000..f743534e5f61 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use crate::runtime::DataType; +use crate::runtime::String as TString; +use tvm_macros::Object; + +type IndexExpr = PrimExpr; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Conv2DAttrs"] +#[type_key = "relay.attrs.Conv2DAttrs"] +pub struct Conv2DAttrsNode { + pub base: BaseAttrsNode, + pub strides: Array, + pub padding: Array, + pub dilation: Array, + // TODO(@gussmith23) groups is "int", what should it be here? + pub groups: i32, + pub channels: IndexExpr, + pub kernel_size: Array, + pub data_layout: TString, + pub kernel_layout: TString, + pub out_layout: TString, + pub out_dtype: DataType, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BiasAddAttrs"] +#[type_key = "relay.attrs.BiasAddAttrs"] +pub struct BiasAddAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "DenseAttrs"] +#[type_key = "relay.attrs.DenseAttrs"] +pub struct DenseAttrsNode { + pub base: BaseAttrsNode, + pub units: IndexExpr, + pub out_dtype: DataType, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalPool2DAttrs"] +#[type_key = "relay.attrs.GlobalPool2DAttrs"] +pub struct GlobalPool2DAttrsNode { + pub base: BaseAttrsNode, + pub layout: TString, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "MaxPool2DAttrs"] +#[type_key = "relay.attrs.MaxPool2DAttrs"] +pub struct MaxPool2DAttrsNode { + pub base: BaseAttrsNode, + pub pool_size: Array, + pub strides: Array, + pub padding: Array, + pub layout: TString, + pub ceil_mode: bool, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "SoftmaxAttrs"] +#[type_key = "relay.attrs.SoftmaxAttrs"] +pub struct SoftmaxAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs new file mode 100644 index 000000000000..863f07617778 --- /dev/null +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::attrs::BaseAttrsNode; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "ExpandDimsAttrs"] +#[type_key = "relay.attrs.ExpandDimsAttrs"] +pub struct ExpandDimsAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32, + pub num_newaxis: i32, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 4f4497ea0fce..e539221d1db6 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,88 +17,76 @@ * under the License. */ +pub mod attrs; + +use std::hash::Hash; + use crate::runtime::array::Array; use crate::runtime::{object::*, String as TString}; -use crate::DataType; -use tvm_macros::Object; -#[repr(C)] -#[derive(Object)] -#[ref_name = "Id"] -#[type_key = "relay.Id"] -pub struct IdNode { - pub base: Object, - pub name_hint: TString, -} +use super::attrs::Attrs; +use super::expr::BaseExprNode; +use super::function::BaseFuncNode; +use super::ty::{Type, TypeNode}; -impl Id { - fn new(name_hint: TString) -> Id { - let node = IdNode { - base: Object::base_object::(), - name_hint: name_hint, - }; - Id(Some(ObjectPtr::new(node))) - } -} - -#[repr(C)] -#[derive(Object)] -#[ref_name = "BaseExpr"] -#[type_key = "Expr"] -pub struct BaseExprNode { - pub base: Object, -} - -#[repr(C)] -pub struct PrimExprNode { - pub base: BaseExprNode, - pub datatype: DataType, -} +use tvm_macros::Object; +use tvm_rt::NDArray; -impl BaseExprNode { - fn base() -> BaseExprNode { - BaseExprNode { - base: Object::base_object::(), - } - } -} +pub use super::expr::{GlobalVar, GlobalVarNode}; #[repr(C)] #[derive(Object)] #[ref_name = "Expr"] -#[type_key = "relay.Expr"] -pub struct RelayExpr { +#[type_key = "RelayExpr"] +pub struct ExprNode { pub base: BaseExprNode, pub span: ObjectRef, - pub checked_type: ObjectRef, + pub checked_type: Type, } -impl RelayExpr { - fn base() -> RelayExpr { - RelayExpr { +impl ExprNode { + pub fn base() -> ExprNode { + ExprNode { base: BaseExprNode::base::(), span: ObjectRef::null(), - checked_type: ObjectRef::null(), + checked_type: Type::from(TypeNode { + base: Object::base_object::(), + span: ObjectRef::null(), + }), } } } +impl Hash for Expr { + fn hash(&self, state: &mut H) { + self.as_ptr().unwrap().ptr.hash(state) + } +} + +impl PartialEq for Expr { + fn eq(&self, other: &Self) -> bool { + self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr) + } +} + +impl Eq for Expr {} + #[repr(C)] #[derive(Object)] -#[ref_name = "GlobalVar"] -#[type_key = "GlobalVar"] -pub struct GlobalVarNode { - pub base: RelayExpr, +#[ref_name = "Id"] +#[type_key = "relay.Id"] +pub struct IdNode { + pub base: Object, pub name_hint: TString, } -impl GlobalVar { - pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { - let node = GlobalVarNode { - base: RelayExpr::base::(), - name_hint: name_hint.into(), +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::(), + name_hint: name_hint, }; - GlobalVar(Some(ObjectPtr::new(node))) + Id(Some(ObjectPtr::new(node))) } } @@ -107,36 +95,55 @@ impl GlobalVar { #[ref_name = "Constant"] #[type_key = "relay.Constant"] pub struct ConstantNode { - pub base: RelayExpr, - pub data: ObjectRef, // make this NDArray. + pub base: ExprNode, + pub data: NDArray, } impl Constant { - pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + pub fn new(data: NDArray, _span: ObjectRef) -> Constant { let node = ConstantNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), data: data, }; Constant(Some(ObjectPtr::new(node))) } } +#[repr(C)] +#[derive(Object)] +#[ref_name = "Tuple"] +#[type_key = "relay.Tuple"] +pub struct TupleNode { + pub base: ExprNode, + pub fields: Array, +} + +impl Tuple { + pub fn new(fields: Array, _span: ObjectRef) -> Tuple { + let node = TupleNode { + base: ExprNode::base::(), + fields, + }; + Tuple(Some(ObjectPtr::new(node))) + } +} + #[repr(C)] #[derive(Object)] #[ref_name = "Var"] #[type_key = "relay.Var"] pub struct VarNode { - pub base: RelayExpr, + pub base: ExprNode, pub vid: Id, - pub type_annotation: ObjectRef, + pub type_annotation: Type, } impl Var { - pub fn new(name_hint: String, _span: ObjectRef) -> Var { + pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var { let node = VarNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), vid: Id::new(name_hint.into()), - type_annotation: ObjectRef::null(), + type_annotation, }; Var(Some(ObjectPtr::new(node))) } @@ -150,19 +157,16 @@ impl Var { } } -pub type Type = ObjectRef; -pub type Attrs = ObjectRef; - #[repr(C)] #[derive(Object)] #[ref_name = "Call"] #[type_key = "relay.Call"] pub struct CallNode { - pub base: RelayExpr, + pub base: ExprNode, pub op: Expr, pub args: Array, - pub attrs: ObjectRef, - pub type_args: Array, + pub attrs: Attrs, + pub type_args: Array, } impl Call { @@ -170,11 +174,11 @@ impl Call { op: Expr, args: Array, attrs: Attrs, - type_args: Array, + type_args: Array, _span: ObjectRef, ) -> Call { let node = CallNode { - base: RelayExpr::base::(), + base: ExprNode::base::(), op: op, args: args, attrs: attrs, @@ -186,22 +190,297 @@ impl Call { #[repr(C)] #[derive(Object)] -#[ref_name = "BaseFunc"] -#[type_key = "BaseFunc"] -pub struct BaseFuncNode { - pub base: RelayExpr, - pub attrs: ObjectRef, -} - -impl BaseFuncNode { - fn base() -> BaseFuncNode { - BaseFuncNode { - base: RelayExpr::base::(), - attrs: ObjectRef::null(), +#[ref_name = "Let"] +#[type_key = "relay.Let"] +pub struct LetNode { + pub base: ExprNode, + pub var: Var, + pub value: Expr, + pub body: Expr, +} + +impl Let { + pub fn new(var: Var, value: Expr, body: Expr, _span: ObjectRef) -> Let { + let node = LetNode { + base: ExprNode::base::(), + var, + value, + body, + }; + Let(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "If"] +#[type_key = "relay.If"] +pub struct IfNode { + pub base: ExprNode, + pub cond: Expr, + pub true_branch: Expr, + pub false_branch: Expr, +} + +impl If { + pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, _span: ObjectRef) -> If { + let node = IfNode { + base: ExprNode::base::(), + cond, + true_branch, + false_branch, + }; + If(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TupleGetItem"] +#[type_key = "relay.TupleGetItem"] +pub struct TupleGetItemNode { + pub base: ExprNode, + pub tuple: Expr, + pub index: i32, +} + +impl TupleGetItem { + pub fn new(tuple: Expr, index: i32, _span: ObjectRef) -> TupleGetItem { + let node = TupleGetItemNode { + base: ExprNode::base::(), + tuple, + index, + }; + TupleGetItem(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefCreate"] +#[type_key = "relay.RefCreate"] +pub struct RefCreateNode { + pub base: ExprNode, + pub value: Expr, +} + +impl RefCreate { + pub fn new(value: Expr, _span: ObjectRef) -> RefCreate { + let node = RefCreateNode { + base: ExprNode::base::(), + value, + }; + RefCreate(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefRead"] +#[type_key = "relay.RefRead"] +pub struct RefReadNode { + pub base: ExprNode, + pub ref_value: Expr, +} + +impl RefRead { + pub fn new(ref_value: Expr, _span: ObjectRef) -> RefRead { + let node = RefReadNode { + base: ExprNode::base::(), + ref_value, + }; + RefRead(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefWrite"] +#[type_key = "relay.RefWrite"] +pub struct RefWriteNode { + pub base: ExprNode, + pub ref_value: Expr, + pub value: Expr, +} + +impl RefWrite { + pub fn new(ref_value: Expr, value: Expr, _span: ObjectRef) -> RefWrite { + let node = RefWriteNode { + base: ExprNode::base::(), + ref_value, + value, + }; + RefWrite(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constructor"] +#[type_key = "relay.Constructor"] +pub struct ConstructorNode { + pub base: ExprNode, + pub name_hint: String, + pub inputs: Array, + pub tag: i32, +} + +impl Constructor { + pub fn new(name_hint: String, inputs: Array, tag: i32, _span: ObjectRef) -> Constructor { + let node = ConstructorNode { + base: ExprNode::base::(), + name_hint, + inputs, + tag, + }; + Constructor(Some(ObjectPtr::new(node))) + } +} + +// TODO(@jroesch): define the type data + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Pattern"] +#[type_key = "relay.Pattern"] +pub struct PatternNode { + pub base: Object, + pub span: ObjectRef, +} + +impl PatternNode { + pub fn base() -> PatternNode { + PatternNode { + base: Object::base_object::(), + span: ObjectRef::null(), } } } +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternWildcard"] +#[type_key = "relay.PatternWildcard"] +pub struct PatternWildcardNode { + pub base: PatternNode, +} + +impl PatternWildcard { + pub fn new(_span: ObjectRef) -> PatternWildcard { + let node = PatternWildcardNode { + base: PatternNode::base::(), + }; + PatternWildcard(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternVar"] +#[type_key = "relay.PatternVar"] +pub struct PatternVarNode { + pub base: PatternNode, + pub var: Var, +} + +impl PatternVar { + pub fn new(var: Var, _span: ObjectRef) -> PatternVar { + let node = PatternVarNode { + base: PatternNode::base::(), + var: var, + }; + PatternVar(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternConstructor"] +#[type_key = "relay.PatternConstructor"] +pub struct PatternConstructorNode { + pub base: PatternNode, + pub constructor: Constructor, + pub patterns: Array, +} + +impl PatternConstructor { + pub fn new( + constructor: Constructor, + patterns: Array, + _span: ObjectRef, + ) -> PatternConstructor { + let node = PatternConstructorNode { + base: PatternNode::base::(), + constructor, + patterns, + }; + PatternConstructor(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PatternTuple"] +#[type_key = "relay.PatternTuple"] +pub struct PatternTupleNode { + pub base: PatternNode, + pub patterns: Array, +} + +impl PatternTuple { + pub fn new(patterns: Array, _span: ObjectRef) -> PatternTuple { + let node = PatternTupleNode { + base: PatternNode::base::(), + patterns, + }; + PatternTuple(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Clause"] +#[type_key = "relay.Clause"] +pub struct ClauseNode { + pub base: Object, + pub lhs: Pattern, + pub rhs: Expr, +} + +impl Clause { + pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { + let node = ClauseNode { + base: Object::base_object::(), + lhs, + rhs, + }; + Clause(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Match"] +#[type_key = "relay.Match"] +pub struct MatchNode { + pub base: ExprNode, + pub data: Expr, + pub clauses: Array, + pub complete: bool, +} + +impl Match { + pub fn new(data: Expr, clauses: Array, complete: bool, _span: ObjectRef) -> Match { + let node = MatchNode { + base: ExprNode::base::(), + data, + clauses, + complete, + }; + Match(Some(ObjectPtr::new(node))) + } +} + #[repr(C)] #[derive(Object)] #[ref_name = "Function"] @@ -258,25 +537,47 @@ mod tests { #[test] fn test_var() -> Result<()> { - let var = Var::new("local".to_string(), ObjectRef::null()); + let var = Var::new("local".to_string(), Type::null(), ObjectRef::null()); let text = as_text(var.clone()); assert!(text.contains("%local")); Ok(()) } - use super::Array; - use crate::ir::relay::Var; - use crate::runtime::object::ObjectRef; - #[test] - fn create_array_and_get() -> Result<()> { - let vec = vec![ - Var::new("foo".into(), ObjectRef::null()), - Var::new("bar".into(), ObjectRef::null()), - ]; - let array = Array::from_vec(vec)?; - assert_eq!(array.get(0)?.name_hint().to_string(), "foo"); - assert_eq!(array.get(1)?.name_hint().to_string(), "bar"); + fn test_parse_constant() -> Result<()> { + let module = crate::ir::module::IRModule::parse( + "", + r#" +#[version = "0.0.5"] +def @main() -> float32 { + 0.01639530062675476f +} +"#, + ); + let main = module + .lookup(module.get_global_var("main".to_string().into()).unwrap()) + .unwrap(); + let func = main.downcast::().unwrap(); + let constant = func + .body + .clone() + .downcast::() + .unwrap(); + let tuple_type = constant + .clone() + .upcast::() + .checked_type + .clone() + .downcast::() + .unwrap(); + // Test type + assert_eq!(tuple_type.shape.len(), 0,); + assert_eq!(tuple_type.dtype, "float32".parse().unwrap(),); + // Check that actual data matches up with type + assert_eq!(constant.data.dtype(), "float32".parse().unwrap(),); + assert_eq!(constant.data.len(), 1); + assert_eq!(constant.data.size(), 4); + assert_eq!(constant.data.shape(), &[]); Ok(()) } } diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs new file mode 100644 index 000000000000..d2e19a25a950 --- /dev/null +++ b/rust/tvm/src/ir/span.rs @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::ObjectRef; + +pub type Span = ObjectRef; diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index ee30c513e9f0..22d4e02054e1 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -17,10 +17,11 @@ * under the License. */ +use super::{PrimExpr, PrimExprNode}; use crate::runtime::String as TVMString; use crate::DataType; -use super::*; +use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { @@ -43,6 +44,7 @@ macro_rules! define_node { } } +// TODO(@jroesch): should move up to expr.rs to mirror TVM. define_node!(IntImm, "IntImm", "IntImm"; IntImmNode { value: i64 }); define_node!(Var, "Var", "tir.Var"; diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs new file mode 100644 index 000000000000..b6a47f553da4 --- /dev/null +++ b/rust/tvm/src/ir/ty.rs @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::span::Span; +use crate::runtime::{IsObject, Object, ObjectPtr}; +use tvm_macros::Object; +use tvm_rt::{array::Array, DataType}; + +use super::PrimExpr; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Type"] +#[type_key = "Type"] +pub struct TypeNode { + pub base: Object, + pub span: Span, +} + +impl TypeNode { + fn base(span: Span) -> Self { + TypeNode { + base: Object::base_object::(), + span, + } + } +} + +/* + * \brief Primitive data types used in the low-level IR. + * + * PrimType represents POD-values and handles that are + * not automatically managed by the runtime. + * + * \sa PrimType + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "PrimType"] +#[type_key = "PrimType"] +pub struct PrimTypeNode { + pub base: TypeNode, + /// The corresponding dtype field. + pub dtype: DataType, +} + +/* + *! + * \brief Low-level raw pointer type. + * + * PointerType represents type hints in the TIR to be + * passed to the final code generator. + * + * PointerType should not occur in the high-level analysis. + * + * \sa PointerType + */ + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PointerType"] +#[type_key = "PointerType"] +pub struct PointerTypeNode { + pub base: TypeNode, + /// The type of the element which the pointer points to. + pub element_type: Type, +} + +/// Possible kinds of type variables. +pub enum TypeKind { + Type = 0, + /// Template variable in shape expression. + ShapeVar = 1, + Constraint = 4, + AdtHandle = 5, + TypeData = 6, +} + +/* + * \brief Type parameter in functions. + * + * A type variable can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeVar of f is TypeVar("n", kind=kShapeVar). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeVar, TypeKind + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "TypeVar"] +#[type_key = "TypeVar"] +pub struct TypeVarNode { + pub base: TypeNode, + pub name_hint: String, + pub kind: TypeKind, +} + +/// A global type variable that is used for defining new types or type aliases. +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalTypeVar"] +#[type_key = "GlobalTypeVar"] +pub struct GlobalTypeVarNode { + pub base: TypeNode, + pub name_hint: String, + pub kind: TypeKind, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TupleType"] +#[type_key = "TupleType"] +pub struct TupleTypeNode { + pub base: TypeNode, + pub fields: Array, +} + +impl TupleType { + pub fn empty() -> TupleType { + todo!() + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TypeConstraint"] +#[type_key = "TypeConstraint"] +pub struct TypeConstraintNode { + pub base: TypeNode, +} + +/// The representation of a polymorphic function type. +#[repr(C)] +#[derive(Object)] +#[ref_name = "FuncType"] +#[type_key = "FuncType"] +pub struct FuncTypeNode { + pub base: TypeNode, + /// The type of arguments. + pub arg_types: Array, + /// The return type of the function. + pub ret_type: Type, + /// ... + pub type_params: Array, + /// Type constraints that must hold when + /// calling this function. + pub type_constraints: Array, +} + +/* + * \brief Intermediate values that is used to indicate incomplete type + * during type inference. + * + * If we view the type relations as "computational graph of types", + * then IncompleteType represents intermediate values of the graph, + * TypeVar represents the input to the graph. + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "IncompleteType"] +#[type_key = "IncompleteType"] +pub struct IncompleteTypeNode { + pub base: TypeNode, + pub kind: TypeKind, +} + +/* + * \brief Reference Type High-level Relay IR. + * + * \sa RelayRefType. + */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "RefType"] +#[type_key = "relay.RefType"] +pub struct RelayRefTypeNode { + pub base: TypeNode, + pub value: Type, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseTensorType"] +#[type_key = "relay.BaseTensorType"] +pub struct BaseTensorTypeNode { + pub base: TypeNode, +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "TensorType"] +#[type_key = "relay.TensorType"] +pub struct TensorTypeNode { + pub base: TypeNode, + pub shape: Array, + pub dtype: DataType, +} + +impl TensorType { + pub fn new(shape: Array, dtype: DataType, span: Span) -> TensorType { + let node = TensorTypeNode { + base: TypeNode::base::(span), + shape, + dtype, + }; + ObjectPtr::new(node).into() + } +} +// TODO(@jroesch): implement these in future. +// +// using TypeCall = tvm::TypeCall; +// using TypeCallNode = tvm::TypeCallNode; +// using TypeRelation = tvm::TypeRelation; +// using TypeRelationNode = tvm::TypeRelationNode; +// using TypeRelationFn = tvm::TypeRelationFn; +// using TypeReporter = tvm::TypeReporter; +// using TypeReporterNode = tvm::TypeReporterNode; diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 64252a4f9c6f..36c750328249 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -41,6 +41,8 @@ pub use tvm_rt::module; pub use tvm_rt::ndarray; pub use tvm_rt::value; pub mod ir; +#[cfg(feature = "python")] +pub mod python; pub mod runtime; pub mod transform; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs new file mode 100644 index 000000000000..89558af733b3 --- /dev/null +++ b/rust/tvm/src/python.rs @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use pyo3::prelude::*; + +/// Load the Python interpreter into the address space. +/// +/// This enables the ability for Rust code to call TVM +/// functionality defined in Python. +/// +/// For example registered TVM functions can now be +/// obtained via `Function::get`. +pub fn load() -> Result { + let gil = Python::acquire_gil(); + let py = gil.python(); + load_python_tvm_(py).map_err(|e| { + // We can't display Python exceptions via std::fmt::Display, + // so print the error here manually. + e.print_and_set_sys_last_vars(py); + }) +} + +// const TVMC_CODE: &'static str = include_str!("tvmc.py"); + +fn load_python_tvm_(py: Python) -> PyResult { + let sys = py.import("tvm")?; + let version: String = sys.get("__version__")?.extract()?; + // py.run(TVMC_CODE, None, None)?; + Ok(version) +} + +#[cfg(test)] +mod tests { + use super::load_python_tvm_; + use anyhow::Result; + use pyo3::prelude::*; + + #[ignore] + #[test] + fn test_run() -> Result<()> { + load_python_tvm_(Python::acquire_gil().python()).unwrap(); + Ok(()) + } +} diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm/src/runtime/graph_rt.rs new file mode 100644 index 000000000000..8b26ebb4ca22 --- /dev/null +++ b/rust/tvm/src/runtime/graph_rt.rs @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryInto; + +use crate::runtime::Function; +use crate::{runtime::function::Result, runtime::ByteArray, Context, Module, NDArray}; + +/// An instance of the C++ graph runtime. +/// +/// An efficient and light weight runtime for static deep learning models. +pub struct GraphRt { + /// The backing graph runtime module which exposes a set of packed functions + /// which can be invoked by a client. + /// + /// In the graph runtime module, it exposes create, load_params, set_input, get_output, and run. + module: Module, +} + +impl GraphRt { + /// Create a graph runtime from the deprecated graph, lib, ctx triple. + pub fn create_from_parts(graph: &str, lib: Module, ctx: Context) -> Result { + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + + let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ + graph.into(), + (&lib).into(), + (&ctx.device_type).into(), + // NOTE you must pass the device id in as i32 because that's what TVM expects + (ctx.device_id as i32).into(), + ]); + let graph_runtime_module: Module = runtime_create_fn_ret?.try_into()?; + Ok(Self { + module: graph_runtime_module, + }) + } + + /// Load the parameters of the model into the runtime. + pub fn load_params

(&mut self, params: P) -> Result<()> + where + P: Into, + { + let load_param_fn = self.module.get_function("load_params", false)?; + + let params: ByteArray = params.into(); + + load_param_fn.invoke(vec![(¶ms).into()])?; + + Ok(()) + } + + /// Set the input with name `name` with the value of `input`. + pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { + let ref set_input_fn = self.module.get_function("set_input", false)?; + + set_input_fn.invoke(vec![name.into(), input.into()])?; + Ok(()) + } + + /// Run the graph module, once setting parameters and inputs. + pub fn run(&mut self) -> Result<()> { + let ref run_fn = self.module.get_function("run", false)?; + + // execute the run function. Note that it has no argument + run_fn.invoke(vec![])?; + Ok(()) + } + + /// Extract the ith output from the graph runtime and returns it. + pub fn get_output(&mut self, i: i64) -> Result { + let get_output_fn = self.module.get_function("get_output", false)?; + get_output_fn.invoke(vec![i.into()])?.try_into() + } + + /// Extract the ith output from the graph runtime and write the results into output. + pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { + let get_output_fn = self.module.get_function("get_output", false)?; + get_output_fn.invoke(vec![i.into(), output.into()])?; + Ok(()) + } +} diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs index 69fbb371824a..84da186557f7 100644 --- a/rust/tvm/src/runtime/mod.rs +++ b/rust/tvm/src/runtime/mod.rs @@ -18,3 +18,5 @@ */ pub use tvm_rt::*; + +pub mod graph_rt; diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index 04d8382d3c1f..e4249a491746 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -33,7 +33,7 @@ fn main() { let dtype = DataType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); - let mut ret = NDArray::empty(shape, ctx, dtype); + let ret = NDArray::empty(shape, ctx, dtype); let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); if !fadd.enabled(ctx_name) { return; @@ -44,7 +44,7 @@ fn main() { fadd.entry() .expect("module must have entry point") - .invoke(vec![(&arr).into(), (&arr).into(), (&mut ret).into()]) + .invoke(vec![(&arr).into(), (&arr).into(), (&ret).into()]) .unwrap(); assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index ad41bd18ec8b..2f1848ec6471 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -36,31 +36,28 @@ use tvm::{ fn main() { fn sum(args: Vec>) -> Result { - let mut ret = 0f32; - let shape = &mut [2]; - for arg in args.iter() { - let e = NDArray::empty(shape, Context::cpu(0), DataType::float(32, 1)); + let mut ret = 0.0; + for arg in args { let arg: NDArray = arg.try_into()?; - let arr = arg.copy_to_ndarray(e)?; - let rnd: ArrayD = ArrayD::try_from(&arr)?; + let rnd: ArrayD = ArrayD::try_from(&arg)?; ret += rnd.scalar_sum(); } Ok(RetValue::from(ret)) } - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; + let shape = &[2]; + let data = vec![3.0, 4.0]; let mut arr = NDArray::empty(shape, Context::cpu(0), DataType::float(32, 1)); - arr.copy_from_buffer(data.as_mut_slice()); + arr.copy_from_buffer(data.as_slice()); register_untyped(sum, "sum", true).unwrap(); let func = Function::get("sum").expect("function registered"); let ret: f32 = func - .invoke(vec![(&arr).into(), (&arr).into()]) + .invoke(vec![(&arr).into()]) .unwrap() .try_into() .expect("call should succeed"); - assert_eq!(ret, 7f32); + assert_eq!(ret, 7.0); } diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 438500f45e5e..2b9103b9709a 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -431,8 +431,8 @@ weight transformation in advance. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform") .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, - Array kernel_size, std::string data_layout, - std::string kernel_layout, std::string out_layout, DataType out_dtype) { + Array kernel_size, tvm::String data_layout, + tvm::String kernel_layout, tvm::String out_layout, DataType out_dtype) { return MakeConvGemm( data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform"); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 847f81f72a04..f112a7259552 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -62,13 +62,13 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { - size_t axis = param->kernel_layout.find('O'); + size_t axis = param->kernel_layout.operator std::string().find('O'); CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale } else { // Here, total number of output channels depend on depth multiplier. - size_t o_axis = param->kernel_layout.find('O'); - size_t i_axis = param->kernel_layout.find('I'); + size_t o_axis = param->kernel_layout.operator std::string().find('O'); + size_t i_axis = param->kernel_layout.operator std::string().find('I'); CHECK(o_axis != std::string::npos || i_axis != std::string::npos) << "Kernel layout attribute is not defined"; AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis], diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index a639fcd60af6..54aec99f46fb 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -196,7 +196,8 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto channels = GetConv2DSuperChannelsDim(conv2d); num_filters += channels; } - auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); + auto index = + branches[0][0]->attrs.as()->kernel_layout.operator std::string().find('O'); CHECK_NE(index, std::string::npos); return std::make_tuple(MakeConcatenate(Tuple(weights), index), tir::make_const(DataType::Int(32), num_filters)); diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 17e73048f24c..3c653af01e2e 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -198,7 +198,7 @@ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { auto param = call->attrs.as(); auto tweight = call->args[1]->type_as(); - auto index = param->kernel_layout.find('O'); + auto index = param->kernel_layout.operator std::string().find('O'); CHECK_NE(index, std::string::npos); auto channels = tir::as_const_int(tweight->shape[index]); return *channels; diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 6cf4958a0de0..9c1eeeb973d6 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -273,8 +273,9 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ DLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; - *out = NDArray::Internal::MoveToFFIHandle( - NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx)); + auto ndarray = NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx); + + *out = NDArray::Internal::MoveToFFIHandle(ndarray); API_END(); } diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 29166c627663..98dac93ac98f 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -68,6 +68,13 @@ npm install npm run typedoc cd .. +# TODO(@jroesch): add Rust to CI container +# see: https://github.com/apache/incubator-tvm/issues/6628 +# Rust doc +# cd rust +# cargo doc --workspace --no-deps +# cd .. + # Prepare the doc dir rm -rf _docs mv docs/_build/html _docs @@ -75,6 +82,8 @@ rm -f _docs/.buildinfo mkdir -p _docs/api mv docs/doxygen/html _docs/api/doxygen mv jvm/core/target/site/apidocs _docs/api/javadoc +# See above TODO +# mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/api/typedoc echo "Start creating the docs tarball.." diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index d7b9a5b74406..18361feb03ee 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -20,10 +20,13 @@ set -e set -u export TVM_HOME="$(git rev-parse --show-toplevel)" - +echo "Using TVM_HOME=$TVM_HOME" export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" -export PYTHONPATH="$TVM_HOME/python" +echo "Using LD_LIBRARY_PATH=$LD_LIBRARY_PATH" +export PYTHONPATH="$TVM_HOME/python:${PYTHONPATH}" +echo "Using PYTHONPATH=$PYTHONPATH" export RUST_DIR="$TVM_HOME/rust" +echo "Using RUST_DIR=$RUST_DIR" export LLVM_CONFIG_DEFAULT=`which llvm-config-10` @@ -107,6 +110,8 @@ cargo run --bin array cargo run --bin string cd - -cd examples/resnet -cargo build +# TODO(@jroesch): we need to renable MxNet in ci-cpu image +# https://github.com/apache/incubator-tvm/pull/6563 +# cd examples/resnet +# cargo build cd -