Skip to content

Commit

Permalink
[Rust][IRModule] Flesh out IRModule methods (#6741)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* WIP

* WIP

* Disable WASM and fix rebase

* Work on finishing tests

* Make entire object system printable

* Write some more tests for IRModule

* All tests pass

* Format

* Restore module.cc

* Bump syn
  • Loading branch information
jroesch authored Nov 5, 2020
1 parent 7291a92 commit a4bd5f8
Show file tree
Hide file tree
Showing 31 changed files with 599 additions and 193 deletions.
1 change: 0 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ members = [
"tvm-graph-rt",
"tvm-graph-rt/tests/test_tvm_basic",
"tvm-graph-rt/tests/test_tvm_dso",
"tvm-graph-rt/tests/test_wasm32",
"tvm-graph-rt/tests/test_nn",
"compiler-ext",
]
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
ndarray="0.12"
ndarray = "0.12"
tvm-graph-rt = { path = "../../" }

[build-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ proc-macro = true
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] }
proc-macro-error = "^1.0"
51 changes: 42 additions & 9 deletions rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,35 @@
* under the License.
*/
use proc_macro2::Span;
use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};

use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
use syn::{
token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType,
Signature, Type, Visibility,
};

struct ExternalItem {
attrs: Vec<Attribute>,
visibility: Visibility,
sig: Signature,
}

impl Parse for ExternalItem {
fn parse(input: ParseStream) -> Result<Self> {
let item = ExternalItem {
attrs: input.call(Attribute::parse_outer)?,
visibility: input.parse()?,
sig: input.parse()?,
};
let _semi: Semi = input.parse()?;
Ok(item)
}
}

struct External {
visibility: Visibility,
tvm_name: String,
ident: Ident,
generics: Generics,
Expand All @@ -32,7 +55,8 @@ struct External {

impl Parse for External {
fn parse(input: ParseStream) -> Result<Self> {
let method: TraitItemMethod = input.parse()?;
let method: ExternalItem = input.parse()?;
let visibility = method.visibility;
assert_eq!(method.attrs.len(), 1);
let sig = method.sig;
let tvm_name = method.attrs[0].parse_meta()?;
Expand All @@ -47,8 +71,7 @@ impl Parse for External {
}
_ => panic!(),
};
assert_eq!(method.default, None);
assert!(method.semi_token != None);

let ident = sig.ident;
let generics = sig.generics;
let inputs = sig
Expand All @@ -60,6 +83,7 @@ impl Parse for External {
let ret_type = sig.output;

Ok(External {
visibility,
tvm_name,
ident,
generics,
Expand Down Expand Up @@ -98,6 +122,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut items = Vec::new();

for external in &ext_input.externs {
let visibility = &external.visibility;
let name = &external.ident;
let global_name = format!("global_{}", external.ident);
let global_name = Ident::new(&global_name, Span::call_site());
Expand All @@ -109,7 +134,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.iter()
.map(|ty_param| match ty_param {
syn::GenericParam::Type(param) => param.clone(),
_ => panic!(),
_ => abort! { ty_param,
"Only supports type parameters."
},
})
.collect();

Expand All @@ -124,15 +151,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ty: Type = *pat_type.ty.clone();
(ident, ty)
}
_ => panic!(),
_ => abort! { pat_type,
"Only supports type parameters."
},
},
pat => abort! {
pat, "invalid pattern type for function";

note = "{:?} is not allowed here", pat;
},
_ => panic!(),
})
.unzip();

let ret_type = match &external.ret_type {
ReturnType::Type(_, rtype) => *rtype.clone(),
_ => panic!(),
ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
};

let global = quote! {
Expand All @@ -147,7 +180,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);

let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
#visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
let res: #ret_type = func_ref(#(#args),*)?;
Expand Down
5 changes: 4 additions & 1 deletion rust/tvm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;

mod external;
mod import_module;
Expand All @@ -29,12 +30,14 @@ pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}

#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
#[proc_macro_error]
#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
}

#[proc_macro_error]
#[proc_macro]
pub fn external(input: TokenStream) -> TokenStream {
external::macro_impl(input)
Expand Down
32 changes: 31 additions & 1 deletion rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
.map(attr_to_str)
.expect("Failed to get type_key");

let derive = get_attr(&derive_input, "no_derive")
.map(|_| false)
.unwrap_or(true);

let ref_id = get_attr(&derive_input, "ref_name")
.map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
.unwrap_or_else(|| {
Expand Down Expand Up @@ -75,6 +79,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
_ => panic!("derive only works for structs"),
};

let ref_derives = if derive {
quote! { #[derive(Debug, Clone)]}
} else {
quote! { #[derive(Clone)] }
};

let mut expanded = quote! {
unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
const TYPE_KEY: &'static str = #type_key;
Expand All @@ -87,7 +97,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
}

#[derive(Clone)]
#ref_derives
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);

impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
Expand Down Expand Up @@ -185,5 +195,25 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {

expanded.extend(base_tokens);

if derive {
let derives = quote! {
impl std::hash::Hash for #ref_id {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}

impl std::cmp::PartialEq for #ref_id {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl std::cmp::Eq for #ref_id {}
};

expanded.extend(derives);
}

TokenStream::from(expanded)
}
15 changes: 14 additions & 1 deletion rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

use std::convert::{TryFrom, TryInto};
use std::iter::{IntoIterator, Iterator};
use std::iter::{FromIterator, IntoIterator, Iterator};
use std::marker::PhantomData;

use crate::errors::Error;
Expand Down Expand Up @@ -82,6 +82,13 @@ impl<T: IsObjectRef> Array<T> {
}
}

impl<T: IsObjectRef> std::fmt::Debug for Array<T> {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
let as_vec: Vec<T> = self.clone().into_iter().collect();
write!(formatter, "{:?}", as_vec)
}
}

pub struct IntoIter<T: IsObjectRef> {
array: Array<T>,
pos: isize,
Expand Down Expand Up @@ -118,6 +125,12 @@ impl<T: IsObjectRef> IntoIterator for Array<T> {
}
}

impl<T: IsObjectRef> FromIterator<T> for Array<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Array::from_vec(iter.into_iter().collect()).unwrap()
}
}

impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
Expand Down
2 changes: 0 additions & 2 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ where
// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
#[name("node.MapSize")]
fn map_size(map: ObjectRef) -> i64;
#[name("node.MapGetItem")]
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use crate::object::{Object, ObjectPtr};

/// See the [`module-level documentation`](../ndarray/index.html) for more details.
#[repr(C)]
#[derive(Object)]
#[derive(Object, Debug)]
#[ref_name = "NDArray"]
#[type_key = "runtime.NDArray"]
pub struct NDArrayContainer {
Expand Down
12 changes: 4 additions & 8 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub trait IsObjectRef:
+ TryFrom<RetValue, Error = Error>
+ for<'a> Into<ArgValue<'a>>
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
+ std::fmt::Debug
{
type Object: IsObject;
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
Expand Down Expand Up @@ -88,14 +89,9 @@ pub trait IsObjectRef:

external! {
#[name("ir.DebugPrint")]
fn debug_print(object: ObjectRef) -> CString;
pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
#[name("node.StructuralEqual")]
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
}

// external! {
// #[name("ir.TextPrinter")]
// fn as_text(object: ObjectRef) -> CString;
// }
40 changes: 39 additions & 1 deletion rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use std::convert::TryFrom;
use std::ffi::CString;
use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::AtomicI32;

Expand Down Expand Up @@ -147,14 +148,26 @@ impl Object {
}
}

// impl fmt::Debug for Object {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// let index =
// format!("{} // key: {}", self.type_index, "the_key");

// f.debug_struct("Object")
// .field("type_index", &index)
// // TODO(@jroesch: do we expose other fields?)
// .finish()
// }
// }

/// An unsafe trait which should be implemented for an object
/// subtype.
///
/// The trait contains the type key needed to compute the type
/// index, a method for accessing the base object given the
/// subtype, and a typed delete method which is specialized
/// to the subtype.
pub unsafe trait IsObject: AsRef<Object> {
pub unsafe trait IsObject: AsRef<Object> + std::fmt::Debug {
const TYPE_KEY: &'static str;

unsafe extern "C" fn typed_delete(object: *mut Self) {
Expand Down Expand Up @@ -264,6 +277,13 @@ impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
}
}

impl<T: IsObject> fmt::Debug for ObjectPtr<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::ops::Deref;
write!(f, "{:?}", self.deref())
}
}

impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
fn from(object_ptr: ObjectPtr<T>) -> RetValue {
let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
Expand Down Expand Up @@ -342,6 +362,24 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
}
}

impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_i64(
super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(),
)
}
}

impl<T: IsObject> PartialEq for ObjectPtr<T> {
fn eq(&self, other: &Self) -> bool {
let lhs = ObjectRef(Some(self.clone().upcast()));
let rhs = ObjectRef(Some(other.clone().upcast()));
super::structural_equal(lhs, rhs, false, false).unwrap()
}
}

impl<T: IsObject> Eq for ObjectPtr<T> {}

#[cfg(test)]
mod tests {
use super::{Object, ObjectPtr};
Expand Down
3 changes: 2 additions & 1 deletion rust/tvm-rt/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ use super::Object;
use tvm_macros::Object;

#[repr(C)]
#[derive(Object)]
#[derive(Object, Debug)]
#[ref_name = "String"]
#[type_key = "runtime.String"]
#[no_derive]
pub struct StringObj {
base: Object,
data: *const u8,
Expand Down
1 change: 0 additions & 1 deletion rust/tvm-rt/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
//! `RetValue` is the owned version of `TVMPODValue`.

use std::convert::TryFrom;
// use std::ffi::c_void;

use crate::{ArgValue, Module, RetValue};
use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
Expand Down
Loading

0 comments on commit a4bd5f8

Please sign in to comment.