diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 17d1ba2a5132..8454b04443a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Convert type index to type key. + * \param tindex The type index. + * \param out_type_key The output type key. + * \return 0 when success, nonzero when failure happens + */ +TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + /*! * \brief Increase the reference count of an object. * diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index c84d0aab612f..4134da5fe6d9 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,18 +156,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.clone().into() + Some(value) => value.into() } } } - impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { - let oref: #ref_id = object_ref.clone(); - #tvm_rt_crate::ArgValue::<'a>::from(oref) - } - } - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index e8902b54f6ef..02c34a1d133f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,19 +45,22 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } + fn into_ptr(self) -> Option> { self.object.into_ptr() } + fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; + Array { object: object_ref, _data: PhantomData, @@ -67,7 +70,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.into_iter().map(T::into_arg_value).collect(); + let iter = data.iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -151,9 +154,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From> for ArgValue<'a> { - fn from(array: Array) -> ArgValue<'a> { - array.object.into() +impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { + fn from(array: &'a Array) -> ArgValue<'a> { + (&array.object).into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 5db665cc7a48..62474e6650d4 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -35,7 +35,8 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{ToFunction, Typed}; +pub use super::to_function::{RawArgs, ToFunction, Typed}; +use crate::object::AsArgValue; pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -153,12 +154,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: Into>),* + $($t: for<'a> AsArgValue<'a>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $($t.into()),* ]; + let args = vec![ $((&$t).as_arg_value()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -196,8 +197,8 @@ impl TryFrom for Function { } } -impl<'a> From for ArgValue<'a> { - fn from(func: Function) -> ArgValue<'a> { +impl<'a> From<&'a Function> for ArgValue<'a> { + fn from(func: &'a Function) -> ArgValue<'a> { if func.handle().is_null() { ArgValue::Null } else { @@ -291,12 +292,12 @@ where } pub fn register_untyped>( - f: fn(Vec>) -> Result, + f: for<'a> fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - // TODO(@jroesch): can we unify all the code. - let func = f.to_function(); + //TODO(@jroesch): can we unify the untpyed and typed registration functions. + let func = ToFunction::::to_function(f); let name = name.into(); // Not sure about this code let handle = func.handle(); diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 7db53d466665..53f3210aa742 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -50,11 +50,12 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - lib.into(), + (&lib).into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), ]); + let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?; Ok(Self { module: graph_executor_module, @@ -79,7 +80,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), input.into()])?; + set_input_fn.invoke(vec![name.into(), (&input).into()])?; Ok(()) } @@ -101,7 +102,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), output.into()])?; + get_output_fn.invoke(vec![i.into(), (&output).into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 824dc63f0b50..3b7d066e7b78 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -130,16 +130,17 @@ mod tests { ); } - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } + // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. + // #[test] + // fn bytearray() { + // let w = vec![1u8, 2, 3, 4, 5]; + // let v = ByteArray::from(w.as_slice()); + // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + // assert_eq!( + // tvm.data(), + // w.iter().copied().collect::>().as_slice() + // ); + // } #[test] fn ty() { diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index d6dfaf3641b8..5594a91dc0f0 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl FromIterator<(K, V)> for Map +impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into()); - buffer.push(v.into()) + buffer.push(k.into_arg_value()); + buffer.push(v.into_arg_value()); } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From> for ArgValue<'a> +impl<'a, K, V> From<&'a Map> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: Map) -> ArgValue<'a> { - map.object.into() + fn from(map: &'a Map) -> ArgValue<'a> { + (&map.object).into() } } @@ -268,7 +268,7 @@ mod test { let mut std_map: HashMap = HashMap::new(); std_map.insert("key1".into(), "value1".into()); std_map.insert("key2".into(), "value2".into()); - let tvm_map = Map::from_iter(std_map.clone().into_iter()); + let tvm_map = Map::from_iter(std_map.iter()); let back_map = tvm_map.into(); assert_eq!(std_map, back_map); } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 08dcfe33f28f..80f8f184140c 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -101,6 +101,21 @@ impl NDArrayContainer { .cast::() } } + + pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + object_ptr + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } + } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 8c07ed9f0853..f5832fcb3ab8 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,6 +29,19 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; +pub trait AsArgValue<'a> { + fn as_arg_value(&'a self) -> ArgValue<'a>; +} + +impl<'a, T: 'static> AsArgValue<'a> for T +where + &'a T: Into>, +{ + fn as_arg_value(&'a self) -> ArgValue<'a> { + self.into() + } +} + // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -37,8 +50,8 @@ pub trait IsObjectRef: Sized + Clone + Into + + for<'a> AsArgValue<'a> + TryFrom - + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -51,8 +64,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(self) -> ArgValue<'a> { - self.into() + fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { + self.as_arg_value() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index a093cf5fe3ae..09d6068f1a88 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -20,11 +20,14 @@ use std::convert::TryFrom; use std::ffi::CString; use std::fmt; +use std::os::raw::c_char; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::ffi::{ + self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, +}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -62,10 +65,12 @@ pub struct Object { /// "subtype". /// /// This function just converts the pointer to the correct type -/// and invokes the underlying typed delete function. +/// and reconstructs a Box which then is dropped to deallocate +/// the underlying allocation. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = object as *mut T; - T::typed_delete(typed_object); + let boxed: Box = Box::from_raw(typed_object); + drop(boxed); } fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { @@ -98,6 +103,18 @@ impl Object { } } + fn get_type_key(&self) -> String { + let mut cstring: *mut c_char = std::ptr::null_mut(); + unsafe { + if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { + panic!("{}", crate::get_last_error()); + } + return CString::from_raw(cstring) + .into_string() + .expect("type keys should be valid utf-8"); + } + } + fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); @@ -157,11 +174,6 @@ impl Object { /// to the subtype. pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; - - unsafe extern "C" fn typed_delete(object: *mut Self) { - let object = Box::from_raw(object); - drop(object) - } } /// A smart pointer for types which implement IsObject. @@ -252,13 +264,18 @@ impl ObjectPtr { if is_derived { Ok(unsafe { self.cast() }) } else { - Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + let type_key = self.as_ref().get_type_key(); + Err(Error::downcast(type_key.into(), U::TYPE_KEY)) } } pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } + + pub unsafe fn as_ptr(&self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -308,26 +325,25 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From> for ArgValue<'a> { - fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { + fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.upcast::(); + let object_ptr = object_ptr.clone().upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) - as *mut NDArrayContainer as *mut std::ffi::c_void; + let dcast_ptr = object_ptr.downcast().unwrap(); + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -345,14 +361,22 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); + optr.inc_ref(); + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); optr.downcast() } ArgValue::NDArrayHandle(handle) => { let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); + // TODO(@jroesch): figure out if there is a more optimal way to do this + let object = optr.upcast::(); + object.inc_ref(); + object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } @@ -440,11 +464,12 @@ mod tests { assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = ptr_clone.into(); + let arg_value: ArgValue = (&ptr_clone).into(); assert_eq!(ptr.count(), 2); let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 2); + assert_eq!(ptr2.count(), 3); assert_eq!(ptr.count(), ptr2.count()); + drop(ptr_clone); assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, @@ -460,26 +485,71 @@ mod tests { Ok(()) } - fn test_fn(o: ObjectPtr) -> ObjectPtr { - // The call machinery adds at least 1 extra count while inside the call. + fn test_fn_raw<'a>( + mut args: crate::to_function::ArgList<'a>, + ) -> crate::function::Result { + let v: ArgValue = args.remove(0); + let v2: ArgValue = args.remove(0); + // assert_eq!(o.count(), 2); + let o: ObjectPtr = v.try_into().unwrap(); + assert_eq!(o.count(), 2); + let o2: ObjectPtr = v2.try_into().unwrap(); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); + Ok(o.into()) + } + + #[test] + fn test_ref_count_raw_fn() { + use super::*; + use crate::function::{register_untyped, Function}; + let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = same.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); + let raw_func = Function::get("test_fn_raw").unwrap(); + let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + } + + fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { assert_eq!(o.count(), 3); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); return o; } #[test] - fn test_ref_count_boundary3() { + fn test_ref_count_typed() { use super::*; use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let output = test_fn_typed(ptr.clone(), ptr.clone()); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register(test_fn_typed, "test_fn_typed").unwrap(); + let typed_func = Function::get("test_fn_typed").unwrap(); + let output = typed_func + .invoke(vec![(&ptr).into(), (&ptr).into()]) + .unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); assert_eq!(ptr.count(), 1); - let stay = ptr.clone(); - assert_eq!(ptr.count(), 2); - register(test_fn, "my_func2").unwrap(); - let func = Function::get("my_func2").unwrap(); - let same = func.invoke(vec![ptr.into()]).unwrap(); - let same: ObjectPtr = same.try_into().unwrap(); - // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); - drop(same); - assert_eq!(stay.count(), 3); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 7797d2cd23ff..67fbfc996af0 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -44,8 +44,16 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. + +pub type ArgList<'a> = Vec>; + +pub enum Args<'a, I> { + Typed(I), + Raw(ArgList<'a>), +} + pub trait Typed { - fn args(i: Vec>) -> Result; + fn args<'arg>(i: Vec>) -> Result>; fn ret(o: O) -> Result; } @@ -54,7 +62,7 @@ pub trait ToFunction: Sized { fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where Self: Typed; @@ -70,7 +78,7 @@ pub trait ToFunction: Sized { check_call!(ffi::TVMFuncCreateFromCFunc( Some(Self::tvm_callback), resource_handle as *mut _, - None, // Some(Self::tvm_finalizer), + Some(Self::tvm_finalizer), &mut fhandle as *mut ffi::TVMFunctionHandle, )); @@ -102,22 +110,28 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - { - check_call!(ffi::TVMCbArgToReturn( - &mut value as *mut _, - &mut tcode as *mut _ - )); - } + // TODO(@jroesch): I believe it is sound to disable this specialized move rule. + // + // This is used in C++ to deal with moving an RValue or reference to a return value + // directly so you can skip copying. + // + // I believe this is not needed as the move directly occurs into the Rust function. + + // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int + // { + // check_call!(ffi::TVMCbArgToReturn( + // &mut value as *mut _, + // &mut tcode as *mut _ + // )); + // } let arg_value = ArgValue::from_tvm_value(value, tcode as u32); local_args.push(arg_value); } - // Ref-count be 2. let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -125,6 +139,12 @@ pub trait ToFunction: Sized { } }; + // TODO(@jroesch): clean up the handling of the is dec_ref + match rv.clone().try_into() as Result> { + Err(_) => {} + Ok(v) => drop(v), + }; + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -165,9 +185,11 @@ pub trait ToFunction: Sized { } } -impl Typed>, RetValue> for fn(Vec>) -> Result { - fn args(args: Vec>) -> Result>> { - Ok(args) +pub struct RawArgs; + +impl Typed for for<'a> fn(Vec>) -> Result { + fn args<'arg>(args: Vec>) -> Result> { + Ok(Args::Raw(args)) } fn ret(o: RetValue) -> Result { @@ -175,43 +197,59 @@ impl Typed>, RetValue> for fn(Vec>) -> R } } -impl ToFunction>, RetValue> - for fn(Vec>) -> Result -{ - type Handle = fn(Vec>) -> Result; +impl ToFunction for for<'arg> fn(Vec>) -> Result { + type Handle = for<'arg> fn(Vec>) -> Result; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(self); Box::into_raw(ptr) } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { (*handle)(args) } + fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { + unsafe { + let func = *handle; + func(args) + } } fn drop(_: *mut Self::Handle) {} } +/// A helper trait which correctly captures the complex conversion and lifetime semantics needed +/// to coerce an ordinary Rust value into `ArgValue`. +pub trait TryFromArgValue: TryFrom { + fn from_arg_value(f: F) -> std::result::Result; +} + +impl<'a, T> TryFromArgValue> for T +where + Self: TryFrom>, + Error: From<>>::Error>, +{ + fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { + Ok(TryFrom::try_from(f)?) + } +} + macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for F + impl Typed<($($t,)*), Out> for Fun where - F: Fn($($t),*) -> Out, + Fun: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<$t::Error>, )* + $( for<'a> $t: TryFromArgValue>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args<'arg>(args: Vec>) -> Result> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), $len, args.len()))) } let mut args = args.into_iter(); - $(let $t = args.next().unwrap().try_into()?;)* - Ok(($($t,)*)) + $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* + Ok(Args::Typed(($($t,)*))) } fn ret(out: Out) -> Result { @@ -220,9 +258,9 @@ macro_rules! impl_typed_and_to_function { } - impl ToFunction<($($t,)*), Out> for F + impl ToFunction<($($t,)*), Out> for Fun where - F: Fn($($t,)*) -> Out + 'static + Fun: Fn($($t,)*) -> Out + 'static { type Handle = Box Out + 'static>; @@ -232,13 +270,18 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: Typed<($($t,)*), Out> + Fun: Typed<($($t,)*), Out> { - let ($($t,)*) = F::args(args)?; - let out = unsafe { (*handle)($($t),*) }; - F::ret(out) + let ($($t,)*) = match Fun::args(args)? { + Args::Raw(_) => panic!("impossible case"), + Args::Typed(typed) => typed, + }; + + let fn_ptr = unsafe { &*handle }; + let out = fn_ptr($($t),*); + Fun::ret(out) } fn drop(ptr: *mut Self::Handle) { @@ -255,13 +298,15 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, G); +impl_typed_and_to_function!(6; A, B, C, D, E, F); +impl_typed_and_to_function!(7; A, B, C, D, E, F, G); +impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); #[cfg(test)] mod tests { use super::*; - fn call(f: F, args: Vec>) -> Result + fn call<'a, F, I, O>(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 4b005abee7ef..2903a81d9c36 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -17,10 +17,9 @@ * under the License. */ use std::convert::TryFrom; -use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::TVMByteArray; +use crate::ffi::{TVMByteArray, TVMByteArrayFree}; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -33,20 +32,45 @@ use crate::{ArgValue, RetValue}; /// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// ``` -pub struct ByteArray { - /// The raw FFI ByteArray. - array: TVMByteArray, +pub enum ByteArray { + Rust(TVMByteArray), + External(TVMByteArray), +} + +impl Drop for ByteArray { + fn drop(&mut self) { + match self { + ByteArray::Rust(bytes) => { + let ptr = bytes.data; + let len = bytes.size as _; + let cap = bytes.size as _; + let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; + drop(data); + } + ByteArray::External(byte_array) => unsafe { + if TVMByteArrayFree(byte_array as _) != 0 { + panic!("error"); + } + }, + } + } } impl ByteArray { /// Gets the underlying byte-array - pub fn data(&self) -> &'static [u8] { - unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) } + pub fn data(&self) -> &[u8] { + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { + std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) + }, + } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { - self.array.size as _ + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, + } } /// Converts the underlying byte-array to `Vec` @@ -59,50 +83,49 @@ impl ByteArray { } } -// Needs AsRef for Vec -impl> From for ByteArray { +impl>> From for ByteArray { fn from(arg: T) -> Self { - let arg = arg.as_ref(); - ByteArray { - array: TVMByteArray { - data: arg.as_ptr() as *const c_char, - size: arg.len() as _, - }, - } + let mut incoming_bytes: Vec = arg.into(); + let mut bytes = Vec::with_capacity(incoming_bytes.len()); + bytes.append(&mut incoming_bytes); + + let mut bytes = std::mem::ManuallyDrop::new(bytes); + let ptr = bytes.as_mut_ptr(); + assert_eq!(bytes.len(), bytes.capacity()); + ByteArray::Rust(TVMByteArray { + data: ptr as _, + size: bytes.len() as _, + }) } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { - ArgValue::Bytes(&val.array) - } -} - -impl TryFrom> for ByteArray { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'static>) -> Result { match val { - ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + ArgValue::Bytes(byte_array) + } } } } -impl From for RetValue { - fn from(val: ByteArray) -> RetValue { - RetValue::Bytes(val.array) - } -} +// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. +// impl From for RetValue { +// fn from(val: ByteArray) -> RetValue { +// match val { +// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { +// // TODO(@jroesch): This requires a little more work, going to land narratives +// RetValue::Bytes(byte_array) +// } +// } +// } +// } impl TryFrom for ByteArray { type Error = ValueDowncastError; fn try_from(val: RetValue) -> Result { match val { - RetValue::Bytes(array) => Ok(ByteArray { array }), + RetValue::Bytes(array) => Ok(ByteArray::External(array)), _ => Err(ValueDowncastError { expected_type: "ByteArray", actual_type: format!("{:?}", val), @@ -118,11 +141,11 @@ mod tests { #[test] fn convert() { let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); let v = b"hello"; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 6f43b786780a..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -224,7 +224,7 @@ macro_rules! impl_pod_value { } } - impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + impl<'a> From<&'a $type> for ArgValue<'a> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From for ArgValue<'a> { - fn from(s: CString) -> Self { - Self::String(s.into_raw()) +impl<'a> From<&'a CString> for ArgValue<'a> { + fn from(s: &'a CString) -> Self { + Self::String(s.as_ptr() as _) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl From<*const T> for ArgValue<'static> { +impl<'a, T> From<*const T> for ArgValue<'a> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl From<*mut T> for ArgValue<'static> { +impl<'a, T> From<*mut T> for ArgValue<'a> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From for ArgValue<'a> { - fn from(s: bool) -> Self { - (s as i64).into() +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index bd0de1c56ba3..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -78,24 +78,40 @@ fn main() -> anyhow::Result<()> { "/deploy_lib.so" )))?; - let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; - // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - println!("param bytes: {}", params.len()); - graph_rt.load_params(¶ms)?; + // If you want an easy way to test a memory leak simply replace the program below with: + // let mut output: Vec; + + // loop { + // let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + // graph_rt.load_params(params.clone())?; + // graph_rt.set_input("data", input.clone())?; + // graph_rt.run()?; + + // // prepare to get the output + // let output_shape = &[1, 1000]; + // let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + // graph_rt.get_output_into(0, output_nd.clone())?; + + // // flatten the output as Vec + // output = output_nd.to_vec::()?; + // } + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; + graph_rt.load_params(params)?; graph_rt.set_input("data", input)?; graph_rt.run()?; // prepare to get the output let output_shape = &[1, 1000]; - let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output.clone())?; + let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output_nd.clone())?; // flatten the output as Vec - let output = output.to_vec::()?; + let output: Vec = output_nd.to_vec::()?; // find the maximum entry in the output and its index let (argmax, max_prob) = output diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs index 6b5873398cab..8313e47bea20 100644 --- a/rust/tvm/src/compiler/graph_rt.rs +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -51,11 +51,11 @@ fn _compile_module( ) -> Result { // The RAW API is Fn(IRModule, String, String, Map, String); let module = TVM_BUILD.invoke(vec![ - module.into(), - target.into(), - target_host.into(), - params.into(), - module_name.into(), + (&module).into(), + (&target).into(), + (&target_host).into(), + (¶ms).into(), + (&module_name).into(), ])?; let module: RtModule = module.try_into().unwrap(); Ok(module) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 513a906f6db4..ea257af1ebc0 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -99,10 +99,10 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new(funcs: F, types: T) -> Result + pub fn new<'a, F, T>(funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, { module_new(Map::from_iter(funcs), Map::from_iter(types)) } @@ -110,7 +110,7 @@ impl IRModule { pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs, types) + IRModule::new(funcs.iter(), types.iter()) } pub fn parse(file_name: N, source: S) -> Result @@ -206,10 +206,10 @@ impl IRModule { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } - pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + pub fn from_expr_with_items<'a, E, F, T>(expr: E, funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, E: IsObjectRef, E::Object: AsRef<::Object>, { diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index 81ee426d3967..8deae30c076d 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -35,7 +35,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args { let arg: NDArray = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs index 37027af0ca37..f8886a55c3a2 100644 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -26,7 +26,7 @@ use tvm::{ }; fn main() { - fn error(_args: Vec>) -> Result { + fn error<'a>(_args: Vec>) -> Result { Err(errors::NDArrayError::DataTypeMismatch { expected: DataType::int(64, 1), actual: DataType::float(64, 1), diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs index 6fd4f868dc79..d575f47c87cd 100644 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -27,7 +27,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs index cdea2e1044c4..fc2e40d8de4d 100644 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -25,7 +25,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs index dbe65ba4c631..4f3d67e95d64 100644 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -26,7 +26,7 @@ use tvm::{ // FIXME fn main() { - fn concat_str(args: Vec>) -> Result { + fn concat_str<'a>(args: Vec>) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 942bc0d1d44a..2b88f0489321 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -669,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor { ret.lowered_funcs = lowered_module.per_target_module; ret.external_mods = lowered_module.external_mods; - auto target_host_str = target_host_->str(); - if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { - ret.lowered_funcs[target_host_str]->Update(mod_run); + if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { + ret.lowered_funcs[target_host_]->Update(mod_run); } else { - ret.lowered_funcs.Set(target_host_str, mod_run); + ret.lowered_funcs.Set(target_host_, mod_run); } std::vector input_var_names(input_vars_.size()); @@ -778,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return (*it).second.first; } - Map get_irmodule() { return this->output_.lowered_funcs; } + Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; LoweredOutput output_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index af2cbae1f72d..232985bdfb5c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -53,7 +53,11 @@ namespace { struct PairHash { template std::size_t operator()(const std::pair& k) const { - return std::hash()(k.first) ^ std::hash()(k.second); + return dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + template + std::size_t operator()(const std::pair& k) const { + return dmlc::HashCombine(ObjectHash()(k.first), std::hash()(k.second)); } }; @@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. - Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) : mod_(mod), per_target_module_(per_target_module), device_(device), @@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor, */ PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, Target target) { - std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + std::pair packed_func_key(target, tir_fn_var->name_hint); auto packed_itr = compiled_packed_funcs_.find(packed_func_key); if (packed_itr != compiled_packed_funcs_.end()) { // Already compiled. @@ -382,7 +386,7 @@ class Interpreter : public ExprFunctor, // Project out just the function(s) we need. IRModule lowered_projected_mod; - auto mod_itr = per_target_module_.find(target->str()); + auto mod_itr = per_target_module_.find(target); ICHECK(mod_itr != per_target_module_.end()) << "No target module for target '" << target->str() << "'"; const IRModule& target_module = (*mod_itr).second; @@ -407,7 +411,7 @@ class Interpreter : public ExprFunctor, PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint << "' in compiled module for target '" << target->str() << "'"; - compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } // Return just what we need for this call. @@ -874,10 +878,10 @@ class Interpreter : public ExprFunctor, // Map from target key to lowered TIR functions derived from mod_. // Note that primitives are implicitly executed on target_, while shape functions are implicitly // executed on the default 'cpu' host. Thus this map has at most two entries. - Map per_target_module_; + Map per_target_module_; // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. - std::unordered_map, PackedFunc, PairHash> + std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) @@ -895,7 +899,7 @@ class Interpreter : public ExprFunctor, * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -std::pair> Prepare(IRModule mod, Device device, Target target) { +std::pair> Prepare(IRModule mod, Device device, Target target) { // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq({transform::SimplifyInference(), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' @@ -1014,7 +1018,7 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_with_expr, device, target); std::shared_ptr intrp = std::make_shared( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, @@ -1057,7 +1061,7 @@ ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - std::pair> main_and_lowered = + std::pair> main_and_lowered = Prepare(mod_and_global.first, device, target); Interpreter intrp( /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 71ac752ec680..1a244ec728f1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -85,18 +85,18 @@ class TECompilerImpl : public TECompilerNode { return LowerShapeFuncInternal(key)->cached_func; } - Map GetLoweredFunctions() { - Map lowered_functions; + Map GetLoweredFunctions() { + Map lowered_functions; for (const auto& it : cache_) { auto source_func = it.first; auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } for (const auto& it : shape_func_cache_) { @@ -104,11 +104,11 @@ class TECompilerImpl : public TECompilerNode { auto lowered_func = it.second; auto target = source_func->target; - if (!lowered_functions.count(target->str())) { - lowered_functions.Set(target->str(), IRModule(Map({}))); + if (!lowered_functions.count(target)) { + lowered_functions.Set(target, IRModule(Map({}))); } - lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + lowered_functions[target]->Update(lowered_func->cached_func->funcs); } return lowered_functions; } @@ -884,7 +884,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) { // Annotate the per-target functions with their target and add them to the unified module for (const auto& kv : mod.per_target_module) { - const String target = kv.first; + const Target target = kv.first; const IRModule target_module = kv.second; // Right now, per-target functions are TIR functions, which don't have type definitions, so @@ -926,7 +926,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->AddTypeDef(kv.first, kv.second); } - Map per_target_modules; + Map per_target_modules; for (const auto& kv : mod->functions) { const GlobalVar& var = kv.first; const BaseFunc& func = kv.second; @@ -934,7 +934,7 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) { main_mod->Add(var, func); } else if (func->IsInstance()) { // Extract target - Optional target = func->GetAttr(tvm::attr::kTarget); + Optional target = func->GetAttr(tvm::attr::kTarget); ICHECK(target) << "Target should be set at this point"; // Put the function in per_target_modules diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index e9cfb0d62e66..1089aa96070b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -67,7 +67,7 @@ struct EnumClassHash { } }; -// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake +// TODO(@jroesch, @chrisS) these shoumakeld be a tvm::Map for uniformity sake // we should a version of context which works in Map using TargetMap = std::unordered_map; using DeviceMap = @@ -97,7 +97,7 @@ class TECompilerNode : public Object { virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0; /* Return all functions which have been lowered by the compiler, keyed by target. */ - virtual Map GetLoweredFunctions() = 0; + virtual Map GetLoweredFunctions() = 0; /*! * \brief Just in time compile to get a PackedFunc. @@ -144,7 +144,7 @@ struct LoweredModule { /*! \brief The module which contains the Relay code. */ IRModule main_module; /*! \brief The module which contains per target code. */ - Map per_target_module; + Map per_target_module; /*! \brief The external runtime modules which must be combined with the lowered code. */ Array external_mods; // TODO(@electriclilies): THis might need to become a map diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..bf13715b7d46 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; Map function_metadata; std::unordered_map> params; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 1892ce780a4c..3cd5df613f4a 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -262,3 +262,11 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } + +int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { + API_BEGIN(); + auto key = tvm::runtime::Object::TypeIndex2Key(tindex); + *out_type_key = static_cast(malloc(key.size() + 1)); + strncpy(*out_type_key, key.c_str(), key.size()); + API_END(); +}