Skip to content

Commit

Permalink
Rewrite the bindings to fix the ArgValue lifetime issue
Browse files Browse the repository at this point in the history
There are still quite a few issues left to resolve in this patch, but I believe the runtime
changes stablize memory consumption as long as the parameters are only set once. ByteArray
also has some totally broken unsafe code which I am unsure of how it was introduced.
  • Loading branch information
jroesch committed Aug 12, 2021
1 parent d135154 commit 3df983f
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 131 deletions.
13 changes: 3 additions & 10 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,20 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
}

impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
use std::ffi::c_void;
let object_ptr = &object_ref.0;
match object_ptr {
None => {
#tvm_rt_crate::ArgValue::
ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
}
Some(value) => value.clone().into()
Some(value) => value.into()
}
}
}

impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
let oref: #ref_id = object_ref.clone();
#tvm_rt_crate::ArgValue::<'a>::from(oref)
}
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #error;

Expand Down
13 changes: 8 additions & 5 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,22 @@ external! {
fn array_size(array: ObjectRef) -> i64;
}

impl<T: IsObjectRef> IsObjectRef for Array<T> {
impl<T: IsObjectRef + 'static> IsObjectRef for Array<T> {
type Object = Object;
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.object.as_ptr()
}

fn into_ptr(self) -> Option<ObjectPtr<Self::Object>> {
self.object.into_ptr()
}

fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
let object_ref = match object_ptr {
Some(o) => o.into(),
_ => panic!(),
};

Array {
object: object_ref,
_data: PhantomData,
Expand All @@ -67,7 +70,7 @@ impl<T: IsObjectRef> IsObjectRef for Array<T> {

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data.into_iter().map(T::into_arg_value).collect();
let iter = data.iter().map(T::into_arg_value).collect();

let func = Function::get("runtime.Array").expect(
"runtime.Array function is not registered, this is most likely a build or linking error",
Expand Down Expand Up @@ -151,9 +154,9 @@ impl<T: IsObjectRef> FromIterator<T> for Array<T> {
}
}

impl<'a, T: IsObjectRef> From<Array<T>> for ArgValue<'a> {
fn from(array: Array<T>) -> ArgValue<'a> {
array.object.into()
impl<'a, T: IsObjectRef> From<&'a Array<T>> for ArgValue<'a> {
fn from(array: &'a Array<T>) -> ArgValue<'a> {
(&array.object).into()
}
}

Expand Down
54 changes: 19 additions & 35 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::errors::Error;

pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};
use crate::object::AsArgValue;

pub type Result<T> = std::result::Result<T, Error>;

Expand Down Expand Up @@ -141,24 +142,6 @@ impl Function {

let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);

// // This is a temporary patch to ensure that the arguments are correclty dropped.
// let args: Vec<ArgValue> = values.into_iter().zip(type_codes.into_iter()).map(|(value, type_code)| {
// ArgValue::from_tvm_value(value, type_code)
// }).collect();

// let mut objects_to_drop: Vec<crate::ObjectRef> = vec![];
// for arg in args {
// match arg {
// ArgValue::ObjectHandle(_) | ArgValue::ModuleHandle(_) | ArgValue::NDArrayHandle(_) => objects_to_drop.push(arg.try_into().unwrap()),
// _ => {}
// }
// }

// drop(objects_to_drop);

let obj: crate::ObjectRef = rv.clone().try_into().unwrap();
println!("rv: {}", obj.count());

Ok(rv)
}
}
Expand All @@ -171,12 +154,12 @@ macro_rules! impl_to_fn {
where
Error: From<Err>,
Out: TryFrom<RetValue, Error = Err>,
$($t: Into<ArgValue<'static>>),*
$($t: for<'a> AsArgValue<'a>),*
{
fn from(func: Function) -> Self {
#[allow(non_snake_case)]
Box::new(move |$($t : $t),*| {
let args = vec![ $($t.into()),* ];
let args = vec![ $((&$t).as_arg_value()),* ];
Ok(func.invoke(args)?.try_into()?)
})
}
Expand Down Expand Up @@ -281,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
F: for<'a> Typed<'a, I, O>,
{
register_override(f, name, false)
}
Expand All @@ -292,7 +275,7 @@ where
pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
F: for<'a> Typed<'a, I, O>,
{
let func = f.to_function();
let name = name.into();
Expand All @@ -309,22 +292,23 @@ where
}

pub fn register_untyped<S: Into<String>>(
f: fn(Vec<ArgValue<'static>>) -> Result<RetValue>,
f: for<'a> fn(Vec<ArgValue<'a>>) -> Result<RetValue>,
name: S,
override_: bool,
) -> Result<()> {
// TODO(@jroesch): can we unify all the code.
let func = f.to_function();
let name = name.into();
// Not sure about this code
let handle = func.handle();
let name = CString::new(name)?;
check_call!(ffi::TVMFuncRegisterGlobal(
name.into_raw(),
handle,
override_ as c_int
));
Ok(())
panic!("foo")
// // TODO(@jroesch): can we unify all the code.
// let func = ToFunction::<Vec<ArgValue>, RetValue>::to_function(f);
// let name = name.into();
// // Not sure about this code
// let handle = func.handle();
// let name = CString::new(name)?;
// check_call!(ffi::TVMFuncRegisterGlobal(
// name.into_raw(),
// handle,
// override_ as c_int
// ));
// Ok(())
}

#[cfg(test)]
Expand Down
6 changes: 3 additions & 3 deletions rust/tvm-rt/src/graph_rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl GraphRt {

let runtime_create_fn_ret = runtime_create_fn.invoke(vec![
graph.into(),
lib.into(),
(&lib).into(),
(&dev.device_type).into(),
// NOTE you must pass the device id in as i32 because that's what TVM expects
(dev.device_id as i32).into(),
Expand Down Expand Up @@ -79,7 +79,7 @@ impl GraphRt {
pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> {
let ref set_input_fn = self.module.get_function("set_input", false)?;

set_input_fn.invoke(vec![name.into(), input.into()])?;
set_input_fn.invoke(vec![name.into(), (&input).into()])?;
Ok(())
}

Expand All @@ -101,7 +101,7 @@ impl GraphRt {
/// Extract the ith output from the graph executor and write the results into output.
pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> {
let get_output_fn = self.module.get_function("get_output", false)?;
get_output_fn.invoke(vec![i.into(), output.into()])?;
get_output_fn.invoke(vec![i.into(), (&output).into()])?;
Ok(())
}
}
14 changes: 7 additions & 7 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ external! {
fn map_items(map: ObjectRef) -> Array<ObjectRef>;
}

impl<K, V> FromIterator<(K, V)> for Map<K, V>
impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
fn from_iter<T: IntoIterator<Item = (&'a K, &'a V)>>(iter: T) -> Self {
let iter = iter.into_iter();
let (lower_bound, upper_bound) = iter.size_hint();
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
for (k, v) in iter {
buffer.push(k.into());
buffer.push(v.into())
buffer.push(k.into_arg_value());
buffer.push(v.into_arg_value());
}
Self::from_data(buffer).expect("failed to convert from data")
}
Expand Down Expand Up @@ -202,13 +202,13 @@ where
}
}

impl<'a, K, V> From<Map<K, V>> for ArgValue<'a>
impl<'a, K, V> From<&'a Map<K, V>> for ArgValue<'a>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from(map: Map<K, V>) -> ArgValue<'a> {
map.object.into()
fn from(map: &'a Map<K, V>) -> ArgValue<'a> {
(&map.object).into()
}
}

Expand Down
15 changes: 15 additions & 0 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ impl NDArrayContainer {
.cast::<NDArrayContainer>()
}
}

pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr<NDArrayContainer>) -> *mut NDArrayContainer
where
NDArrayContainer: 'a,
{
let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
unsafe {
object_ptr
.ptr
.as_ptr()
.cast::<u8>()
.offset(base_offset)
.cast::<NDArrayContainer>()
}
}
}

fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> {
Expand Down
16 changes: 13 additions & 3 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ mod object_ptr;

pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};

pub trait AsArgValue<'a> {
fn as_arg_value(&'a self) -> ArgValue<'a>;
}

impl<'a, T: 'static> AsArgValue<'a> for T where &'a T: Into<ArgValue<'a>> {
fn as_arg_value(&'a self) -> ArgValue<'a> {
self.into()
}
}

// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
// can't because of coherence rules. Instead, we generate them in the macro, and
// add what we can (including Into instead of From) as subtraits.
Expand All @@ -37,8 +47,8 @@ pub trait IsObjectRef:
Sized
+ Clone
+ Into<RetValue>
+ for<'a> AsArgValue<'a>
+ TryFrom<RetValue, Error = Error>
+ for<'a> Into<ArgValue<'a>>
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
+ std::fmt::Debug
{
Expand All @@ -51,8 +61,8 @@ pub trait IsObjectRef:
Self::from_ptr(None)
}

fn into_arg_value<'a>(self) -> ArgValue<'a> {
self.into()
fn into_arg_value<'a>(&'a self) -> ArgValue<'a> {
self.as_arg_value()
}

fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {
Expand Down
20 changes: 12 additions & 8 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ impl<T: IsObject> ObjectPtr<T> {
pub unsafe fn into_raw(self) -> *mut T {
self.ptr.as_ptr()
}

pub unsafe fn as_ptr(&self) -> *mut T {
self.ptr.as_ptr()
}
}

impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
Expand Down Expand Up @@ -308,26 +312,26 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
}
}

impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
impl<'a, T: IsObject> From<&'a ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: &'a ObjectPtr<T>) -> ArgValue<'a> {
debug_assert!(object_ptr.count() >= 1);
let object_ptr = object_ptr.upcast::<Object>();
let object_ptr = object_ptr.clone().upcast::<Object>();
match T::TYPE_KEY {
"runtime.NDArray" => {
use crate::ndarray::NDArrayContainer;
// TODO(this is probably not optimal)
let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap())
as *mut NDArrayContainer as *mut std::ffi::c_void;
let dcast_ptr = object_ptr.downcast().unwrap();
let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr)
as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::NDArrayHandle(raw_ptr)
}
"runtime.Module" => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::ModuleHandle(raw_ptr)
}
_ => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::ObjectHandle(raw_ptr)
}
Expand Down
Loading

0 comments on commit 3df983f

Please sign in to comment.