From c6447b6a90f3c837ed1348b3bbbb906b574345b1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 23 Aug 2021 21:11:24 -0700 Subject: [PATCH 01/14] # This is a combination of 2 commits. # This is the 1st commit message: Initial changes # This is the commit message #2: Ftarget string -> Target object works! --- include/tvm/runtime/c_runtime_api.h | 8 ++ rust/tvm-macros/src/object.rs | 13 +- rust/tvm-rt/src/array.rs | 13 +- rust/tvm-rt/src/function.rs | 17 +-- rust/tvm-rt/src/graph_rt.rs | 7 +- rust/tvm-rt/src/lib.rs | 21 ++-- rust/tvm-rt/src/map.rs | 16 +-- rust/tvm-rt/src/ndarray.rs | 15 +++ rust/tvm-rt/src/object/mod.rs | 19 ++- rust/tvm-rt/src/object/object_ptr.rs | 138 ++++++++++++++++------ rust/tvm-rt/src/to_function.rs | 125 +++++++++++++------- rust/tvm-sys/src/byte_array.rs | 99 ++++++++++------ rust/tvm-sys/src/packed_func.rs | 18 +-- rust/tvm/examples/resnet/src/main.rs | 30 +++-- rust/tvm/src/compiler/graph_rt.rs | 10 +- rust/tvm/src/ir/module.rs | 14 +-- rust/tvm/tests/callback/src/bin/array.rs | 2 +- rust/tvm/tests/callback/src/bin/error.rs | 2 +- rust/tvm/tests/callback/src/bin/float.rs | 2 +- rust/tvm/tests/callback/src/bin/int.rs | 2 +- rust/tvm/tests/callback/src/bin/string.rs | 2 +- src/relay/backend/aot_executor_codegen.cc | 9 +- src/relay/backend/interpreter.cc | 24 ++-- src/relay/backend/te_compiler.cc | 22 ++-- src/relay/backend/te_compiler.h | 6 +- src/relay/backend/utils.h | 2 +- src/runtime/object.cc | 8 ++ 27 files changed, 422 insertions(+), 222 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 17d1ba2a5132..8454b04443a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Convert type index to type key. + * \param tindex The type index. + * \param out_type_key The output type key. + * \return 0 when success, nonzero when failure happens + */ +TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + /*! * \brief Increase the reference count of an object. * diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index c84d0aab612f..4134da5fe6d9 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,18 +156,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.clone().into() + Some(value) => value.into() } } } - impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { - let oref: #ref_id = object_ref.clone(); - #tvm_rt_crate::ArgValue::<'a>::from(oref) - } - } - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index e8902b54f6ef..02c34a1d133f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,19 +45,22 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } + fn into_ptr(self) -> Option> { self.object.into_ptr() } + fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; + Array { object: object_ref, _data: PhantomData, @@ -67,7 +70,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.into_iter().map(T::into_arg_value).collect(); + let iter = data.iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -151,9 +154,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From> for ArgValue<'a> { - fn from(array: Array) -> ArgValue<'a> { - array.object.into() +impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { + fn from(array: &'a Array) -> ArgValue<'a> { + (&array.object).into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 5db665cc7a48..62474e6650d4 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -35,7 +35,8 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{ToFunction, Typed}; +pub use super::to_function::{RawArgs, ToFunction, Typed}; +use crate::object::AsArgValue; pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -153,12 +154,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: Into>),* + $($t: for<'a> AsArgValue<'a>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $($t.into()),* ]; + let args = vec![ $((&$t).as_arg_value()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -196,8 +197,8 @@ impl TryFrom for Function { } } -impl<'a> From for ArgValue<'a> { - fn from(func: Function) -> ArgValue<'a> { +impl<'a> From<&'a Function> for ArgValue<'a> { + fn from(func: &'a Function) -> ArgValue<'a> { if func.handle().is_null() { ArgValue::Null } else { @@ -291,12 +292,12 @@ where } pub fn register_untyped>( - f: fn(Vec>) -> Result, + f: for<'a> fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - // TODO(@jroesch): can we unify all the code. - let func = f.to_function(); + //TODO(@jroesch): can we unify the untpyed and typed registration functions. + let func = ToFunction::::to_function(f); let name = name.into(); // Not sure about this code let handle = func.handle(); diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 7db53d466665..53f3210aa742 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -50,11 +50,12 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - lib.into(), + (&lib).into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), ]); + let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?; Ok(Self { module: graph_executor_module, @@ -79,7 +80,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), input.into()])?; + set_input_fn.invoke(vec![name.into(), (&input).into()])?; Ok(()) } @@ -101,7 +102,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), output.into()])?; + get_output_fn.invoke(vec![i.into(), (&output).into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 824dc63f0b50..3b7d066e7b78 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -130,16 +130,17 @@ mod tests { ); } - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } + // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. + // #[test] + // fn bytearray() { + // let w = vec![1u8, 2, 3, 4, 5]; + // let v = ByteArray::from(w.as_slice()); + // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + // assert_eq!( + // tvm.data(), + // w.iter().copied().collect::>().as_slice() + // ); + // } #[test] fn ty() { diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index d6dfaf3641b8..5594a91dc0f0 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl FromIterator<(K, V)> for Map +impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into()); - buffer.push(v.into()) + buffer.push(k.into_arg_value()); + buffer.push(v.into_arg_value()); } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From> for ArgValue<'a> +impl<'a, K, V> From<&'a Map> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: Map) -> ArgValue<'a> { - map.object.into() + fn from(map: &'a Map) -> ArgValue<'a> { + (&map.object).into() } } @@ -268,7 +268,7 @@ mod test { let mut std_map: HashMap = HashMap::new(); std_map.insert("key1".into(), "value1".into()); std_map.insert("key2".into(), "value2".into()); - let tvm_map = Map::from_iter(std_map.clone().into_iter()); + let tvm_map = Map::from_iter(std_map.iter()); let back_map = tvm_map.into(); assert_eq!(std_map, back_map); } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 08dcfe33f28f..80f8f184140c 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -101,6 +101,21 @@ impl NDArrayContainer { .cast::() } } + + pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + object_ptr + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } + } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 8c07ed9f0853..f5832fcb3ab8 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,6 +29,19 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; +pub trait AsArgValue<'a> { + fn as_arg_value(&'a self) -> ArgValue<'a>; +} + +impl<'a, T: 'static> AsArgValue<'a> for T +where + &'a T: Into>, +{ + fn as_arg_value(&'a self) -> ArgValue<'a> { + self.into() + } +} + // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -37,8 +50,8 @@ pub trait IsObjectRef: Sized + Clone + Into + + for<'a> AsArgValue<'a> + TryFrom - + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -51,8 +64,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(self) -> ArgValue<'a> { - self.into() + fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { + self.as_arg_value() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index a093cf5fe3ae..09d6068f1a88 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -20,11 +20,14 @@ use std::convert::TryFrom; use std::ffi::CString; use std::fmt; +use std::os::raw::c_char; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::ffi::{ + self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, +}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -62,10 +65,12 @@ pub struct Object { /// "subtype". /// /// This function just converts the pointer to the correct type -/// and invokes the underlying typed delete function. +/// and reconstructs a Box which then is dropped to deallocate +/// the underlying allocation. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = object as *mut T; - T::typed_delete(typed_object); + let boxed: Box = Box::from_raw(typed_object); + drop(boxed); } fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { @@ -98,6 +103,18 @@ impl Object { } } + fn get_type_key(&self) -> String { + let mut cstring: *mut c_char = std::ptr::null_mut(); + unsafe { + if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { + panic!("{}", crate::get_last_error()); + } + return CString::from_raw(cstring) + .into_string() + .expect("type keys should be valid utf-8"); + } + } + fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); @@ -157,11 +174,6 @@ impl Object { /// to the subtype. pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; - - unsafe extern "C" fn typed_delete(object: *mut Self) { - let object = Box::from_raw(object); - drop(object) - } } /// A smart pointer for types which implement IsObject. @@ -252,13 +264,18 @@ impl ObjectPtr { if is_derived { Ok(unsafe { self.cast() }) } else { - Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + let type_key = self.as_ref().get_type_key(); + Err(Error::downcast(type_key.into(), U::TYPE_KEY)) } } pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } + + pub unsafe fn as_ptr(&self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -308,26 +325,25 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From> for ArgValue<'a> { - fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { + fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.upcast::(); + let object_ptr = object_ptr.clone().upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) - as *mut NDArrayContainer as *mut std::ffi::c_void; + let dcast_ptr = object_ptr.downcast().unwrap(); + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -345,14 +361,22 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); + optr.inc_ref(); + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); optr.downcast() } ArgValue::NDArrayHandle(handle) => { let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); + // TODO(@jroesch): figure out if there is a more optimal way to do this + let object = optr.upcast::(); + object.inc_ref(); + object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } @@ -440,11 +464,12 @@ mod tests { assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = ptr_clone.into(); + let arg_value: ArgValue = (&ptr_clone).into(); assert_eq!(ptr.count(), 2); let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 2); + assert_eq!(ptr2.count(), 3); assert_eq!(ptr.count(), ptr2.count()); + drop(ptr_clone); assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, @@ -460,26 +485,71 @@ mod tests { Ok(()) } - fn test_fn(o: ObjectPtr) -> ObjectPtr { - // The call machinery adds at least 1 extra count while inside the call. + fn test_fn_raw<'a>( + mut args: crate::to_function::ArgList<'a>, + ) -> crate::function::Result { + let v: ArgValue = args.remove(0); + let v2: ArgValue = args.remove(0); + // assert_eq!(o.count(), 2); + let o: ObjectPtr = v.try_into().unwrap(); + assert_eq!(o.count(), 2); + let o2: ObjectPtr = v2.try_into().unwrap(); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); + Ok(o.into()) + } + + #[test] + fn test_ref_count_raw_fn() { + use super::*; + use crate::function::{register_untyped, Function}; + let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = same.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); + let raw_func = Function::get("test_fn_raw").unwrap(); + let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + } + + fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { assert_eq!(o.count(), 3); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); return o; } #[test] - fn test_ref_count_boundary3() { + fn test_ref_count_typed() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let output = test_fn_typed(ptr.clone(), ptr.clone()); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register(test_fn_typed, "test_fn_typed").unwrap(); + let typed_func = Function::get("test_fn_typed").unwrap(); + let output = typed_func + .invoke(vec![(&ptr).into(), (&ptr).into()]) + .unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); assert_eq!(ptr.count(), 1); - let stay = ptr.clone(); - assert_eq!(ptr.count(), 2); - register(test_fn, "my_func2").unwrap(); - let func = Function::get("my_func2").unwrap(); - let same = func.invoke(vec![ptr.into()]).unwrap(); - let same: ObjectPtr = same.try_into().unwrap(); - // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); - drop(same); - assert_eq!(stay.count(), 3); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 7797d2cd23ff..67fbfc996af0 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -44,8 +44,16 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. + +pub type ArgList<'a> = Vec>; + +pub enum Args<'a, I> { + Typed(I), + Raw(ArgList<'a>), +} + pub trait Typed { - fn args(i: Vec>) -> Result; + fn args<'arg>(i: Vec>) -> Result>; fn ret(o: O) -> Result; } @@ -54,7 +62,7 @@ pub trait ToFunction: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where Self: Typed; @@ -70,7 +78,7 @@ pub trait ToFunction: Sized { check_call!(ffi::TVMFuncCreateFromCFunc( Some(Self::tvm_callback), resource_handle as *mut _, - None, // Some(Self::tvm_finalizer), + Some(Self::tvm_finalizer), &mut fhandle as *mut ffi::TVMFunctionHandle, )); @@ -102,22 +110,28 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - { - check_call!(ffi::TVMCbArgToReturn( - &mut value as *mut _, - &mut tcode as *mut _ - )); - } + // TODO(@jroesch): I believe it is sound to disable this specialized move rule. + // + // This is used in C++ to deal with moving an RValue or reference to a return value + // directly so you can skip copying. + // + // I believe this is not needed as the move directly occurs into the Rust function. + + // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int + // { + // check_call!(ffi::TVMCbArgToReturn( + // &mut value as *mut _, + // &mut tcode as *mut _ + // )); + // } let arg_value = ArgValue::from_tvm_value(value, tcode as u32); local_args.push(arg_value); } - // Ref-count be 2. let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -125,6 +139,12 @@ pub trait ToFunction: Sized { } }; + // TODO(@jroesch): clean up the handling of the is dec_ref + match rv.clone().try_into() as Result> { + Err(_) => {} + Ok(v) => drop(v), + }; + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -165,9 +185,11 @@ pub trait ToFunction: Sized { } } -impl Typed>, RetValue> for fn(Vec>) -> Result { - fn args(args: Vec>) -> Result>> { - Ok(args) +pub struct RawArgs; + +impl Typed for for<'a> fn(Vec>) -> Result { + fn args<'arg>(args: Vec>) -> Result> { + Ok(Args::Raw(args)) } fn ret(o: RetValue) -> Result { @@ -175,43 +197,59 @@ impl Typed>, RetValue> for fn(Vec>) -> R } } -impl ToFunction>, RetValue> - for fn(Vec>) -> Result -{ - type Handle = fn(Vec>) -> Result; +impl ToFunction for for<'arg> fn(Vec>) -> Result { + type Handle = for<'arg> fn(Vec>) -> Result; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(self); Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { (*handle)(args) } + fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { + unsafe { + let func = *handle; + func(args) + } } fn drop(_: *mut Self::Handle) {} } +/// A helper trait which correctly captures the complex conversion and lifetime semantics needed +/// to coerce an ordinary Rust value into `ArgValue`. +pub trait TryFromArgValue: TryFrom { + fn from_arg_value(f: F) -> std::result::Result; +} + +impl<'a, T> TryFromArgValue> for T +where + Self: TryFrom>, + Error: From<>>::Error>, +{ + fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { + Ok(TryFrom::try_from(f)?) + } +} + macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for F + impl Typed<($($t,)*), Out> for Fun where - F: Fn($($t),*) -> Out, + Fun: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<$t::Error>, )* + $( for<'a> $t: TryFromArgValue>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args<'arg>(args: Vec>) -> Result> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), $len, args.len()))) } let mut args = args.into_iter(); - $(let $t = args.next().unwrap().try_into()?;)* - Ok(($($t,)*)) + $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* + Ok(Args::Typed(($($t,)*))) } fn ret(out: Out) -> Result { @@ -220,9 +258,9 @@ macro_rules! impl_typed_and_to_function { } - impl ToFunction<($($t,)*), Out> for F + impl ToFunction<($($t,)*), Out> for Fun where - F: Fn($($t,)*) -> Out + 'static + Fun: Fn($($t,)*) -> Out + 'static { type Handle = Box Out + 'static>; @@ -232,13 +270,18 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: Typed<($($t,)*), Out> + Fun: Typed<($($t,)*), Out> { - let ($($t,)*) = F::args(args)?; - let out = unsafe { (*handle)($($t),*) }; - F::ret(out) + let ($($t,)*) = match Fun::args(args)? { + Args::Raw(_) => panic!("impossible case"), + Args::Typed(typed) => typed, + }; + + let fn_ptr = unsafe { &*handle }; + let out = fn_ptr($($t),*); + Fun::ret(out) } fn drop(ptr: *mut Self::Handle) { @@ -255,13 +298,15 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, G); +impl_typed_and_to_function!(6; A, B, C, D, E, F); +impl_typed_and_to_function!(7; A, B, C, D, E, F, G); +impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); #[cfg(test)] mod tests { use super::*; - fn call(f: F, args: Vec>) -> Result + fn call<'a, F, I, O>(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 4b005abee7ef..2903a81d9c36 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -17,10 +17,9 @@ * under the License. */ use std::convert::TryFrom; -use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::TVMByteArray; +use crate::ffi::{TVMByteArray, TVMByteArrayFree}; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -33,20 +32,45 @@ use crate::{ArgValue, RetValue}; /// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// ``` -pub struct ByteArray { - /// The raw FFI ByteArray. - array: TVMByteArray, +pub enum ByteArray { + Rust(TVMByteArray), + External(TVMByteArray), +} + +impl Drop for ByteArray { + fn drop(&mut self) { + match self { + ByteArray::Rust(bytes) => { + let ptr = bytes.data; + let len = bytes.size as _; + let cap = bytes.size as _; + let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; + drop(data); + } + ByteArray::External(byte_array) => unsafe { + if TVMByteArrayFree(byte_array as _) != 0 { + panic!("error"); + } + }, + } + } } impl ByteArray { /// Gets the underlying byte-array - pub fn data(&self) -> &'static [u8] { - unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) } + pub fn data(&self) -> &[u8] { + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { + std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) + }, + } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { - self.array.size as _ + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, + } } /// Converts the underlying byte-array to `Vec` @@ -59,50 +83,49 @@ impl ByteArray { } } -// Needs AsRef for Vec -impl> From for ByteArray { +impl>> From for ByteArray { fn from(arg: T) -> Self { - let arg = arg.as_ref(); - ByteArray { - array: TVMByteArray { - data: arg.as_ptr() as *const c_char, - size: arg.len() as _, - }, - } + let mut incoming_bytes: Vec = arg.into(); + let mut bytes = Vec::with_capacity(incoming_bytes.len()); + bytes.append(&mut incoming_bytes); + + let mut bytes = std::mem::ManuallyDrop::new(bytes); + let ptr = bytes.as_mut_ptr(); + assert_eq!(bytes.len(), bytes.capacity()); + ByteArray::Rust(TVMByteArray { + data: ptr as _, + size: bytes.len() as _, + }) } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { - ArgValue::Bytes(&val.array) - } -} - -impl TryFrom> for ByteArray { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'static>) -> Result { match val { - ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + ArgValue::Bytes(byte_array) + } } } } -impl From for RetValue { - fn from(val: ByteArray) -> RetValue { - RetValue::Bytes(val.array) - } -} +// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. +// impl From for RetValue { +// fn from(val: ByteArray) -> RetValue { +// match val { +// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { +// // TODO(@jroesch): This requires a little more work, going to land narratives +// RetValue::Bytes(byte_array) +// } +// } +// } +// } impl TryFrom for ByteArray { type Error = ValueDowncastError; fn try_from(val: RetValue) -> Result { match val { - RetValue::Bytes(array) => Ok(ByteArray { array }), + RetValue::Bytes(array) => Ok(ByteArray::External(array)), _ => Err(ValueDowncastError { expected_type: "ByteArray", actual_type: format!("{:?}", val), @@ -118,11 +141,11 @@ mod tests { #[test] fn convert() { let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); let v = b"hello"; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 6f43b786780a..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -224,7 +224,7 @@ macro_rules! impl_pod_value { } } - impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + impl<'a> From<&'a $type> for ArgValue<'a> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From for ArgValue<'a> { - fn from(s: CString) -> Self { - Self::String(s.into_raw()) +impl<'a> From<&'a CString> for ArgValue<'a> { + fn from(s: &'a CString) -> Self { + Self::String(s.as_ptr() as _) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl From<*const T> for ArgValue<'static> { +impl<'a, T> From<*const T> for ArgValue<'a> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl From<*mut T> for ArgValue<'static> { +impl<'a, T> From<*mut T> for ArgValue<'a> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From for ArgValue<'a> { - fn from(s: bool) -> Self { - (s as i64).into() +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index bd0de1c56ba3..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -78,24 +78,40 @@ fn main() -> anyhow::Result<()> { "/deploy_lib.so" )))?; - let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; - // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - println!("param bytes: {}", params.len()); - graph_rt.load_params(¶ms)?; + // If you want an easy way to test a memory leak simply replace the program below with: + // let mut output: Vec; + + // loop { + // let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + // graph_rt.load_params(params.clone())?; + // graph_rt.set_input("data", input.clone())?; + // graph_rt.run()?; + + // // prepare to get the output + // let output_shape = &[1, 1000]; + // let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + // graph_rt.get_output_into(0, output_nd.clone())?; + + // // flatten the output as Vec + // output = output_nd.to_vec::()?; + // } + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; + graph_rt.load_params(params)?; graph_rt.set_input("data", input)?; graph_rt.run()?; // prepare to get the output let output_shape = &[1, 1000]; - let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output.clone())?; + let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output_nd.clone())?; // flatten the output as Vec - let output = output.to_vec::()?; + let output: Vec = output_nd.to_vec::()?; // find the maximum entry in the output and its index let (argmax, max_prob) = output diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs index 6b5873398cab..8313e47bea20 100644 --- a/rust/tvm/src/compiler/graph_rt.rs +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -51,11 +51,11 @@ fn _compile_module( ) -> Result { // The RAW API is Fn(IRModule, String, String, Map, String); let module = TVM_BUILD.invoke(vec![ - module.into(), - target.into(), - target_host.into(), - params.into(), - module_name.into(), + (&module).into(), + (&target).into(), + (&target_host).into(), + (¶ms).into(), + (&module_name).into(), ])?; let module: RtModule = module.try_into().unwrap(); Ok(module) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 513a906f6db4..ea257af1ebc0 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -99,10 +99,10 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new(funcs: F, types: T) -> Result + pub fn new<'a, F, T>(funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, { module_new(Map::from_iter(funcs), Map::from_iter(types)) } @@ -110,7 +110,7 @@ impl IRModule { pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs, types) + IRModule::new(funcs.iter(), types.iter()) } pub fn parse(file_name: N, source: S) -> Result @@ -206,10 +206,10 @@ impl IRModule { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } - pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + pub fn from_expr_with_items<'a, E, F, T>(expr: E, funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, E: IsObjectRef, E::Object: AsRef<::Object>, { diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index 81ee426d3967..8deae30c076d 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -35,7 +35,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args { let arg: NDArray = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs index 37027af0ca37..f8886a55c3a2 100644 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -26,7 +26,7 @@ use tvm::{ }; fn main() { - fn error(_args: Vec>) -> Result { + fn error<'a>(_args: Vec>) -> Result { Err(errors::NDArrayError::DataTypeMismatch { expected: DataType::int(64, 1), actual: DataType::float(64, 1), diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs index 6fd4f868dc79..d575f47c87cd 100644 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -27,7 +27,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs index cdea2e1044c4..fc2e40d8de4d 100644 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -25,7 +25,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs index dbe65ba4c631..4f3d67e95d64 100644 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -26,7 +26,7 @@ use tvm::{ // FIXME fn main() { - fn concat_str(args: Vec>) -> Result { + fn concat_str<'a>(args: Vec>) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2b88f0489321 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - auto target_host_str = target_host_->str(); - if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Update(mod_run); + if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_str, mod_run); + ret.lowered_funcs.Set(target_host_, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..232985bdfb5c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -53,7 +53,11 @@ namespace { struct PairHash { template std::size_t operator()(const std::pair& k) const { - return std::hash()(k.first) ^ std::hash()(k.second); + return dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + template + std::size_t operator()(const std::pair& k) const { + return dmlc::HashCombine(ObjectHash()(k.first), std::hash()(k.second)); } }; @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + std::pair packed_func_key(target, tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -382,7 +386,7 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target->str()); + auto mod_itr = per_target_module_.find(target); ICHECK(mod_itr != per_target_module_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; @@ -407,7 +411,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,10 +878,10 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) @@ -895,7 +899,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1018,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1061,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 71ac752ec680..1a244ec728f1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,18 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + Map lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,11 +104,11 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } return lowered_functions; } @@ -884,7 +884,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const String target = kv.first; + const Target target = kv.first; const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so @@ -926,7 +926,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -934,7 +934,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - Optional target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e9cfb0d62e66..1089aa96070b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -67,7 +67,7 @@ struct EnumClassHash { } }; -// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// TODO(@jroesch, @chrisS) these shoumakeld be a tvm::Map for uniformity sake // we should a version of context which works in Map using TargetMap = std::unordered_map; using DeviceMap = @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..bf13715b7d46 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 1892ce780a4c..3cd5df613f4a 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -262,3 +262,11 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } + +int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { + API_BEGIN(); + auto key = tvm::runtime::Object::TypeIndex2Key(tindex); + *out_type_key = static_cast(malloc(key.size() + 1)); + strncpy(*out_type_key, key.c_str(), key.size()); + API_END(); +} From 3f19fcae8c3dc5540cbeb307cc397daf1dbbad7c Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 23 Aug 2021 23:42:54 -0700 Subject: [PATCH 02/14] Fix remaining target strings --- src/relay/backend/build_module.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b2b73e9bad02..367b91bbdcb8 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -92,8 +92,8 @@ struct ExecutorCodegen { return CallFunc>("get_external_modules", nullptr); } - Map GetIRModule() { - return CallFunc>("get_irmodule", nullptr); + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); } runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } @@ -490,9 +490,10 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); + // TODO(@electriclilies): How do I check if target object is ext_dev? // No need to build for external functions. - if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) { - lowered_funcs.Set("ext_dev", IRModule()); + if (lowered_funcs.find(tvm::Target("ext_dev")) != lowered_funcs.end()) { + lowered_funcs.Set(tvm::Target("ext_dev"), IRModule()); } // Generate a placeholder function that attaches linked params as its arguments. @@ -510,10 +511,10 @@ class RelayBuildModule : public runtime::ModuleNode { DictAttrs attrs{dict}; auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); - if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) { - lowered_funcs.Set(target_host->str(), IRModule(Map({}))); + if (lowered_funcs.find(target_host) == lowered_funcs.end()) { + lowered_funcs.Set(target_host, IRModule(Map({}))); } - lowered_funcs[target_host->str()]->Add( + lowered_funcs[target_host]->Add( GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); } From 8aaee0c24f0544eb7b08e1365cc91d51e17ae39f Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:22:38 -0700 Subject: [PATCH 03/14] fix bad rebase --- include/tvm/runtime/c_runtime_api.h | 8 -- rust/tvm-macros/src/object.rs | 13 +- rust/tvm-rt/src/array.rs | 13 +- rust/tvm-rt/src/function.rs | 17 ++- rust/tvm-rt/src/graph_rt.rs | 7 +- rust/tvm-rt/src/lib.rs | 21 ++-- rust/tvm-rt/src/map.rs | 16 +-- rust/tvm-rt/src/ndarray.rs | 15 --- rust/tvm-rt/src/object/mod.rs | 19 +-- rust/tvm-rt/src/object/object_ptr.rs | 138 ++++++---------------- rust/tvm-rt/src/to_function.rs | 125 +++++++------------- rust/tvm-sys/src/byte_array.rs | 99 ++++++---------- rust/tvm-sys/src/packed_func.rs | 18 +-- rust/tvm/examples/resnet/src/main.rs | 30 ++--- rust/tvm/src/compiler/graph_rt.rs | 10 +- rust/tvm/src/ir/module.rs | 14 +-- rust/tvm/tests/callback/src/bin/array.rs | 2 +- rust/tvm/tests/callback/src/bin/error.rs | 2 +- rust/tvm/tests/callback/src/bin/float.rs | 2 +- rust/tvm/tests/callback/src/bin/int.rs | 2 +- rust/tvm/tests/callback/src/bin/string.rs | 2 +- 21 files changed, 192 insertions(+), 381 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 8454b04443a1..17d1ba2a5132 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -520,14 +520,6 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); -/*! - * \brief Convert type index to type key. - * \param tindex The type index. - * \param out_type_key The output type key. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); - /*! * \brief Increase the reference count of an object. * diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 4134da5fe6d9..c84d0aab612f 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,11 +156,18 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.into() + Some(value) => value.clone().into() } } } + impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { + let oref: #ref_id = object_ref.clone(); + #tvm_rt_crate::ArgValue::<'a>::from(oref) + } + } + impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 02c34a1d133f..e8902b54f6ef 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,22 +45,19 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } - fn into_ptr(self) -> Option> { self.object.into_ptr() } - fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; - Array { object: object_ref, _data: PhantomData, @@ -70,7 +67,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.iter().map(T::into_arg_value).collect(); + let iter = data.into_iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -154,9 +151,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { - fn from(array: &'a Array) -> ArgValue<'a> { - (&array.object).into() +impl<'a, T: IsObjectRef> From> for ArgValue<'a> { + fn from(array: Array) -> ArgValue<'a> { + array.object.into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 62474e6650d4..5db665cc7a48 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -35,8 +35,7 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{RawArgs, ToFunction, Typed}; -use crate::object::AsArgValue; +pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -154,12 +153,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: for<'a> AsArgValue<'a>),* + $($t: Into>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $((&$t).as_arg_value()),* ]; + let args = vec![ $($t.into()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -197,8 +196,8 @@ impl TryFrom for Function { } } -impl<'a> From<&'a Function> for ArgValue<'a> { - fn from(func: &'a Function) -> ArgValue<'a> { +impl<'a> From for ArgValue<'a> { + fn from(func: Function) -> ArgValue<'a> { if func.handle().is_null() { ArgValue::Null } else { @@ -292,12 +291,12 @@ where } pub fn register_untyped>( - f: for<'a> fn(Vec>) -> Result, + f: fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - //TODO(@jroesch): can we unify the untpyed and typed registration functions. - let func = ToFunction::::to_function(f); + // TODO(@jroesch): can we unify all the code. + let func = f.to_function(); let name = name.into(); // Not sure about this code let handle = func.handle(); diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 53f3210aa742..7db53d466665 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -50,12 +50,11 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - (&lib).into(), + lib.into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), ]); - let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?; Ok(Self { module: graph_executor_module, @@ -80,7 +79,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), (&input).into()])?; + set_input_fn.invoke(vec![name.into(), input.into()])?; Ok(()) } @@ -102,7 +101,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), (&output).into()])?; + get_output_fn.invoke(vec![i.into(), output.into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 3b7d066e7b78..824dc63f0b50 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -130,17 +130,16 @@ mod tests { ); } - // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. - // #[test] - // fn bytearray() { - // let w = vec![1u8, 2, 3, 4, 5]; - // let v = ByteArray::from(w.as_slice()); - // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - // assert_eq!( - // tvm.data(), - // w.iter().copied().collect::>().as_slice() - // ); - // } + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } #[test] fn ty() { diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 5594a91dc0f0..d6dfaf3641b8 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map +impl FromIterator<(K, V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into_arg_value()); - buffer.push(v.into_arg_value()); + buffer.push(k.into()); + buffer.push(v.into()) } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From<&'a Map> for ArgValue<'a> +impl<'a, K, V> From> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: &'a Map) -> ArgValue<'a> { - (&map.object).into() + fn from(map: Map) -> ArgValue<'a> { + map.object.into() } } @@ -268,7 +268,7 @@ mod test { let mut std_map: HashMap = HashMap::new(); std_map.insert("key1".into(), "value1".into()); std_map.insert("key2".into(), "value2".into()); - let tvm_map = Map::from_iter(std_map.iter()); + let tvm_map = Map::from_iter(std_map.clone().into_iter()); let back_map = tvm_map.into(); assert_eq!(std_map, back_map); } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 80f8f184140c..08dcfe33f28f 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -101,21 +101,6 @@ impl NDArrayContainer { .cast::() } } - - pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer - where - NDArrayContainer: 'a, - { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { - object_ptr - .ptr - .as_ptr() - .cast::() - .offset(base_offset) - .cast::() - } - } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index f5832fcb3ab8..8c07ed9f0853 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,19 +29,6 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; -pub trait AsArgValue<'a> { - fn as_arg_value(&'a self) -> ArgValue<'a>; -} - -impl<'a, T: 'static> AsArgValue<'a> for T -where - &'a T: Into>, -{ - fn as_arg_value(&'a self) -> ArgValue<'a> { - self.into() - } -} - // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -50,8 +37,8 @@ pub trait IsObjectRef: Sized + Clone + Into - + for<'a> AsArgValue<'a> + TryFrom + + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -64,8 +51,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { - self.as_arg_value() + fn into_arg_value<'a>(self) -> ArgValue<'a> { + self.into() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 09d6068f1a88..a093cf5fe3ae 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -20,14 +20,11 @@ use std::convert::TryFrom; use std::ffi::CString; use std::fmt; -use std::os::raw::c_char; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{ - self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, -}; +use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -65,12 +62,10 @@ pub struct Object { /// "subtype". /// /// This function just converts the pointer to the correct type -/// and reconstructs a Box which then is dropped to deallocate -/// the underlying allocation. +/// and invokes the underlying typed delete function. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = object as *mut T; - let boxed: Box = Box::from_raw(typed_object); - drop(boxed); + T::typed_delete(typed_object); } fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { @@ -103,18 +98,6 @@ impl Object { } } - fn get_type_key(&self) -> String { - let mut cstring: *mut c_char = std::ptr::null_mut(); - unsafe { - if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { - panic!("{}", crate::get_last_error()); - } - return CString::from_raw(cstring) - .into_string() - .expect("type keys should be valid utf-8"); - } - } - fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); @@ -174,6 +157,11 @@ impl Object { /// to the subtype. pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; + + unsafe extern "C" fn typed_delete(object: *mut Self) { + let object = Box::from_raw(object); + drop(object) + } } /// A smart pointer for types which implement IsObject. @@ -264,18 +252,13 @@ impl ObjectPtr { if is_derived { Ok(unsafe { self.cast() }) } else { - let type_key = self.as_ref().get_type_key(); - Err(Error::downcast(type_key.into(), U::TYPE_KEY)) + Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } } pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } - - pub unsafe fn as_ptr(&self) -> *mut T { - self.ptr.as_ptr() - } } impl std::ops::Deref for ObjectPtr { @@ -325,25 +308,26 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { - fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From> for ArgValue<'a> { + fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.clone().upcast::(); + let object_ptr = object_ptr.upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - let dcast_ptr = object_ptr.downcast().unwrap(); - let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; + // TODO(this is probably not optimal) + let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) + as *mut NDArrayContainer as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -361,22 +345,14 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - optr.inc_ref(); - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); + debug_assert!(optr.count() >= 1); optr.downcast() } ArgValue::NDArrayHandle(handle) => { let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); - // TODO(@jroesch): figure out if there is a more optimal way to do this - let object = optr.upcast::(); - object.inc_ref(); - object.downcast() + debug_assert!(optr.count() >= 1); + optr.upcast::().downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } @@ -464,12 +440,11 @@ mod tests { assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = (&ptr_clone).into(); + let arg_value: ArgValue = ptr_clone.into(); assert_eq!(ptr.count(), 2); let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 3); + assert_eq!(ptr2.count(), 2); assert_eq!(ptr.count(), ptr2.count()); - drop(ptr_clone); assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, @@ -485,71 +460,26 @@ mod tests { Ok(()) } - fn test_fn_raw<'a>( - mut args: crate::to_function::ArgList<'a>, - ) -> crate::function::Result { - let v: ArgValue = args.remove(0); - let v2: ArgValue = args.remove(0); - // assert_eq!(o.count(), 2); - let o: ObjectPtr = v.try_into().unwrap(); - assert_eq!(o.count(), 2); - let o2: ObjectPtr = v2.try_into().unwrap(); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); - Ok(o.into()) - } - - #[test] - fn test_ref_count_raw_fn() { - use super::*; - use crate::function::{register_untyped, Function}; - let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = same.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); - let raw_func = Function::get("test_fn_raw").unwrap(); - let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - } - - fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { + fn test_fn(o: ObjectPtr) -> ObjectPtr { + // The call machinery adds at least 1 extra count while inside the call. assert_eq!(o.count(), 3); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); return o; } #[test] - fn test_ref_count_typed() { + fn test_ref_count_boundary3() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let output = test_fn_typed(ptr.clone(), ptr.clone()); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register(test_fn_typed, "test_fn_typed").unwrap(); - let typed_func = Function::get("test_fn_typed").unwrap(); - let output = typed_func - .invoke(vec![(&ptr).into(), (&ptr).into()]) - .unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); assert_eq!(ptr.count(), 1); + let stay = ptr.clone(); + assert_eq!(ptr.count(), 2); + register(test_fn, "my_func2").unwrap(); + let func = Function::get("my_func2").unwrap(); + let same = func.invoke(vec![ptr.into()]).unwrap(); + let same: ObjectPtr = same.try_into().unwrap(); + // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); + drop(same); + assert_eq!(stay.count(), 3); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 67fbfc996af0..7797d2cd23ff 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -44,16 +44,8 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. - -pub type ArgList<'a> = Vec>; - -pub enum Args<'a, I> { - Typed(I), - Raw(ArgList<'a>), -} - pub trait Typed { - fn args<'arg>(i: Vec>) -> Result>; + fn args(i: Vec>) -> Result; fn ret(o: O) -> Result; } @@ -62,7 +54,7 @@ pub trait ToFunction: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result + fn call(handle: *mut Self::Handle, args: Vec>) -> Result where Self: Typed; @@ -78,7 +70,7 @@ pub trait ToFunction: Sized { check_call!(ffi::TVMFuncCreateFromCFunc( Some(Self::tvm_callback), resource_handle as *mut _, - Some(Self::tvm_finalizer), + None, // Some(Self::tvm_finalizer), &mut fhandle as *mut ffi::TVMFunctionHandle, )); @@ -110,28 +102,22 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - // TODO(@jroesch): I believe it is sound to disable this specialized move rule. - // - // This is used in C++ to deal with moving an RValue or reference to a return value - // directly so you can skip copying. - // - // I believe this is not needed as the move directly occurs into the Rust function. - - // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - // { - // check_call!(ffi::TVMCbArgToReturn( - // &mut value as *mut _, - // &mut tcode as *mut _ - // )); - // } + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + } let arg_value = ArgValue::from_tvm_value(value, tcode as u32); local_args.push(arg_value); } + // Ref-count be 2. let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -139,12 +125,6 @@ pub trait ToFunction: Sized { } }; - // TODO(@jroesch): clean up the handling of the is dec_ref - match rv.clone().try_into() as Result> { - Err(_) => {} - Ok(v) => drop(v), - }; - let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -185,11 +165,9 @@ pub trait ToFunction: Sized { } } -pub struct RawArgs; - -impl Typed for for<'a> fn(Vec>) -> Result { - fn args<'arg>(args: Vec>) -> Result> { - Ok(Args::Raw(args)) +impl Typed>, RetValue> for fn(Vec>) -> Result { + fn args(args: Vec>) -> Result>> { + Ok(args) } fn ret(o: RetValue) -> Result { @@ -197,59 +175,43 @@ impl Typed for for<'a> fn(Vec>) -> Result for for<'arg> fn(Vec>) -> Result { - type Handle = for<'arg> fn(Vec>) -> Result; +impl ToFunction>, RetValue> + for fn(Vec>) -> Result +{ + type Handle = fn(Vec>) -> Result; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(self); Box::into_raw(ptr) } - fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { - let func = *handle; - func(args) - } + fn call(handle: *mut Self::Handle, args: Vec>) -> Result { + unsafe { (*handle)(args) } } fn drop(_: *mut Self::Handle) {} } -/// A helper trait which correctly captures the complex conversion and lifetime semantics needed -/// to coerce an ordinary Rust value into `ArgValue`. -pub trait TryFromArgValue: TryFrom { - fn from_arg_value(f: F) -> std::result::Result; -} - -impl<'a, T> TryFromArgValue> for T -where - Self: TryFrom>, - Error: From<>>::Error>, -{ - fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { - Ok(TryFrom::try_from(f)?) - } -} - macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for Fun + impl Typed<($($t,)*), Out> for F where - Fun: Fn($($t),*) -> Out, + F: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( for<'a> $t: TryFromArgValue>, )* + $( $t: TryFrom>, + Error: From<$t::Error>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args<'arg>(args: Vec>) -> Result> { + fn args(args: Vec>) -> Result<($($t,)*)> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), $len, args.len()))) } let mut args = args.into_iter(); - $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* - Ok(Args::Typed(($($t,)*))) + $(let $t = args.next().unwrap().try_into()?;)* + Ok(($($t,)*)) } fn ret(out: Out) -> Result { @@ -258,9 +220,9 @@ macro_rules! impl_typed_and_to_function { } - impl ToFunction<($($t,)*), Out> for Fun + impl ToFunction<($($t,)*), Out> for F where - Fun: Fn($($t,)*) -> Out + 'static + F: Fn($($t,)*) -> Out + 'static { type Handle = Box Out + 'static>; @@ -270,18 +232,13 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result + fn call(handle: *mut Self::Handle, args: Vec>) -> Result where - Fun: Typed<($($t,)*), Out> + F: Typed<($($t,)*), Out> { - let ($($t,)*) = match Fun::args(args)? { - Args::Raw(_) => panic!("impossible case"), - Args::Typed(typed) => typed, - }; - - let fn_ptr = unsafe { &*handle }; - let out = fn_ptr($($t),*); - Fun::ret(out) + let ($($t,)*) = F::args(args)?; + let out = unsafe { (*handle)($($t),*) }; + F::ret(out) } fn drop(ptr: *mut Self::Handle) { @@ -298,15 +255,13 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, F); -impl_typed_and_to_function!(7; A, B, C, D, E, F, G); -impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); +impl_typed_and_to_function!(6; A, B, C, D, E, G); #[cfg(test)] mod tests { use super::*; - fn call<'a, F, I, O>(f: F, args: Vec>) -> Result + fn call(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 2903a81d9c36..4b005abee7ef 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -17,9 +17,10 @@ * under the License. */ use std::convert::TryFrom; +use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::{TVMByteArray, TVMByteArrayFree}; +use crate::ffi::TVMByteArray; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -32,45 +33,20 @@ use crate::{ArgValue, RetValue}; /// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// ``` -pub enum ByteArray { - Rust(TVMByteArray), - External(TVMByteArray), -} - -impl Drop for ByteArray { - fn drop(&mut self) { - match self { - ByteArray::Rust(bytes) => { - let ptr = bytes.data; - let len = bytes.size as _; - let cap = bytes.size as _; - let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; - drop(data); - } - ByteArray::External(byte_array) => unsafe { - if TVMByteArrayFree(byte_array as _) != 0 { - panic!("error"); - } - }, - } - } +pub struct ByteArray { + /// The raw FFI ByteArray. + array: TVMByteArray, } impl ByteArray { /// Gets the underlying byte-array - pub fn data(&self) -> &[u8] { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { - std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) - }, - } + pub fn data(&self) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, - } + self.array.size as _ } /// Converts the underlying byte-array to `Vec` @@ -83,49 +59,50 @@ impl ByteArray { } } -impl>> From for ByteArray { +// Needs AsRef for Vec +impl> From for ByteArray { fn from(arg: T) -> Self { - let mut incoming_bytes: Vec = arg.into(); - let mut bytes = Vec::with_capacity(incoming_bytes.len()); - bytes.append(&mut incoming_bytes); - - let mut bytes = std::mem::ManuallyDrop::new(bytes); - let ptr = bytes.as_mut_ptr(); - assert_eq!(bytes.len(), bytes.capacity()); - ByteArray::Rust(TVMByteArray { - data: ptr as _, - size: bytes.len() as _, - }) + let arg = arg.as_ref(); + ByteArray { + array: TVMByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len() as _, + }, + } } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { + ArgValue::Bytes(&val.array) + } +} + +impl TryFrom> for ByteArray { + type Error = ValueDowncastError; + + fn try_from(val: ArgValue<'static>) -> Result { match val { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { - ArgValue::Bytes(byte_array) - } + ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), + _ => Err(ValueDowncastError { + expected_type: "ByteArray", + actual_type: format!("{:?}", val), + }), } } } -// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. -// impl From for RetValue { -// fn from(val: ByteArray) -> RetValue { -// match val { -// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { -// // TODO(@jroesch): This requires a little more work, going to land narratives -// RetValue::Bytes(byte_array) -// } -// } -// } -// } +impl From for RetValue { + fn from(val: ByteArray) -> RetValue { + RetValue::Bytes(val.array) + } +} impl TryFrom for ByteArray { type Error = ValueDowncastError; fn try_from(val: RetValue) -> Result { match val { - RetValue::Bytes(array) => Ok(ByteArray::External(array)), + RetValue::Bytes(array) => Ok(ByteArray { array }), _ => Err(ValueDowncastError { expected_type: "ByteArray", actual_type: format!("{:?}", val), @@ -141,11 +118,11 @@ mod tests { #[test] fn convert() { let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(v.to_vec()); + let barr = ByteArray::from(&v); assert_eq!(barr.len(), v.len()); assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); let v = b"hello"; - let barr = ByteArray::from(v.to_vec()); + let barr = ByteArray::from(&v); assert_eq!(barr.len(), v.len()); assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d..6f43b786780a 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -224,7 +224,7 @@ macro_rules! impl_pod_value { } } - impl<'a> From<&'a $type> for ArgValue<'a> { + impl<'a, 'v> From<&'a $type> for ArgValue<'v> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From<&'a CString> for ArgValue<'a> { - fn from(s: &'a CString) -> Self { - Self::String(s.as_ptr() as _) +impl<'a> From for ArgValue<'a> { + fn from(s: CString) -> Self { + Self::String(s.into_raw()) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl<'a, T> From<*const T> for ArgValue<'a> { +impl From<*const T> for ArgValue<'static> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl<'a, T> From<*mut T> for ArgValue<'a> { +impl From<*mut T> for ArgValue<'static> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From<&bool> for ArgValue<'a> { - fn from(s: &bool) -> Self { - (*s as i64).into() +impl<'a> From for ArgValue<'a> { + fn from(s: bool) -> Self { + (s as i64).into() } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index c22d55f2e4da..bd0de1c56ba3 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -78,40 +78,24 @@ fn main() -> anyhow::Result<()> { "/deploy_lib.so" )))?; + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; + // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - println!("param bytes: {}", params.len()); - - // If you want an easy way to test a memory leak simply replace the program below with: - // let mut output: Vec; - // loop { - // let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; - // graph_rt.load_params(params.clone())?; - // graph_rt.set_input("data", input.clone())?; - // graph_rt.run()?; - - // // prepare to get the output - // let output_shape = &[1, 1000]; - // let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - // graph_rt.get_output_into(0, output_nd.clone())?; - - // // flatten the output as Vec - // output = output_nd.to_vec::()?; - // } + println!("param bytes: {}", params.len()); - let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; - graph_rt.load_params(params)?; + graph_rt.load_params(¶ms)?; graph_rt.set_input("data", input)?; graph_rt.run()?; // prepare to get the output let output_shape = &[1, 1000]; - let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output_nd.clone())?; + let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output.clone())?; // flatten the output as Vec - let output: Vec = output_nd.to_vec::()?; + let output = output.to_vec::()?; // find the maximum entry in the output and its index let (argmax, max_prob) = output diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs index 8313e47bea20..6b5873398cab 100644 --- a/rust/tvm/src/compiler/graph_rt.rs +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -51,11 +51,11 @@ fn _compile_module( ) -> Result { // The RAW API is Fn(IRModule, String, String, Map, String); let module = TVM_BUILD.invoke(vec![ - (&module).into(), - (&target).into(), - (&target_host).into(), - (¶ms).into(), - (&module_name).into(), + module.into(), + target.into(), + target_host.into(), + params.into(), + module_name.into(), ])?; let module: RtModule = module.try_into().unwrap(); Ok(module) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index ea257af1ebc0..513a906f6db4 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -99,10 +99,10 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new<'a, F, T>(funcs: F, types: T) -> Result + pub fn new(funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, { module_new(Map::from_iter(funcs), Map::from_iter(types)) } @@ -110,7 +110,7 @@ impl IRModule { pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs.iter(), types.iter()) + IRModule::new(funcs, types) } pub fn parse(file_name: N, source: S) -> Result @@ -206,10 +206,10 @@ impl IRModule { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } - pub fn from_expr_with_items<'a, E, F, T>(expr: E, funcs: F, types: T) -> Result + pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, E: IsObjectRef, E::Object: AsRef<::Object>, { diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index 8deae30c076d..81ee426d3967 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -35,7 +35,7 @@ use tvm::{ }; fn main() { - fn sum<'a>(args: Vec>) -> Result { + fn sum(args: Vec>) -> Result { let mut ret = 0.0; for arg in args { let arg: NDArray = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs index f8886a55c3a2..37027af0ca37 100644 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -26,7 +26,7 @@ use tvm::{ }; fn main() { - fn error<'a>(_args: Vec>) -> Result { + fn error(_args: Vec>) -> Result { Err(errors::NDArrayError::DataTypeMismatch { expected: DataType::int(64, 1), actual: DataType::float(64, 1), diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs index d575f47c87cd..6fd4f868dc79 100644 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -27,7 +27,7 @@ use tvm::{ }; fn main() { - fn sum<'a>(args: Vec>) -> Result { + fn sum(args: Vec>) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs index fc2e40d8de4d..cdea2e1044c4 100644 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -25,7 +25,7 @@ use tvm::{ }; fn main() { - fn sum<'a>(args: Vec>) -> Result { + fn sum(args: Vec>) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs index 4f3d67e95d64..dbe65ba4c631 100644 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -26,7 +26,7 @@ use tvm::{ // FIXME fn main() { - fn concat_str<'a>(args: Vec>) -> Result { + fn concat_str(args: Vec>) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; From eb288ac2d5f97b20b0154501a34a5a9bd86e9fca Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:28:20 -0700 Subject: [PATCH 04/14] Fix typo --- src/relay/backend/build_module.cc | 1 - src/relay/backend/te_compiler.h | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 367b91bbdcb8..b2099bd409c0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -490,7 +490,6 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); - // TODO(@electriclilies): How do I check if target object is ext_dev? // No need to build for external functions. if (lowered_funcs.find(tvm::Target("ext_dev")) != lowered_funcs.end()) { lowered_funcs.Set(tvm::Target("ext_dev"), IRModule()); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 1089aa96070b..65ba67ac7e1b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -67,7 +67,7 @@ struct EnumClassHash { } }; -// TODO(@jroesch, @chrisS) these shoumakeld be a tvm::Map for uniformity sake +// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake // we should a version of context which works in Map using TargetMap = std::unordered_map; using DeviceMap = From 76257bbdd24ea8a9032e27afad85c774386285ce Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:29:34 -0700 Subject: [PATCH 05/14] 1 more bad rebase fix --- src/runtime/object.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 3cd5df613f4a..1892ce780a4c 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -262,11 +262,3 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } - -int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { - API_BEGIN(); - auto key = tvm::runtime::Object::TypeIndex2Key(tindex); - *out_type_key = static_cast(malloc(key.size() + 1)); - strncpy(*out_type_key, key.c_str(), key.size()); - API_END(); -} From d67e8855f10f0e9e6d8ec7746130004a2343eb56 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:35:16 -0700 Subject: [PATCH 06/14] Lint --- src/relay/backend/build_module.cc | 4 ++-- src/relay/backend/interpreter.cc | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b2099bd409c0..7f781dbf3c51 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -513,8 +513,8 @@ class RelayBuildModule : public runtime::ModuleNode { if (lowered_funcs.find(target_host) == lowered_funcs.end()) { lowered_funcs.Set(target_host, IRModule(Map({}))); } - lowered_funcs[target_host]->Add( - GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); + lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), + prim); } // When there is no lowered_funcs due to reasons such as optimization. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 232985bdfb5c..9ba2ab78591b 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -881,8 +881,7 @@ class Interpreter : public ExprFunctor, Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> - compiled_packed_funcs_; + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; From ee7881e80f94d94c59c9235b843e7f75723cdd0f Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:42:20 -0700 Subject: [PATCH 07/14] typo --- src/relay/backend/build_module.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 7f781dbf3c51..02dba236cc7e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -491,8 +491,8 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. - if (lowered_funcs.find(tvm::Target("ext_dev")) != lowered_funcs.end()) { - lowered_funcs.Set(tvm::Target("ext_dev"), IRModule()); + if (lowered_funcs.find(Target("ext_dev")) != lowered_funcs.end()) { + lowered_funcs.Set(Target("ext_dev"), IRModule()); } // Generate a placeholder function that attaches linked params as its arguments. From e3ca300cf014a374f56cbd4142ecd7e745f32558 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 24 Aug 2021 10:51:41 -0700 Subject: [PATCH 08/14] Forgot to commit this --- src/relay/backend/build_module.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 02dba236cc7e..69dced36295e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = executor_codegen_->GetIRModule(); // No need to build for external functions. - if (lowered_funcs.find(Target("ext_dev")) != lowered_funcs.end()) { - lowered_funcs.Set(Target("ext_dev"), IRModule()); + Target ext_dev("ext_dev"); + if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) { + lowered_funcs.Set(ext_dev, IRModule()); } // Generate a placeholder function that attaches linked params as its arguments. From ee60645b73524100d7e81d6874b25146289bb591 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 27 Aug 2021 08:28:55 -0700 Subject: [PATCH 09/14] Add TargetStrHash and Map* target, Target* host); * \param host The Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Map* target, Target* host); + +// TODO(@electriclilies): Move to somewhere in backend and add note about appropriate use + +/*! \brief Target hash function */ +struct TargetStrHash { + /*! + * \brief Calculate the hash code of a Target based on the string + * \param a The given Target + * \return String hash of the target + */ + size_t operator()(const Target& target) const { + return String::HashBytes(target->str().c_str(), target->str().size()); + } +}; + +/*! \brief Target equality functino based on string */ +struct TargetStrEqual { + /*! + * \brief Check if the two Targets are equal + * \param a One Target + * \param b The other Target + * \return String equality of the targets + */ + const bool operator()(const Target& a, const Target& b) const { + TargetStrHash target_hash = TargetStrHash(); + return target_hash(a) == target_hash(b); + } +}; + +// TODO(@electriclilies): Add documentation +std::unordered_map TargetModuleMapToStdMap(Map input_map); + + } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bfea3e7b67c0..dd99a6e5429e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -59,6 +59,7 @@ Target DefaultTargetHost(Target target) { return target; } else { if (LLVMEnabled()) { + // Target("llvm") created here return Target("llvm"); } else { return Target("stackvm"); @@ -522,6 +523,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar CheckAndUpdateHostConsistency(&target, &target_host); Optional device = target->GetAttr("device"); if (device.defined() && device.value() == "vta") { + // Target("ext_dev") created here target = Target("ext_dev"); } updated_inputs.Set(target, it.second); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 69dced36295e..c27ec8b13e6b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -384,6 +384,7 @@ class RelayBuildModule : public runtime::ModuleNode { */ Target CreateDefaultTarget(int device_type) { std::string name = runtime::DeviceName(device_type); + // Target("llvm") created here if (name == "cpu") return Target("llvm"); if (name == "cuda") return Target("cuda"); return Target(name); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6142e8323dea..21a3d9893dc5 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -191,6 +191,7 @@ class CompileEngineImpl : public CompileEngineNode { const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = std::string(name_node.value()); + // Target("ext_dev") created here auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 9ba2ab78591b..2cacf87f41f0 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -386,10 +386,13 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target); - ICHECK(mod_itr != per_target_module_.end()) + std::unordered_map per_target_module_std_map_ = TargetModuleMapToStdMap(per_target_module_); + auto mod_itr = per_target_module_std_map_.find(target); + ICHECK(mod_itr != per_target_module_std_map_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; + std::cout << "Target module map: " << per_target_module_ << std::endl; + std::cout << "Target module: " << target_module << std::endl; for (const auto& var : all_tir_fn_vars) { ICHECK(target_module->ContainGlobalVar(var->name_hint)) << "No global var for '" << var->name_hint << "' in module for target '" << target->str() diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 1a244ec728f1..dc83265bb5a3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -201,6 +201,7 @@ class TECompilerImpl : public TECompilerNode { const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = GetUniqueName(name_node.value(), &name_map_); + // Target("ext_dev") created here auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3eab91d202c..308371356005 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -527,6 +527,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; if (func->GetAttr(attr::kCompiler).defined()) { + // Target("ext_dev") created here target = Target("ext_dev"); } else { // Next generate the invoke instruction. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index d545518c1c3c..7a6500d9493e 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -232,6 +232,7 @@ class ConstantFolder : public MixedModeMutator { Device dev; dev.device_type = kDLCPU; dev.device_id = 0; + // Target("llvm") created here Target target = Target("llvm"); // use a fresh build context in case we are already in a build context. diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 12c7a3132947..8e296666b890 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -324,6 +324,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { return; } if (!target_.defined()) { + // Target("llvm") created here target_ = Target("llvm"); } llvm::EngineBuilder builder(std::move(module_)); diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d7..c18b4b9c2525 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -826,6 +826,15 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, return output; } +// Helper to convert the tvm::Map to a std::unordered_map +std::unordered_map TargetModuleMapToStdMap(Map input_map) { + std::unordered_map std_map; + for (auto kv : input_map) { + std_map[kv.first] = kv.second; + } + return std_map; +} + /********** Registry **********/ TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index f774fc01a5f4..815442a7ce7e 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -50,6 +50,7 @@ TEST(BuildModule, Basic) { auto args = Array({A, B, C}); std::unordered_map binds; + // Target("llvm") created here auto target = Target("llvm"); auto lowered = LowerSchedule(s, args, "func", binds); @@ -88,6 +89,7 @@ TEST(BuildModule, Heterogeneous) { return; } + // Target("llvm") created here auto target_llvm = Target("llvm"); auto target_cuda = Target("cuda"); From 8da2c5446f08f48deecb911447376bcd730be306 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 27 Aug 2021 11:33:57 -0700 Subject: [PATCH 10/14] Passing most tests, yay --- include/tvm/target/target.h | 32 ++++++++++++++++++++++++----- src/driver/driver_api.cc | 2 -- src/relay/backend/compile_engine.cc | 1 - src/relay/backend/interpreter.cc | 5 ++--- src/relay/backend/te_compiler.cc | 9 ++++---- src/relay/backend/vm/compiler.cc | 1 - src/target/target.cc | 12 ++++++++++- 7 files changed, 44 insertions(+), 18 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 7969904590d0..5fa2e0a59549 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -32,6 +32,7 @@ #include #include +#include #include namespace tvm { @@ -209,16 +210,19 @@ void CheckAndUpdateHostConsistency(Map* target, Target* host); /*! \brief Target hash function */ struct TargetStrHash { /*! - * \brief Calculate the hash code of a Target based on the string + * \brief Calculate the hash code of a Target based on the string value of the Target + This will be removed when maps from Targets to IRModules are removed from the codebase. * \param a The given Target * \return String hash of the target */ size_t operator()(const Target& target) const { - return String::HashBytes(target->str().c_str(), target->str().size()); + return String::HashBytes(target->str().c_str(), target->str().size()); } }; -/*! \brief Target equality functino based on string */ +/*! \brief Target equality function based on the string value of Target +This will be removed when maps from Targets to IRModules are removed from the +codebase.*/ struct TargetStrEqual { /*! * \brief Check if the two Targets are equal @@ -232,9 +236,27 @@ struct TargetStrEqual { } }; -// TODO(@electriclilies): Add documentation -std::unordered_map TargetModuleMapToStdMap(Map input_map); +/*! + * \brief Convert a Map to std::unordered_map Target equality is currently based on pointer equality, which is a problem since + * we have a lot of Map in the codebase. This function converts the map to a + * version that is keyed based on string value of the Target instead. Note that once we remove + * Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map); +/*! + * \brief Convert a std::unordered_map to + * Map This function is a helper that undoes TargetModuleMapToTargetStr. Note that + * once we remove Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index dd99a6e5429e..bfea3e7b67c0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -59,7 +59,6 @@ Target DefaultTargetHost(Target target) { return target; } else { if (LLVMEnabled()) { - // Target("llvm") created here return Target("llvm"); } else { return Target("stackvm"); @@ -523,7 +522,6 @@ runtime::Module build(const Map& inputs_arg, const Target& tar CheckAndUpdateHostConsistency(&target, &target_host); Optional device = target->GetAttr("device"); if (device.defined() && device.value() == "vta") { - // Target("ext_dev") created here target = Target("ext_dev"); } updated_inputs.Set(target, it.second); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 21a3d9893dc5..6142e8323dea 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -191,7 +191,6 @@ class CompileEngineImpl : public CompileEngineNode { const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = std::string(name_node.value()); - // Target("ext_dev") created here auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 2cacf87f41f0..b7c9a98f6e01 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -386,13 +386,12 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - std::unordered_map per_target_module_std_map_ = TargetModuleMapToStdMap(per_target_module_); + std::unordered_map per_target_module_std_map_ = + TargetModuleMapToTargetStrModuleMap(per_target_module_); auto mod_itr = per_target_module_std_map_.find(target); ICHECK(mod_itr != per_target_module_std_map_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; - std::cout << "Target module map: " << per_target_module_ << std::endl; - std::cout << "Target module: " << target_module << std::endl; for (const auto& var : all_tir_fn_vars) { ICHECK(target_module->ContainGlobalVar(var->name_hint)) << "No global var for '" << var->name_hint << "' in module for target '" << target->str() diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index dc83265bb5a3..2e95232086f6 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -86,14 +86,14 @@ class TECompilerImpl : public TECompilerNode { } Map GetLoweredFunctions() { - Map lowered_functions; + std::unordered_map lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; if (!lowered_functions.count(target)) { - lowered_functions.Set(target, IRModule(Map({}))); + lowered_functions[target] = IRModule(Map({})); } lowered_functions[target]->Update(lowered_func->cached_func->funcs); @@ -105,12 +105,12 @@ class TECompilerImpl : public TECompilerNode { auto target = source_func->target; if (!lowered_functions.count(target)) { - lowered_functions.Set(target, IRModule(Map({}))); + lowered_functions[target] = IRModule(Map({})); } lowered_functions[target]->Update(lowered_func->cached_func->funcs); } - return lowered_functions; + return TargetStrModuleMapToTargetModuleMap(lowered_functions); } Array LowerExternalFunctions() { @@ -201,7 +201,6 @@ class TECompilerImpl : public TECompilerNode { const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = GetUniqueName(name_node.value(), &name_map_); - // Target("ext_dev") created here auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 308371356005..b3eab91d202c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -527,7 +527,6 @@ class VMFunctionCompiler : ExprFunctor { Target target; if (func->GetAttr(attr::kCompiler).defined()) { - // Target("ext_dev") created here target = Target("ext_dev"); } else { // Next generate the invoke instruction. diff --git a/src/target/target.cc b/src/target/target.cc index c18b4b9c2525..aa2717d41bd7 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -827,7 +827,8 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, } // Helper to convert the tvm::Map to a std::unordered_map -std::unordered_map TargetModuleMapToStdMap(Map input_map) { +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map) { std::unordered_map std_map; for (auto kv : input_map) { std_map[kv.first] = kv.second; @@ -835,6 +836,15 @@ std::unordered_map TargetModule return std_map; } +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map) { + Map tvm_map; + for (auto kv : input_map) { + tvm_map.Set(kv.first, kv.second); + } + return tvm_map; +} + /********** Registry **********/ TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher); From 1ebe6236bbd8bccf82a4219e19d01cf7b7cfbaf2 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 27 Aug 2021 11:40:50 -0700 Subject: [PATCH 11/14] remove some comments --- src/relay/backend/build_module.cc | 1 - src/relay/transforms/fold_constant.cc | 1 - src/target/llvm/llvm_module.cc | 1 - tests/cpp/build_module_test.cc | 2 -- 4 files changed, 5 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index c27ec8b13e6b..69dced36295e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -384,7 +384,6 @@ class RelayBuildModule : public runtime::ModuleNode { */ Target CreateDefaultTarget(int device_type) { std::string name = runtime::DeviceName(device_type); - // Target("llvm") created here if (name == "cpu") return Target("llvm"); if (name == "cuda") return Target("cuda"); return Target(name); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 7a6500d9493e..d545518c1c3c 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -232,7 +232,6 @@ class ConstantFolder : public MixedModeMutator { Device dev; dev.device_type = kDLCPU; dev.device_id = 0; - // Target("llvm") created here Target target = Target("llvm"); // use a fresh build context in case we are already in a build context. diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 8e296666b890..12c7a3132947 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -324,7 +324,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { return; } if (!target_.defined()) { - // Target("llvm") created here target_ = Target("llvm"); } llvm::EngineBuilder builder(std::move(module_)); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 815442a7ce7e..f774fc01a5f4 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -50,7 +50,6 @@ TEST(BuildModule, Basic) { auto args = Array({A, B, C}); std::unordered_map binds; - // Target("llvm") created here auto target = Target("llvm"); auto lowered = LowerSchedule(s, args, "func", binds); @@ -89,7 +88,6 @@ TEST(BuildModule, Heterogeneous) { return; } - // Target("llvm") created here auto target_llvm = Target("llvm"); auto target_cuda = Target("cuda"); From 4a65400114a760a11bf15d578659d70b7630d0ed Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 27 Aug 2021 11:43:11 -0700 Subject: [PATCH 12/14] lint --- include/tvm/target/target.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 5fa2e0a59549..1a5e38f781e9 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -31,8 +31,8 @@ #include #include -#include #include +#include #include namespace tvm { From 4205389385cdeb368b5233cb5ef3f079e9a58d7f Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 27 Aug 2021 11:50:46 -0700 Subject: [PATCH 13/14] target-str-to-target-object --- include/tvm/target/target.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 1a5e38f781e9..18a9d5a779f3 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -212,7 +212,7 @@ struct TargetStrHash { /*! * \brief Calculate the hash code of a Target based on the string value of the Target This will be removed when maps from Targets to IRModules are removed from the codebase. - * \param a The given Target + * \param target The Target to hash * \return String hash of the target */ size_t operator()(const Target& target) const { @@ -226,13 +226,13 @@ codebase.*/ struct TargetStrEqual { /*! * \brief Check if the two Targets are equal - * \param a One Target - * \param b The other Target + * \param target One Target + * \param other_target The other Target * \return String equality of the targets */ - const bool operator()(const Target& a, const Target& b) const { + const bool operator()(const Target& target, const Target& other_target) const { TargetStrHash target_hash = TargetStrHash(); - return target_hash(a) == target_hash(b); + return target_hash(target) == target_hash(other_target); } }; From 29f802ce10138c95a87efdfd9bce37ad76efd427 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 30 Aug 2021 11:07:55 -0700 Subject: [PATCH 14/14] Respond to change requests --- include/tvm/target/target.h | 53 ------------------------------- src/relay/backend/interpreter.cc | 9 +++--- src/relay/backend/te_compiler.cc | 5 +-- src/relay/backend/utils.cc | 18 +++++++++++ src/relay/backend/utils.h | 54 ++++++++++++++++++++++++++++++++ src/target/target.cc | 19 ----------- 6 files changed, 80 insertions(+), 78 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 18a9d5a779f3..deec662e74ad 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -205,58 +205,5 @@ void CheckAndUpdateHostConsistency(Map* target, Target* host); */ void CheckAndUpdateHostConsistency(Map* target, Target* host); -// TODO(@electriclilies): Move to somewhere in backend and add note about appropriate use - -/*! \brief Target hash function */ -struct TargetStrHash { - /*! - * \brief Calculate the hash code of a Target based on the string value of the Target - This will be removed when maps from Targets to IRModules are removed from the codebase. - * \param target The Target to hash - * \return String hash of the target - */ - size_t operator()(const Target& target) const { - return String::HashBytes(target->str().c_str(), target->str().size()); - } -}; - -/*! \brief Target equality function based on the string value of Target -This will be removed when maps from Targets to IRModules are removed from the -codebase.*/ -struct TargetStrEqual { - /*! - * \brief Check if the two Targets are equal - * \param target One Target - * \param other_target The other Target - * \return String equality of the targets - */ - const bool operator()(const Target& target, const Target& other_target) const { - TargetStrHash target_hash = TargetStrHash(); - return target_hash(target) == target_hash(other_target); - } -}; - -/*! - * \brief Convert a Map to std::unordered_map Target equality is currently based on pointer equality, which is a problem since - * we have a lot of Map in the codebase. This function converts the map to a - * version that is keyed based on string value of the Target instead. Note that once we remove - * Map, this function will be removed. - * \param input_map The map to convert - * \return The converted map - */ -std::unordered_map -TargetModuleMapToTargetStrModuleMap(Map input_map); - -/*! - * \brief Convert a std::unordered_map to - * Map This function is a helper that undoes TargetModuleMapToTargetStr. Note that - * once we remove Map, this function will be removed. - * \param input_map The map to convert - * \return The converted map - */ -Map TargetStrModuleMapToTargetModuleMap( - std::unordered_map input_map); - } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index b7c9a98f6e01..76b6f9186eb5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -386,10 +386,11 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - std::unordered_map per_target_module_std_map_ = - TargetModuleMapToTargetStrModuleMap(per_target_module_); - auto mod_itr = per_target_module_std_map_.find(target); - ICHECK(mod_itr != per_target_module_std_map_.end()) + std::unordered_map + per_target_module_std_map = + backend::TargetModuleMapToTargetStrModuleMap(per_target_module_); + auto mod_itr = per_target_module_std_map.find(target); + ICHECK(mod_itr != per_target_module_std_map.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2e95232086f6..06d862b781e1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -86,7 +86,8 @@ class TECompilerImpl : public TECompilerNode { } Map GetLoweredFunctions() { - std::unordered_map lowered_functions; + std::unordered_map + lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; @@ -110,7 +111,7 @@ class TECompilerImpl : public TECompilerNode { lowered_functions[target]->Update(lowered_func->cached_func->funcs); } - return TargetStrModuleMapToTargetModuleMap(lowered_functions); + return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions); } Array LowerExternalFunctions() { diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4b4844599e29..ea0ab093aa1d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -187,6 +187,24 @@ Array GetPassPrefix(const Map& targets, bool is return pass_seqs; } +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map) { + std::unordered_map std_map; + for (auto kv : input_map) { + std_map[kv.first] = kv.second; + } + return std_map; +} + +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map) { + Map tvm_map; + for (auto kv : input_map) { + tvm_map.Set(kv.first, kv.second); + } + return tvm_map; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index bf13715b7d46..cf8a2dd4b8e0 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -427,6 +427,60 @@ inline bool IsCompileEngineCacheDisabled() { */ Array GetPassPrefix(const Map& targets, bool is_vm); +/*! \brief Target hash function */ +struct TargetStrHash { + /*! + * \brief Calculate the hash code of a Target based on the string value of the Target. + Note that this hash should NOT be used in new usecases, equality of targets based on their + value is not well-defined. + This will be removed when maps from Targets to IRModules are removed from the codebase. + * \param target The Target to hash + * \return String hash of the target + */ + size_t operator()(const Target& target) const { + return String::HashBytes(target->str().c_str(), target->str().size()); + } +}; + +/*! \brief Target equality function based on the string value of Target +Note that this equality function should NOT be used in new usecases, equality of targets based on +their value is not well-defined. This will be removed when maps from Targets to IRModules are +removed from the codebase.*/ +struct TargetStrEqual { + /*! + * \brief Check if the two Targets are equal + * \param target One Target + * \param other_target The other Target + * \return String equality of the targets + */ + const bool operator()(const Target& target, const Target& other_target) const { + TargetStrHash target_hash = TargetStrHash(); + return target_hash(target) == target_hash(other_target); + } +}; + +/*! + * \brief Convert a Map to std::unordered_map Target equality is currently based on pointer equality, which is a problem since + * we have a lot of Map in the codebase. This function converts the map to a + * version that is keyed based on string value of the Target instead. Note that once we remove + * Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +std::unordered_map +TargetModuleMapToTargetStrModuleMap(Map input_map); + +/*! + * \brief Convert a std::unordered_map to + * Map This function is a helper that undoes TargetModuleMapToTargetStr. Note that + * once we remove Map, this function will be removed. + * \param input_map The map to convert + * \return The converted map + */ +Map TargetStrModuleMapToTargetModuleMap( + std::unordered_map input_map); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index aa2717d41bd7..e0b9539380d7 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -826,25 +826,6 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, return output; } -// Helper to convert the tvm::Map to a std::unordered_map -std::unordered_map -TargetModuleMapToTargetStrModuleMap(Map input_map) { - std::unordered_map std_map; - for (auto kv : input_map) { - std_map[kv.first] = kv.second; - } - return std_map; -} - -Map TargetStrModuleMapToTargetModuleMap( - std::unordered_map input_map) { - Map tvm_map; - for (auto kv : input_map) { - tvm_map.Set(kv.first, kv.second); - } - return tvm_map; -} - /********** Registry **********/ TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher);