diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6849c039f86f..d9bb3ab065fd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,5 +29,7 @@ members = [ "frontend/tests/callback", "frontend/examples/resnet", "tvm-sys", - "tvm-rt" + "tvm-macros", + "tvm-rt", + "tvm", ] diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml index 1d3373a9e60f..eeead4587de0 100644 --- a/rust/runtime/tests/test_wasm32/Cargo.toml +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -20,7 +20,11 @@ name = "test-wasm32" version = "0.0.0" license = "Apache-2.0" authors = ["TVM Contributors"] +edition = "2018" [dependencies] ndarray="0.12" tvm-runtime = { path = "../../" } + +[build-dependencies] +anyhow = "^1.0" diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs index 8b72be290267..5c816c336825 100644 --- a/rust/runtime/tests/test_wasm32/build.rs +++ b/rust/runtime/tests/test_wasm32/build.rs @@ -19,12 +19,14 @@ use std::{path::PathBuf, process::Command}; -fn main() { +use anyhow::{Context, Result}; + +fn main() -> Result<()> { let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); out_dir.push("lib"); if !out_dir.is_dir() { - std::fs::create_dir(&out_dir).unwrap(); + std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?; } let obj_file = out_dir.join("test.o"); @@ -36,7 +38,8 @@ fn main() { )) .arg(&out_dir) .output() - .expect("Failed to execute command"); + .context("failed to execute Python script for generating TVM library")?; + assert!( obj_file.exists(), "Could not build tvm lib: {}", @@ -49,12 +52,14 @@ fn main() { ); let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + let output = Command::new(ar) .arg("rcs") .arg(&lib_file) .arg(&obj_file) .output() - .expect("Failed to execute command"); + .context("failed to run LLVM_AR command")?; + assert!( lib_file.exists(), "Could not create archive: {}", @@ -68,4 +73,5 @@ fn main() { println!("cargo:rustc-link-lib=static=test_wasm32"); println!("cargo:rustc-link-search=native={}", out_dir.display()); + Ok(()) } diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 8833d6084574..2fcee49d3abd 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let tvm_rt_crate = crate::util::get_tvm_rt_crate(); - let err_type = quote! { #tvm_rt_crate::Error }; + let result_type = quote! { #tvm_rt_crate::function::Result }; let mut items = Vec::new(); @@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { items.push(global); let wrapper = quote! { - pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> { + pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); - let func_ref: Box Result<#ret_type, #err_type>> = func_ref.to_boxed_fn(); + let func_ref: Box #result_type<#ret_type>> = func_ref.to_boxed_fn(); let res: #ret_type = func_ref(#(#args),*)?; Ok(res) } diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index bee22c367189..0170e1d71d41 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate; pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { let tvm_rt_crate = get_tvm_rt_crate(); + let result = quote! { #tvm_rt_crate::function::Result }; + let error = quote! { #tvm_rt_crate::errors::Error }; let derive_input = syn::parse_macro_input!(input as DeriveInput); let payload_id = derive_input.ident; @@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #[derive(Clone)] pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); - impl #tvm_rt_crate::object::ToObjectRef for #ref_id { - fn to_object_ref(&self) -> ObjectRef { - ObjectRef(self.0.as_ref().map(|o| o.upcast())) + impl #tvm_rt_crate::object::IsObjectRef for #ref_id { + type Object = #payload_id; + + fn as_object_ptr(&self) -> Option<&ObjectPtr> { + self.0.as_ref() + } + + fn from_object_ptr(object_ptr: Option>) -> Self { + #ref_id(object_ptr) } } @@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { - type Error = #tvm_rt_crate::Error; + type Error = #error; - fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> { + fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { use std::convert::TryInto; let oref: ObjectRef = ret_val.try_into()?; let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; @@ -125,24 +133,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #tvm_rt_crate::Error; + type Error = #error; - fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { use std::convert::TryInto; let optr = arg_value.try_into()?; Ok(#ref_id(Some(optr))) } } - impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #tvm_rt_crate::Error; - - fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { - use std::convert::TryInto; - let optr = arg_value.try_into()?; - Ok(#ref_id(Some(optr))) - } - } impl From<#ref_id> for #tvm_rt_crate::RetValue { fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs new file mode 100644 index 000000000000..128bb879843b --- /dev/null +++ b/rust/tvm-rt/src/array.rs @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::{TryFrom, TryInto}; +use std::marker::PhantomData; + +use crate::errors::Error; +use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; +use crate::{ + external, + function::{Function, Result}, + RetValue, +}; + +#[repr(C)] +#[derive(Clone)] +pub struct Array { + object: ObjectRef, + _data: PhantomData, +} + +// TODO(@jroesch): convert to use generics instead of casting inside +// the implementation. +external! { + #[name("node.ArrayGetItem")] + fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; +} + +impl Array { + pub fn from_vec(data: Vec) -> Result> { + let iter = data + .iter() + .map(|element| element.to_object_ref().into()) + .collect(); + + let func = Function::get("node.Array").expect( + "node.Array function is not registered, this is most likely a build or linking error", + ); + + // let array_data = func.invoke(iter)?; + // let array_data: ObjectRef = func.invoke(iter)?.try_into()?; + let array_data: ObjectPtr = func.invoke(iter)?.try_into()?; + + debug_assert!( + array_data.count() >= 1, + "array reference count is {}", + array_data.count() + ); + + Ok(Array { + object: ObjectRef(Some(array_data)), + _data: PhantomData, + }) + } + + pub fn get(&self, index: isize) -> Result + where + T: TryFrom, + { + let oref: ObjectRef = array_get_item(self.object.clone(), index)?; + oref.downcast() + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 0b45ebf445bf..779f04e6daa9 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -66,6 +66,8 @@ pub enum Error { NDArray(#[from] NDArrayError), #[error("{0}")] CallFailed(String), + #[error("this case will never occur")] + Infallible(#[from] std::convert::Infallible), } impl Error { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index cb8777a6227b..0772e96e4984 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -32,12 +32,12 @@ use std::{ ptr, str, }; -pub use tvm_sys::{ffi, ArgValue, RetValue}; - use crate::errors::Error; use super::to_boxed_fn::ToBoxedFn; -use super::to_function::{ToFunction, Typed}; + +pub use super::to_function::{ToFunction, Typed}; +pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -65,6 +65,14 @@ impl Function { } } + pub unsafe fn null() -> Self { + Function { + handle: std::ptr::null_mut(), + is_global: false, + from_rust: false, + } + } + /// For a given function, it returns a function by name. pub fn get>(name: S) -> Option { let name = CString::new(name.as_ref()).unwrap(); @@ -171,7 +179,11 @@ impl TryFrom for Function { impl<'a> From for ArgValue<'a> { fn from(func: Function) -> ArgValue<'a> { - ArgValue::FuncHandle(func.handle) + if func.handle.is_null() { + ArgValue::Null + } else { + ArgValue::FuncHandle(func.handle) + } } } diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 10f8317bf7bd..a56a25be82fb 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -91,10 +91,10 @@ pub(crate) fn set_last_error(err: &E) { } } -#[macro_use] -pub mod function; +pub mod array; pub mod context; pub mod errors; +pub mod function; pub mod module; pub mod ndarray; pub mod to_boxed_fn; diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index b7ae4622849d..24fa5e0dfcbc 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -411,17 +411,17 @@ mod tests { assert_eq!(nd.unwrap().to_vec::().unwrap(), data); } - #[test] - #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - fn copy_wrong_dtype() { - let shape = vec![4]; - let mut data = vec![1f32, 2., 3., 4.]; - let ctx = Context::cpu(0); - let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - nd_float.copy_to_ndarray(empty_int).unwrap(); - } + // #[test] + // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + // fn copy_wrong_dtype() { + // let shape = vec![4]; + // let mut data = vec![1f32, 2., 3., 4.]; + // let ctx = Context::cpu(0); + // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + // nd_float.copy_from_buffer(&mut data); + // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + // nd_float.copy_to_ndarray(empty_int).unwrap(); + // } #[test] fn rust_ndarray() { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index c49f84e2d916..e6375bfa09dd 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -39,13 +39,32 @@ impl ObjectRef { } } -pub trait ToObjectRef { - fn to_object_ref(&self) -> ObjectRef; -} +pub trait IsObjectRef: Sized { + type Object: IsObject; + fn as_object_ptr(&self) -> Option<&ObjectPtr>; + fn from_object_ptr(object_ptr: Option>) -> Self; -impl ToObjectRef for ObjectRef { fn to_object_ref(&self) -> ObjectRef { - self.clone() + let object_ptr = self.as_object_ptr().cloned(); + ObjectRef(object_ptr.map(|ptr| ptr.upcast())) + } + + fn downcast(&self) -> Result { + let ptr = self.as_object_ptr().map(|ptr| ptr.downcast::()); + let ptr = ptr.transpose()?; + Ok(U::from_object_ptr(ptr)) + } +} + +impl IsObjectRef for ObjectRef { + type Object = Object; + + fn as_object_ptr(&self) -> Option<&ObjectPtr> { + self.0.as_ref() + } + + fn from_object_ptr(object_ptr: Option>) -> Self { + ObjectRef(object_ptr) } } @@ -73,39 +92,23 @@ impl<'a> std::convert::TryFrom> for ObjectRef { type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result { - let optr = arg_value.try_into()?; + let optr: ObjectPtr = arg_value.try_into()?; + debug_assert!(optr.count() >= 1); Ok(ObjectRef(Some(optr))) } } -impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result { - // TODO(@jroesch): remove the clone - let value: ArgValue<'a> = arg_value.clone(); - ObjectRef::try_from(value) - } -} - impl<'a> From for ArgValue<'a> { fn from(object_ref: ObjectRef) -> ArgValue<'a> { use std::ffi::c_void; - let object_ptr = &object_ref.0; + let object_ptr = object_ref.0; match object_ptr { None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), - Some(value) => value.clone().into(), + Some(value) => value.into(), } } } -impl<'a> From<&ObjectRef> for ArgValue<'a> { - fn from(object_ref: &ObjectRef) -> ArgValue<'a> { - let oref: ObjectRef = object_ref.clone(); - ArgValue::<'a>::from(oref) - } -} - external! { #[name("ir.DebugPrint")] fn debug_print(object: ObjectRef) -> CString; diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 40e218454f6a..ddcbff92c604 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -29,16 +29,36 @@ use crate::errors::Error; type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); +/// A TVM intrusive smart pointer header, in TVM all FFI compatible types +/// start with an Object as their first field. The base object tracks +/// a type_index which is an index into the runtime type information +/// table, an atomic reference count, and a customized deleter which +/// will be invoked when the reference count is zero. +/// #[derive(Debug)] #[repr(C)] pub struct Object { - pub type_index: u32, + /// The index into into TVM's runtime type information table. + pub(self) type_index: u32, // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. // NB: in general we should not touch this in Rust. + /// The reference count of the smart pointer. pub(self) ref_count: AtomicI32, - pub fdeleter: Deleter, + /// The deleter function which is used to deallocate the underlying data + /// when the reference count is zero. This field must always be set for + /// all objects. + /// + /// The common use case is ensuring that the allocator which allocated the + /// data is also the one that deletes it. + pub(self) fdeleter: Deleter, } +/// The default deleter for objects allocated in Rust, we use a bit of +/// trait magic here to get a monomorphized deleter for each object +/// "subtype". +/// +/// This function just transmutes the pointer to the correct type +/// and invokes the underlying typed delete function. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = std::mem::transmute(object); T::typed_delete(typed_object); @@ -63,10 +83,12 @@ impl Object { fn new(type_index: u32, deleter: Deleter) -> Object { Object { type_index, - // Note: do not touch this field directly again, this is - // a critical section, we write a 1 to the atomic which will now - // be managed by the C++ atomics. - // In the future we should probably use C-atomcis. + // NB(@jroesch): I believe it is sound to use Rust atomics + // in conjunction with C++ atomics given the memory model + // is nearly identical. + // + // Of course these are famous last words which I may later + // regret. ref_count: AtomicI32::new(0), fdeleter: deleter, } @@ -75,6 +97,7 @@ impl Object { fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { return 0; } else { @@ -89,11 +112,22 @@ impl Object { } } + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.ref_count.load(std::sync::atomic::Ordering::SeqCst) + } + + /// Allocates a base object value for an object subtype of type T. + /// By using associated constants and generics we can provide a + /// type indexed abstraction over allocating objects with the + /// correct index and deleter. pub fn base_object() -> Object { let index = Object::get_type_index::(); Object::new(index, delete::) } + /// Increases the object's reference count by one. pub(self) fn inc_ref(&self) { unsafe { let raw_ptr = std::mem::transmute(self); @@ -101,6 +135,7 @@ impl Object { } } + /// Decreases the object's reference count by one. pub(self) fn dec_ref(&self) { unsafe { let raw_ptr = std::mem::transmute(self); @@ -109,6 +144,13 @@ impl Object { } } +/// An unsafe trait which should be implemented for an object +/// subtype. +/// +/// The trait contains the type key needed to compute the type +/// index, a method for accessing the base object given the +/// subtype, and a typed delete method which is specialized +/// to the subtype. pub unsafe trait IsObject { const TYPE_KEY: &'static str; @@ -128,6 +170,10 @@ unsafe impl IsObject for Object { } } +/// A smart pointer for types which implement IsObject. +/// This type directly corresponds to TVM's C++ type ObjectPtr. +/// +/// See object.h for more details. #[repr(C)] pub struct ObjectPtr { pub ptr: NonNull, @@ -144,7 +190,10 @@ fn dec_ref(ptr: NonNull) { impl ObjectPtr { fn from_raw(object_ptr: *mut Object) -> Option> { let non_null = NonNull::new(object_ptr); - non_null.map(|ptr| ObjectPtr { ptr }) + non_null.map(|ptr| { + debug_assert!(unsafe { ptr.as_ref().count() } >= 0); + ObjectPtr { ptr } + }) } } @@ -207,9 +256,9 @@ impl ObjectPtr { }; if is_derived { - Ok(ObjectPtr { - ptr: self.ptr.cast(), - }) + let ptr = self.ptr.cast(); + inc_ref(ptr); + Ok(ObjectPtr { ptr }) } else { Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } @@ -240,6 +289,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { RetValue::ObjectHandle(handle) => { let handle: *mut Object = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), @@ -249,7 +299,9 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { + debug_assert!(object_ptr.count() >= 1); let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; ArgValue::ObjectHandle(void_ptr) } @@ -263,21 +315,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; - optr.downcast() - } - _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), - } - } -} - -impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result, Self::Error> { - match arg_value { - ArgValue::ObjectHandle(handle) => { - let handle = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + debug_assert!(optr.count() >= 1); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), @@ -305,6 +343,8 @@ mod tests { let ptr = ObjectPtr::new(Object::base_object::()); let ret_value: RetValue = ptr.clone().into(); let ptr2: ObjectPtr = ret_value.try_into()?; + assert_eq!(ptr.count(), ptr2.count()); + assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, "type indices do not match" @@ -321,6 +361,8 @@ mod tests { let ptr = ObjectPtr::new(Object::base_object::()); let arg_value: ArgValue = ptr.clone().into(); let ptr2: ObjectPtr = arg_value.try_into()?; + assert_eq!(ptr.count(), ptr2.count()); + assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, "type indices do not match" @@ -333,6 +375,7 @@ mod tests { } fn test_fn(o: ObjectPtr) -> ObjectPtr { + // The call machinery adds at least 1 extra count while inside the call. assert_eq!(o.count(), 2); return o; } @@ -341,13 +384,19 @@ mod tests { fn test_ref_count_boundary() { use super::*; use crate::function::{register, Function, Result}; + // 1 let ptr = ObjectPtr::new(Object::base_object::()); + assert_eq!(ptr.count(), 1); + // 2 let stay = ptr.clone(); assert_eq!(ptr.count(), 2); register(test_fn, "my_func").unwrap(); let func = Function::get("my_func").unwrap(); let func = func.to_boxed_fn::) -> Result>>(); - func(ptr).unwrap(); + let same = func(ptr).unwrap(); + assert_eq!(stay.count(), 2); + assert_eq!(same.count(), 2); + drop(same); assert_eq!(stay.count(), 1); } } diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 26758b1170e7..7727e4be2409 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -36,7 +36,7 @@ pub struct StringObj { } impl String { - pub fn new(string: std::string::String) -> Result { + pub fn new(string: std::string::String) -> Result { let cstring = CString::new(string)?; // The string is being corrupted. @@ -69,24 +69,24 @@ impl String { } } -// #[cfg(test)] -// mod tests { -// use super::String; -// use crate::object::debug_print; -// use crate::ToObjectRef; -// use anyhow::{ensure, Result}; +#[cfg(test)] +mod tests { + use super::String; + use crate::object::debug_print; + use crate::IsObjectRef; + use anyhow::{ensure, Result}; -// #[test] -// fn test_string_debug() -> Result<()> { -// let s = String::new("foo".to_string()).unwrap(); -// let object_ref = s.to_object_ref(); -// println!("about to call"); -// let string = debug_print(object_ref)?; -// println!("after call"); -// ensure!( -// string.into_string().expect("is cstring").contains("foo"), -// "string content is invalid" -// ); -// Ok(()) -// } -// } + #[test] + fn test_string_debug() -> Result<()> { + let s = String::new("foo".to_string()).unwrap(); + let object_ref = s.to_object_ref(); + println!("about to call"); + let string = debug_print(object_ref)?; + println!("after call"); + ensure!( + string.into_string().expect("is cstring").contains("foo"), + "string content is invalid" + ); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 4814d098238a..4fc021adb5ab 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -46,28 +46,32 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// And the implementation of it to `ToFunction`. pub trait Typed { fn args(i: &[ArgValue<'static>]) -> Result; - fn ret(o: O) -> RetValue; + fn ret(o: O) -> Result; } -impl> Typed<(), O> for F +impl Typed<(), O> for F where F: Fn() -> O, + Error: From, + O: TryInto, { fn args(_args: &[ArgValue<'static>]) -> Result<()> { debug_assert!(_args.len() == 0); Ok(()) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A,), O> for F +impl Typed<(A,), O> for F where F: Fn(A) -> O, - Error: From, - A: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { debug_assert!(args.len() == 1); @@ -75,17 +79,19 @@ where Ok((a,)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A, B), O> for F +impl Typed<(A, B), O> for F where F: Fn(A, B) -> O, - Error: From, - A: TryFrom, Error = E>, - B: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + B: TryFrom, Error = E1>, + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { debug_assert!(args.len() == 2); @@ -94,18 +100,20 @@ where Ok((a, b)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A, B, C), O> for F +impl Typed<(A, B, C), O> for F where F: Fn(A, B, C) -> O, - Error: From, - A: TryFrom, Error = E>, - B: TryFrom, Error = E>, - C: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + B: TryFrom, Error = E1>, + C: TryFrom, Error = E1>, + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { debug_assert!(args.len() == 3); @@ -115,8 +123,8 @@ where Ok((a, b, c)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } @@ -230,7 +238,7 @@ where { // Ideally we shouldn't need to clone, probably doesn't really matter. let out = unsafe { (*handle)() }; - Ok(F::ret(out)) + F::ret(out) } fn drop(_: *mut Self::Handle) {} @@ -253,7 +261,7 @@ macro_rules! to_function_instance { let out = unsafe { (*handle)($(args.$index),+) }; - Ok(F::ret(out)) + F::ret(out) } fn drop(_: *mut Self::Handle) {} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index 0f455e726d26..231569ba682e 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -57,3 +57,15 @@ pub use context::{Context, DeviceType}; pub use datatype::DataType; pub use errors::*; pub use packed_func::{ArgValue, RetValue}; + +impl std::convert::TryFrom> for RetValue +where + RetValue: std::convert::TryFrom, + E: From<>::Error>, +{ + type Error = E; + + fn try_from(val: Result) -> Result { + val.and_then(|t| RetValue::try_from(t).map_err(|e| e.into())) + } +} diff --git a/rust/tvm/.gitignore b/rust/tvm/.gitignore new file mode 100644 index 000000000000..2430329c78b6 --- /dev/null +++ b/rust/tvm/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm/.travis.yml b/rust/tvm/.travis.yml new file mode 100644 index 000000000000..e963b7c0ede5 --- /dev/null +++ b/rust/tvm/.travis.yml @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml new file mode 100644 index 000000000000..ebfb5e64a4a7 --- /dev/null +++ b/rust/tvm/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust frontend support for TVM" +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-rt = { version = "0.1", path = "../tvm-rt/" } +tvm-sys = { version = "0.1", path = "../tvm-sys/" } +tvm-macros = { version = "*", path = "../tvm-macros/" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm/README.md b/rust/tvm/README.md new file mode 100644 index 000000000000..01e088f2ea81 --- /dev/null +++ b/rust/tvm/README.md @@ -0,0 +1,235 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = fs::read("deploy_param.params")?; + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec + let output = output.to_vec::()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = TVMRetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs new file mode 100644 index 000000000000..4fe13a32ea35 --- /dev/null +++ b/rust/tvm/src/ir/mod.rs @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::String as TString; +use crate::runtime::{self, external, IsObjectRef, Object, ObjectRef}; +use crate::DataType; + +pub mod relay; + +// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) +external! { + #[name("ir.AsText")] + fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; +} + +pub fn as_text(object: T) -> String { + let no_func = unsafe { runtime::Function::null() }; + _as_text(object.to_object_ref(), 0, no_func) + .unwrap() + .to_string() + .unwrap() +} + +#[repr(C)] +pub struct PrimExprNode { + pub base: Object, + pub dtype: DataType, +} + +#[repr(C)] +pub struct IntImmNode { + pub base: PrimExprNode, + pub value: i64, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs new file mode 100644 index 000000000000..cad41acfc307 --- /dev/null +++ b/rust/tvm/src/ir/relay/mod.rs @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::array::Array; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString}; +use crate::DataType; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Id"] +#[type_key = "relay.Id"] +pub struct IdNode { + pub base: Object, + pub name_hint: TString, +} + +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::(), + name_hint: name_hint, + }; + Id(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +#[repr(C)] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl BaseExprNode { + fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Expr"] +#[type_key = "relay.Expr"] +pub struct RelayExpr { + pub base: BaseExprNode, + pub span: ObjectRef, + pub checked_type: ObjectRef, +} + +impl RelayExpr { + fn base() -> RelayExpr { + RelayExpr { + base: BaseExprNode::base::(), + span: ObjectRef::null(), + checked_type: ObjectRef::null(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "GlobalVar"] +pub struct GlobalVarNode { + pub base: RelayExpr, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: RelayExpr::base::(), + name_hint: TString::new(name_hint).unwrap(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constant"] +#[type_key = "relay.Constant"] +pub struct ConstantNode { + pub base: RelayExpr, + pub data: ObjectRef, // make this NDArray. +} + +impl Constant { + pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + let node = ConstantNode { + base: RelayExpr::base::(), + data: data, + }; + Constant(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Var"] +#[type_key = "relay.Var"] +pub struct VarNode { + pub base: RelayExpr, + pub vid: Id, + pub type_annotation: ObjectRef, +} + +impl Var { + pub fn new(name_hint: String, _span: ObjectRef) -> Var { + let node = VarNode { + base: RelayExpr::base::(), + vid: Id::new(TString::new(name_hint.to_string()).unwrap()), + type_annotation: ObjectRef::null(), + }; + Var(Some(ObjectPtr::new(node))) + } + + pub fn name_hint(&self) -> &TString { + &self.vid.0.as_ref().unwrap().name_hint + } + + pub fn to_expr(self) -> Expr { + unsafe { Expr(std::mem::transmute(self.0)) } + } +} + +pub type Type = ObjectRef; +pub type Attrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Call"] +#[type_key = "relay.Call"] +pub struct CallNode { + pub base: RelayExpr, + pub op: Expr, + pub args: Array, + pub attrs: ObjectRef, + pub type_args: Array, +} + +impl Call { + pub fn new( + op: Expr, + args: Array, + attrs: Attrs, + type_args: Array, + _span: ObjectRef, + ) -> Call { + let node = CallNode { + base: RelayExpr::base::(), + op: op, + args: args, + attrs: attrs, + type_args: type_args, + }; + Call(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseFunc"] +#[type_key = "BaseFunc"] +pub struct BaseFuncNode { + pub base: RelayExpr, + pub attrs: ObjectRef, +} + +impl BaseFuncNode { + fn base() -> BaseFuncNode { + BaseFuncNode { + base: RelayExpr::base::(), + attrs: ObjectRef::null(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Function"] +#[type_key = "relay.Function"] +pub struct FunctionNode { + pub base: BaseFuncNode, + pub params: Array, + pub body: Expr, + pub ret_type: Type, + pub type_params: Array, +} + +impl Function { + pub fn new( + params: Array, + body: Expr, + ret_type: Type, + type_params: Array, + ) -> Function { + let node = FunctionNode { + base: BaseFuncNode::base::(), + params: params, + body: body, + ret_type: ret_type, + type_params: type_params, + }; + Function(Some(ObjectPtr::new(node))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ir::as_text; + use crate::runtime::String as TString; + use anyhow::Result; + + #[test] + fn test_id() -> Result<()> { + let string = TString::new("foo".to_string()).expect("bar"); + let id = Id::new(string); + let text = as_text(id.clone()); + assert!(text.contains("relay.Id")); + Ok(()) + } + + #[test] + fn test_global() -> Result<()> { + let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let text = as_text(gv.clone()); + assert!(text.contains("@main")); + Ok(()) + } + + #[test] + fn test_var() -> Result<()> { + let var = Var::new("local".to_string(), ObjectRef::null()); + let text = as_text(var.clone()); + assert!(text.contains("%local")); + Ok(()) + } + + use super::Array; + use crate::ir::relay::Var; + use crate::runtime::object::ObjectRef; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec = vec![ + Var::new("foo".into(), ObjectRef::null()), + Var::new("bar".into(), ObjectRef::null()), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); + assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + Ok(()) + } +} diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs new file mode 100644 index 000000000000..64252a4f9c6f --- /dev/null +++ b/rust/tvm/src/lib.rs @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray}; + +pub use tvm_rt::{Context, DataType, DeviceType}; + +pub use tvm_rt::context; +pub use tvm_rt::errors; +pub use tvm_rt::function; +pub use tvm_rt::module; +pub use tvm_rt::ndarray; +pub use tvm_rt::value; +pub mod ir; +pub mod runtime; +pub mod transform; + +pub use runtime::version; diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs new file mode 100644 index 000000000000..69fbb371824a --- /dev/null +++ b/rust/tvm/src/runtime/mod.rs @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub use tvm_rt::*; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs new file mode 100644 index 000000000000..ab84202af4fa --- /dev/null +++ b/rust/tvm/src/transform.rs @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::ir::relay::Function; +use crate::runtime::array::Array; +use crate::runtime::{ + external, + function::{self, Result, ToFunction}, + String as TString, +}; +use crate::runtime::{Object, ObjectPtr, ObjectRef}; + +use tvm_macros::Object; + +pub type Pass = ObjectRef; +pub type IRModule = ObjectRef; +pub type PassContext = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PassInfo"] +#[type_key = "transform.PassInfo"] +pub struct PassInfoNode { + pub base: Object, + pub opt_level: i32, + pub name: TString, + pub required: Array, +} + +impl PassInfo { + pub fn new(opt_level: i32, name: String, required: Vec) -> Result { + let required: Result<_> = required + .into_iter() + .map(|name| TString::new(name)) + .collect(); + + let required = Array::from_vec(required?)?; + + let node = PassInfoNode { + base: Object::base_object::(), + opt_level, + name: TString::new(name).unwrap(), + required, + }; + + Ok(PassInfo(Some(ObjectPtr::new(node)))) + } +} + +external! { + #[name("relay._transform.MakeFunctionPass")] + fn create_func_pass(func: function::Function, pass_info: PassInfo) -> Pass; +} + +pub fn function_pass Function + 'static>( + pass_fn: F, + pass_info: PassInfo, +) -> Result { + let func = pass_fn.to_function(); + create_func_pass(func, pass_info) +} + +#[macro_export] +macro_rules! export_pass { + ($name:literal,$func:expr) => { + #[no_mangle] + pub unsafe extern "C" fn initialize( + args: *mut tvm_sys::ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: tvm_sys::ffi::TVMRetValueHandle, + ) -> c_int { + register($func, $name).unwrap(); + return 0; + } + }; +} diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bf40f4bdb672..ee11548edf29 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -831,9 +831,7 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { } TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { - std::cout << "The program: " << node << std::endl; auto text = AsText(node, false, nullptr); - std::cout << "The text " << text; return text; });