Skip to content

Commit

Permalink
Improve Rust bindings: Map, Array, String, various IR nodes (#6339)
Browse files Browse the repository at this point in the history
* Fix datatype

* Add initialize macro

* Add some TIR nodes

* Better downcasting

* Improve Array and add Map

* Convert to new string API

* Clean up some warnings

* Add ConstIntBound type

* Run cargo fmt

* Remove debug prints

* Add some more ops

* Fix some string code

Co-authored-by: Jared Roesch <jroesch@octoml.ai>
  • Loading branch information
mwillsey and jroesch authored Aug 28, 2020
1 parent 4c9a391 commit c899b3c
Show file tree
Hide file tree
Showing 15 changed files with 644 additions and 55 deletions.
6 changes: 3 additions & 3 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
type Object = #payload_id;

fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
fn as_object_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
self.0.as_ref()
}

fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
fn from_object_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>>) -> Self {
#ref_id(object_ptr)
}
}
Expand All @@ -104,7 +104,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {

fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
let oref: ObjectRef = ret_val.try_into()?;
let oref: #tvm_rt_crate::ObjectRef = ret_val.try_into()?;
let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
let ptr = ptr.downcast::<#payload_id>()?;
Ok(#ref_id(Some(ptr)))
Expand Down
45 changes: 44 additions & 1 deletion rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
use crate::{
external,
function::{Function, Result},
RetValue,
ArgValue, RetValue,
};

#[repr(C)]
Expand All @@ -40,6 +40,8 @@ pub struct Array<T: IsObjectRef> {
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
#[name("node.ArraySize")]
fn array_size(array: ObjectRef) -> i64;
}

impl<T: IsObjectRef> Array<T> {
Expand Down Expand Up @@ -76,4 +78,45 @@ impl<T: IsObjectRef> Array<T> {
let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
oref.downcast()
}

pub fn len(&self) -> i64 {
array_size(self.object.clone()).expect("size should never fail")
}
}

impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
}
}

impl<T: IsObjectRef> From<Array<T>> for RetValue {
fn from(array: Array<T>) -> RetValue {
array.object.into()
}
}

impl<'a, T: IsObjectRef> TryFrom<ArgValue<'a>> for Array<T> {
type Error = Error;

fn try_from(array: ArgValue<'a>) -> Result<Array<T>> {
let object_ref: ObjectRef = array.try_into()?;
// TODO: type check
Ok(Array {
object: object_ref,
_data: PhantomData,
})
}
}

impl<'a, T: IsObjectRef> TryFrom<RetValue> for Array<T> {
type Error = Error;

fn try_from(array: RetValue) -> Result<Array<T>> {
let object_ref = array.try_into()?;
Ok(Array {
object: object_ref,
_data: PhantomData,
})
}
}
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl Function {
match rv {
RetValue::ObjectHandle(object) => {
let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap();
println!("after wrapped call: {}", optr.count());
// println!("after wrapped call: {}", optr.count());
crate::object::ObjectPtr::leak(optr);
}
_ => {}
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub mod array;
pub mod context;
pub mod errors;
pub mod function;
pub mod map;
pub mod module;
pub mod ndarray;
mod to_function;
Expand Down
264 changes: 264 additions & 0 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::iter::FromIterator;
use std::marker::PhantomData;

use crate::object::debug_print;

use crate::array::Array;
use crate::errors::Error;
use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
use crate::ArgValue;
use crate::{
external,
function::{Function, Result},
RetValue,
};

#[repr(C)]
#[derive(Clone)]
pub struct Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
object: ObjectRef,
_data: PhantomData<(K, V)>,
}

// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
#[name("node.MapSize")]
fn map_size(map: ObjectRef) -> i64;
#[name("node.MapGetItem")]
fn map_get_item(map_object: ObjectRef, key: ObjectRef) -> ObjectRef;
#[name("node.MapCount")]
fn map_count(map: ObjectRef, key: ObjectRef) -> ObjectRef;
#[name("node.MapItems")]
fn map_items(map: ObjectRef) -> Array<ObjectRef>;
}

impl<K, V> FromIterator<(K, V)> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
let iter = iter.into_iter();
let (lower_bound, upper_bound) = iter.size_hint();
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
for (k, v) in iter {
buffer.push(k.to_object_ref().into());
buffer.push(v.to_object_ref().into())
}
Self::from_data(buffer).expect("failed to convert from data")
}
}

impl<K, V> Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
pub fn from_data(data: Vec<ArgValue>) -> Result<Map<K, V>> {
let func = Function::get("node.Map").expect(
"node.Map function is not registered, this is most likely a build or linking error",
);

let map_data: ObjectPtr<Object> = func.invoke(data)?.try_into()?;

debug_assert!(
map_data.count() >= 1,
"map_data count is {}",
map_data.count()
);

Ok(Map {
object: ObjectRef(Some(map_data)),
_data: PhantomData,
})
}

pub fn get(&self, key: &K) -> Result<V>
where
V: TryFrom<RetValue, Error = Error>,
{
let oref: ObjectRef = map_get_item(self.object.clone(), key.to_object_ref())?;
oref.downcast()
}
}

pub struct IntoIter<K, V> {
// NB: due to FFI this isn't as lazy as one might like
key_and_values: Array<ObjectRef>,
next_key: i64,
_data: PhantomData<(K, V)>,
}

impl<K, V> Iterator for IntoIter<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
type Item = (K, V);

#[inline]
fn next(&mut self) -> Option<(K, V)> {
if self.next_key < self.key_and_values.len() {
let key = self
.key_and_values
.get(self.next_key as isize)
.expect("this should always succeed");
let value = self
.key_and_values
.get((self.next_key as isize) + 1)
.expect("this should always succeed");
self.next_key += 2;
Some((key.downcast().unwrap(), value.downcast().unwrap()))
} else {
None
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
((self.key_and_values.len() / 2) as usize, None)
}
}

impl<K, V> IntoIterator for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
type Item = (K, V);
type IntoIter = IntoIter<K, V>;

fn into_iter(self) -> IntoIter<K, V> {
let items = map_items(self.object).expect("unable to get map items");
IntoIter {
key_and_values: items,
next_key: 0,
_data: PhantomData,
}
}
}

use std::fmt;

impl<K, V> fmt::Debug for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let ctr = debug_print(self.object.clone()).unwrap();
fmt.write_fmt(format_args!("{:?}", ctr))
}
}

impl<K, V, S> From<Map<K, V>> for HashMap<K, V, S>
where
K: Eq + std::hash::Hash,
K: IsObjectRef,
V: IsObjectRef,
S: std::hash::BuildHasher + std::default::Default,
{
fn from(map: Map<K, V>) -> HashMap<K, V, S> {
HashMap::from_iter(map.into_iter())
}
}

impl<'a, K, V> From<Map<K, V>> for ArgValue<'a>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from(map: Map<K, V>) -> ArgValue<'a> {
map.object.into()
}
}

impl<K, V> From<Map<K, V>> for RetValue
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from(map: Map<K, V>) -> RetValue {
map.object.into()
}
}

impl<'a, K, V> TryFrom<ArgValue<'a>> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
type Error = Error;

fn try_from(array: ArgValue<'a>) -> Result<Map<K, V>> {
let object_ref: ObjectRef = array.try_into()?;
// TODO: type check
Ok(Map {
object: object_ref,
_data: PhantomData,
})
}
}

impl<K, V> TryFrom<RetValue> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
type Error = Error;

fn try_from(array: RetValue) -> Result<Map<K, V>> {
let object_ref = array.try_into()?;
// TODO: type check
Ok(Map {
object: object_ref,
_data: PhantomData,
})
}
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use super::*;
use crate::string::String as TString;

#[test]
fn test_from_into_hash_map() {
let mut std_map: HashMap<TString, TString> = HashMap::new();
std_map.insert("key1".into(), "value1".into());
std_map.insert("key2".into(), "value2".into());
let tvm_map = Map::from_iter(std_map.clone().into_iter());
let back_map = tvm_map.into();
assert_eq!(std_map, back_map);
}
}
4 changes: 4 additions & 0 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ impl<'a> From<ObjectRef> for ArgValue<'a> {
external! {
#[name("ir.DebugPrint")]
fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
#[name("node.StructuralEqual")]
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
}

// external! {
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
RetValue::ObjectHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
println!("back to type {}", optr.count());
// println!("back to type {}", optr.count());
optr.downcast()
}
_ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")),
Expand All @@ -315,7 +315,7 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
ArgValue::ObjectHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
println!("count: {}", optr.count());
// println!("count: {}", optr.count());
optr.downcast()
}
_ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
Expand Down
Loading

0 comments on commit c899b3c

Please sign in to comment.