diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index c84d0aab612fe..4134da5fe6d9c 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,18 +156,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.clone().into() + Some(value) => value.into() } } } - impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { - let oref: #ref_id = object_ref.clone(); - #tvm_rt_crate::ArgValue::<'a>::from(oref) - } - } - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index e8902b54f6eff..02c34a1d133f7 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,19 +45,22 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } + fn into_ptr(self) -> Option> { self.object.into_ptr() } + fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; + Array { object: object_ref, _data: PhantomData, @@ -67,7 +70,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.into_iter().map(T::into_arg_value).collect(); + let iter = data.iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -151,9 +154,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From> for ArgValue<'a> { - fn from(array: Array) -> ArgValue<'a> { - array.object.into() +impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { + fn from(array: &'a Array) -> ArgValue<'a> { + (&array.object).into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 1152f3f235b81..3a32c7ed9409a 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -37,6 +37,7 @@ use crate::errors::Error; pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; +use crate::object::AsArgValue; pub type Result = std::result::Result; @@ -141,24 +142,6 @@ impl Function { let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - // // This is a temporary patch to ensure that the arguments are correclty dropped. - // let args: Vec = values.into_iter().zip(type_codes.into_iter()).map(|(value, type_code)| { - // ArgValue::from_tvm_value(value, type_code) - // }).collect(); - - // let mut objects_to_drop: Vec = vec![]; - // for arg in args { - // match arg { - // ArgValue::ObjectHandle(_) | ArgValue::ModuleHandle(_) | ArgValue::NDArrayHandle(_) => objects_to_drop.push(arg.try_into().unwrap()), - // _ => {} - // } - // } - - // drop(objects_to_drop); - - let obj: crate::ObjectRef = rv.clone().try_into().unwrap(); - println!("rv: {}", obj.count()); - Ok(rv) } } @@ -171,12 +154,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: Into>),* + $($t: for<'a> AsArgValue<'a>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $($t.into()),* ]; + let args = vec![ $((&$t).as_arg_value()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -281,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { pub fn register>(f: F, name: S) -> Result<()> where F: ToFunction, - F: Typed, + F: for<'a> Typed<'a, I, O>, { register_override(f, name, false) } @@ -292,7 +275,7 @@ where pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> where F: ToFunction, - F: Typed, + F: for<'a> Typed<'a, I, O>, { let func = f.to_function(); let name = name.into(); @@ -309,22 +292,23 @@ where } pub fn register_untyped>( - f: fn(Vec>) -> Result, + f: for<'a> fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - // TODO(@jroesch): can we unify all the code. - let func = f.to_function(); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - Ok(()) + panic!("foo") + // // TODO(@jroesch): can we unify all the code. + // let func = ToFunction::, RetValue>::to_function(f); + // let name = name.into(); + // // Not sure about this code + // let handle = func.handle(); + // let name = CString::new(name)?; + // check_call!(ffi::TVMFuncRegisterGlobal( + // name.into_raw(), + // handle, + // override_ as c_int + // )); + // Ok(()) } #[cfg(test)] diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 7db53d4666657..5ac9710424e04 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -50,7 +50,7 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - lib.into(), + (&lib).into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), @@ -79,7 +79,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), input.into()])?; + set_input_fn.invoke(vec![name.into(), (&input).into()])?; Ok(()) } @@ -101,7 +101,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), output.into()])?; + get_output_fn.invoke(vec![i.into(), (&output).into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index d6dfaf3641b88..dbfac6f205b3b 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl FromIterator<(K, V)> for Map +impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into()); - buffer.push(v.into()) + buffer.push(k.into_arg_value()); + buffer.push(v.into_arg_value()); } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From> for ArgValue<'a> +impl<'a, K, V> From<&'a Map> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: Map) -> ArgValue<'a> { - map.object.into() + fn from(map: &'a Map) -> ArgValue<'a> { + (&map.object).into() } } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 08dcfe33f28f7..80f8f184140c3 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -101,6 +101,21 @@ impl NDArrayContainer { .cast::() } } + + pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + object_ptr + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } + } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 8c07ed9f0853c..075ef46f35e29 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,6 +29,16 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; +pub trait AsArgValue<'a> { + fn as_arg_value(&'a self) -> ArgValue<'a>; +} + +impl<'a, T: 'static> AsArgValue<'a> for T where &'a T: Into> { + fn as_arg_value(&'a self) -> ArgValue<'a> { + self.into() + } +} + // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -37,8 +47,8 @@ pub trait IsObjectRef: Sized + Clone + Into + + for<'a> AsArgValue<'a> + TryFrom - + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -51,8 +61,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(self) -> ArgValue<'a> { - self.into() + fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { + self.as_arg_value() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index a093cf5fe3aef..38dc99a885148 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -259,6 +259,10 @@ impl ObjectPtr { pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } + + pub unsafe fn as_ptr(&self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -308,26 +312,26 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From> for ArgValue<'a> { - fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { + fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.upcast::(); + let object_ptr = object_ptr.clone().upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) - as *mut NDArrayContainer as *mut std::ffi::c_void; + let dcast_ptr = object_ptr.downcast().unwrap(); + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) + as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 7797d2cd23ff1..0a053b7af539b 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -32,6 +32,7 @@ use std::{ }; use super::{function::Result, Function}; +use crate::AsArgValue; use crate::errors::Error; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -44,25 +45,39 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. -pub trait Typed { - fn args(i: Vec>) -> Result; +pub trait Typed<'a, I, O> { + fn args(i: Vec>) -> Result; fn ret(o: O) -> Result; } +trait AsArgValueErased where Self: for<'a> AsArgValue<'a> { + fn as_arg_value<'a>(&'a self) -> ArgValue<'a>; +} + +struct ArgList { + args: Vec> +} + +impl AsArgValueErased for T where T: for<'a> AsArgValue<'a> { + fn as_arg_value<'a>(&'a self) -> ArgValue<'a> { + AsArgValue::as_arg_value(self) + } +} + pub trait ToFunction: Sized { type Handle; fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - Self: Typed; + Self: for<'arg> Typed<'arg, I, O>; fn drop(handle: *mut Self::Handle); fn to_function(self) -> Function where - Self: Typed, + Self: for<'a> Typed<'a, I, O>, { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; let resource_handle = self.into_raw(); @@ -87,7 +102,7 @@ pub trait ToFunction: Sized { resource_handle: *mut c_void, ) -> c_int where - Self: Typed, + Self: for <'a> Typed<'a, I, O>, { #![allow(unused_assignments, unused_unsafe)] let result = std::panic::catch_unwind(|| { @@ -165,45 +180,48 @@ pub trait ToFunction: Sized { } } -impl Typed>, RetValue> for fn(Vec>) -> Result { - fn args(args: Vec>) -> Result>> { - Ok(args) - } +// impl<'a> Typed<'a, Vec>, RetValue> for for<'arg> fn(Vec>) -> Result { +// fn args(args: Vec>) -> Result>> { +// Ok(args) +// } - fn ret(o: RetValue) -> Result { - Ok(o) - } -} +// fn ret(o: RetValue) -> Result { +// Ok(o) +// } +// } -impl ToFunction>, RetValue> - for fn(Vec>) -> Result -{ - type Handle = fn(Vec>) -> Result; +// impl ToFunction +// for fn(ArgList) -> Result +// { +// type Handle = for<'a> fn(Vec>) -> Result; - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(self); - Box::into_raw(ptr) - } +// fn into_raw(self) -> *mut Self::Handle { +// let ptr: Box = Box::new(self); +// Box::into_raw(ptr) +// } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { (*handle)(args) } - } +// fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result { +// unsafe { +// let func = (*handle); +// func(args) +// } +// } - fn drop(_: *mut Self::Handle) {} -} +// fn drop(_: *mut Self::Handle) {} +// } macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for F + impl<'a, F, Out, $($t),*> Typed<'a, ($($t,)*), Out> for F where F: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<$t::Error>, )* + $( $t: TryFrom>, + Error: From<<$t as TryFrom>>::Error>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args(args: Vec>) -> Result<($($t,)*)> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), @@ -232,9 +250,9 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: Typed<($($t,)*), Out> + F: for<'arg> Typed<'arg, ($($t,)*), Out> { let ($($t,)*) = F::args(args)?; let out = unsafe { (*handle)($($t),*) }; @@ -261,7 +279,7 @@ impl_typed_and_to_function!(6; A, B, C, D, E, G); mod tests { use super::*; - fn call(f: F, args: Vec>) -> Result + fn call<'a, F, I, O>(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 4b005abee7ef1..7da6145797f74 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -20,7 +20,7 @@ use std::convert::TryFrom; use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::TVMByteArray; +use crate::ffi::{TVMByteArray, TVMByteArrayFree}; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -38,6 +38,11 @@ pub struct ByteArray { array: TVMByteArray, } +impl Drop for ByteArray { + fn drop(&mut self) { + } +} + impl ByteArray { /// Gets the underlying byte-array pub fn data(&self) -> &'static [u8] { @@ -59,6 +64,7 @@ impl ByteArray { } } + // Needs AsRef for Vec impl> From for ByteArray { fn from(arg: T) -> Self { @@ -78,20 +84,6 @@ impl<'a> From<&'a ByteArray> for ArgValue<'a> { } } -impl TryFrom> for ByteArray { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'static>) -> Result { - match val { - ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), - } - } -} - impl From for RetValue { fn from(val: ByteArray) -> RetValue { RetValue::Bytes(val.array) diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 6f43b786780a1..e996b9ddf3b70 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From for ArgValue<'a> { - fn from(s: CString) -> Self { - Self::String(s.into_raw()) +impl<'a> From<&'a CString> for ArgValue<'a> { + fn from(s: &'a CString) -> Self { + Self::String(s.as_ptr() as _) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl From<*const T> for ArgValue<'static> { +impl<'a, T> From<*const T> for ArgValue<'a> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl From<*mut T> for ArgValue<'static> { +impl<'a, T> From<*mut T> for ArgValue<'a> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From for ArgValue<'a> { - fn from(s: bool) -> Self { - (s as i64).into() +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 96d74e2260a48..22933e0cc5af1 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -83,11 +83,10 @@ fn main() -> anyhow::Result<()> { println!("param bytes: {}", params.len()); let mut output: Vec; + let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + graph_rt.load_params(¶ms)?; loop { - let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; - - graph_rt.load_params(¶ms)?; graph_rt.set_input("data", input.clone())?; graph_rt.run()?;