Skip to content

Commit

Permalink
Fix bad rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Aug 31, 2021
1 parent 21333ba commit a6aa3b8
Show file tree
Hide file tree
Showing 22 changed files with 389 additions and 192 deletions.
8 changes: 8 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

/*!
* \brief Convert type index to type key.
* \param tindex The type index.
* \param out_type_key The output type key.
* \return 0 when success, nonzero when failure happens
*/
TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);

/*!
* \brief Increase the reference count of an object.
*
Expand Down
13 changes: 3 additions & 10 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,20 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
}

impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
use std::ffi::c_void;
let object_ptr = &object_ref.0;
match object_ptr {
None => {
#tvm_rt_crate::ArgValue::
ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
}
Some(value) => value.clone().into()
Some(value) => value.into()
}
}
}

impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
let oref: #ref_id = object_ref.clone();
#tvm_rt_crate::ArgValue::<'a>::from(oref)
}
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #error;

Expand Down
13 changes: 8 additions & 5 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,22 @@ external! {
fn array_size(array: ObjectRef) -> i64;
}

impl<T: IsObjectRef> IsObjectRef for Array<T> {
impl<T: IsObjectRef + 'static> IsObjectRef for Array<T> {
type Object = Object;
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.object.as_ptr()
}

fn into_ptr(self) -> Option<ObjectPtr<Self::Object>> {
self.object.into_ptr()
}

fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
let object_ref = match object_ptr {
Some(o) => o.into(),
_ => panic!(),
};

Array {
object: object_ref,
_data: PhantomData,
Expand All @@ -67,7 +70,7 @@ impl<T: IsObjectRef> IsObjectRef for Array<T> {

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data.into_iter().map(T::into_arg_value).collect();
let iter = data.iter().map(T::into_arg_value).collect();

let func = Function::get("runtime.Array").expect(
"runtime.Array function is not registered, this is most likely a build or linking error",
Expand Down Expand Up @@ -151,9 +154,9 @@ impl<T: IsObjectRef> FromIterator<T> for Array<T> {
}
}

impl<'a, T: IsObjectRef> From<Array<T>> for ArgValue<'a> {
fn from(array: Array<T>) -> ArgValue<'a> {
array.object.into()
impl<'a, T: IsObjectRef> From<&'a Array<T>> for ArgValue<'a> {
fn from(array: &'a Array<T>) -> ArgValue<'a> {
(&array.object).into()
}
}

Expand Down
17 changes: 9 additions & 8 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use std::{

use crate::errors::Error;

pub use super::to_function::{ToFunction, Typed};
pub use super::to_function::{RawArgs, ToFunction, Typed};
use crate::object::AsArgValue;
pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -153,12 +154,12 @@ macro_rules! impl_to_fn {
where
Error: From<Err>,
Out: TryFrom<RetValue, Error = Err>,
$($t: Into<ArgValue<'static>>),*
$($t: for<'a> AsArgValue<'a>),*
{
fn from(func: Function) -> Self {
#[allow(non_snake_case)]
Box::new(move |$($t : $t),*| {
let args = vec![ $($t.into()),* ];
let args = vec![ $((&$t).as_arg_value()),* ];
Ok(func.invoke(args)?.try_into()?)
})
}
Expand Down Expand Up @@ -196,8 +197,8 @@ impl TryFrom<RetValue> for Function {
}
}

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
impl<'a> From<&'a Function> for ArgValue<'a> {
fn from(func: &'a Function) -> ArgValue<'a> {
if func.handle().is_null() {
ArgValue::Null
} else {
Expand Down Expand Up @@ -291,12 +292,12 @@ where
}

pub fn register_untyped<S: Into<String>>(
f: fn(Vec<ArgValue<'static>>) -> Result<RetValue>,
f: for<'a> fn(Vec<ArgValue<'a>>) -> Result<RetValue>,
name: S,
override_: bool,
) -> Result<()> {
// TODO(@jroesch): can we unify all the code.
let func = f.to_function();
//TODO(@jroesch): can we unify the untpyed and typed registration functions.
let func = ToFunction::<RawArgs, RetValue>::to_function(f);
let name = name.into();
// Not sure about this code
let handle = func.handle();
Expand Down
7 changes: 4 additions & 3 deletions rust/tvm-rt/src/graph_rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ impl GraphRt {

let runtime_create_fn_ret = runtime_create_fn.invoke(vec![
graph.into(),
lib.into(),
(&lib).into(),
(&dev.device_type).into(),
// NOTE you must pass the device id in as i32 because that's what TVM expects
(dev.device_id as i32).into(),
]);

let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?;
Ok(Self {
module: graph_executor_module,
Expand All @@ -79,7 +80,7 @@ impl GraphRt {
pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> {
let ref set_input_fn = self.module.get_function("set_input", false)?;

set_input_fn.invoke(vec![name.into(), input.into()])?;
set_input_fn.invoke(vec![name.into(), (&input).into()])?;
Ok(())
}

Expand All @@ -101,7 +102,7 @@ impl GraphRt {
/// Extract the ith output from the graph executor and write the results into output.
pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> {
let get_output_fn = self.module.get_function("get_output", false)?;
get_output_fn.invoke(vec![i.into(), output.into()])?;
get_output_fn.invoke(vec![i.into(), (&output).into()])?;
Ok(())
}
}
21 changes: 11 additions & 10 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,17 @@ mod tests {
);
}

#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = ByteArray::from(w.as_slice());
let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().copied().collect::<Vec<u8>>().as_slice()
);
}
// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership.
// #[test]
// fn bytearray() {
// let w = vec![1u8, 2, 3, 4, 5];
// let v = ByteArray::from(w.as_slice());
// let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
// assert_eq!(
// tvm.data(),
// w.iter().copied().collect::<Vec<u8>>().as_slice()
// );
// }

#[test]
fn ty() {
Expand Down
16 changes: 8 additions & 8 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ external! {
fn map_items(map: ObjectRef) -> Array<ObjectRef>;
}

impl<K, V> FromIterator<(K, V)> for Map<K, V>
impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
fn from_iter<T: IntoIterator<Item = (&'a K, &'a V)>>(iter: T) -> Self {
let iter = iter.into_iter();
let (lower_bound, upper_bound) = iter.size_hint();
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
for (k, v) in iter {
buffer.push(k.into());
buffer.push(v.into())
buffer.push(k.into_arg_value());
buffer.push(v.into_arg_value());
}
Self::from_data(buffer).expect("failed to convert from data")
}
Expand Down Expand Up @@ -202,13 +202,13 @@ where
}
}

impl<'a, K, V> From<Map<K, V>> for ArgValue<'a>
impl<'a, K, V> From<&'a Map<K, V>> for ArgValue<'a>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from(map: Map<K, V>) -> ArgValue<'a> {
map.object.into()
fn from(map: &'a Map<K, V>) -> ArgValue<'a> {
(&map.object).into()
}
}

Expand Down Expand Up @@ -268,7 +268,7 @@ mod test {
let mut std_map: HashMap<TString, TString> = HashMap::new();
std_map.insert("key1".into(), "value1".into());
std_map.insert("key2".into(), "value2".into());
let tvm_map = Map::from_iter(std_map.clone().into_iter());
let tvm_map = Map::from_iter(std_map.iter());
let back_map = tvm_map.into();
assert_eq!(std_map, back_map);
}
Expand Down
15 changes: 15 additions & 0 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ impl NDArrayContainer {
.cast::<NDArrayContainer>()
}
}

pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr<NDArrayContainer>) -> *mut NDArrayContainer
where
NDArrayContainer: 'a,
{
let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
unsafe {
object_ptr
.ptr
.as_ptr()
.cast::<u8>()
.offset(base_offset)
.cast::<NDArrayContainer>()
}
}
}

fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> {
Expand Down
19 changes: 16 additions & 3 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ mod object_ptr;

pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};

pub trait AsArgValue<'a> {
fn as_arg_value(&'a self) -> ArgValue<'a>;
}

impl<'a, T: 'static> AsArgValue<'a> for T
where
&'a T: Into<ArgValue<'a>>,
{
fn as_arg_value(&'a self) -> ArgValue<'a> {
self.into()
}
}

// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
// can't because of coherence rules. Instead, we generate them in the macro, and
// add what we can (including Into instead of From) as subtraits.
Expand All @@ -37,8 +50,8 @@ pub trait IsObjectRef:
Sized
+ Clone
+ Into<RetValue>
+ for<'a> AsArgValue<'a>
+ TryFrom<RetValue, Error = Error>
+ for<'a> Into<ArgValue<'a>>
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
+ std::fmt::Debug
{
Expand All @@ -51,8 +64,8 @@ pub trait IsObjectRef:
Self::from_ptr(None)
}

fn into_arg_value<'a>(self) -> ArgValue<'a> {
self.into()
fn into_arg_value<'a>(&'a self) -> ArgValue<'a> {
self.as_arg_value()
}

fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {
Expand Down
Loading

0 comments on commit a6aa3b8

Please sign in to comment.