Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 23, 2020
1 parent 0ce0f32 commit 9608878
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 45 deletions.
8 changes: 6 additions & 2 deletions cmake/modules/LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
# specific language governing permissions and limitations
# under the License.

# LLVM rules
add_definitions(-DDMLC_USE_FOPEN64=0)
# Due to LLVM debug symbols you can sometimes face linking issues on
# certain compiler, platform combinations if you don't set NDEBUG.
#
# See https://github.com/imageworks/OpenShadingLanguage/issues/1069
# for more discussion.
add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1)

# Test if ${USE_LLVM} is not an explicit boolean false
# It may be a boolean or a string
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};

use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type};

struct ExternalItem {
attrs: Vec<Attribute>,
Expand Down
3 changes: 2 additions & 1 deletion rust/tvm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}

#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
#[proc_macro_error]
#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
Expand Down
23 changes: 23 additions & 0 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
.map(attr_to_str)
.expect("Failed to get type_key");

let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true);

let ref_id = get_attr(&derive_input, "ref_name")
.map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
.unwrap_or_else(|| {
Expand Down Expand Up @@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {

expanded.extend(base_tokens);

if derive {
let derives = quote! {
impl std::hash::Hash for #ref_id {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}

impl std::cmp::PartialEq for #ref_id {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl std::cmp::Eq for #ref_id {}
};


expanded.extend(derives);
}

TokenStream::from(expanded)
}
9 changes: 2 additions & 7 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,7 @@ external! {
#[name("ir.DebugPrint")]
pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
#[name("node.StructuralEqual")]
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
}

// external! {
// #[name("ir.TextPrinter")]
// fn as_text(object: ObjectRef) -> CString;
// }
16 changes: 16 additions & 0 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
}
}

impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap())
}
}

impl<T: IsObject> PartialEq for ObjectPtr<T> {
fn eq(&self, other: &Self) -> bool {
let lhs = ObjectRef(Some(self.clone().upcast()));
let rhs = ObjectRef(Some(other.clone().upcast()));
super::structural_equal(lhs, rhs, false, false).unwrap()
}
}

impl<T: IsObject> Eq for ObjectPtr<T> {}

#[cfg(test)]
mod tests {
use super::{Object, ObjectPtr};
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-rt/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use tvm_macros::Object;
#[derive(Object)]
#[ref_name = "String"]
#[type_key = "runtime.String"]
#[no_derive]
pub struct StringObj {
base: Object,
data: *const u8,
Expand Down
1 change: 0 additions & 1 deletion rust/tvm-rt/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
//! `RetValue` is the owned version of `TVMPODValue`.

use std::convert::TryFrom;
// use std::ffi::c_void;

use crate::{ArgValue, Module, RetValue};
use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
Expand Down
4 changes: 4 additions & 0 deletions rust/tvm-sys/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ impl DataType {
DataType::new(DL_FLOAT_CODE, bits, lanes)
}

pub const fn float32() -> DataType {
Self::float(32, 1)
}

pub const fn uint(bits: u8, lanes: u16) -> DataType {
DataType::new(DL_UINT_CODE, bits, lanes)
}
Expand Down
150 changes: 140 additions & 10 deletions rust/tvm/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

use std::io::Result as IOResult;
use std::iter::FromIterator;
use std::path::Path;

use crate::runtime::array::Array;
Expand All @@ -33,7 +34,6 @@ use super::{ty::GlobalTypeVar, relay};
use tvm_macros::Object;

// TODO(@jroesch): define type
type TypeDef = ObjectRef;
type TypeData = ObjectRef;

#[repr(C)]
Expand All @@ -52,9 +52,11 @@ external! {
fn parse_module(file_name: TVMString, source: TVMString) -> IRModule;
#[name("parser.ParseExpr")]
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
#[name("ir.IRModule")]
fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
// Module methods
#[name("ir.Module_Add")]
fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule;
#[name("ir.Module_AddDef")]
fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
#[name("ir.Module_GetGlobalVar")]
Expand All @@ -66,15 +68,15 @@ external! {
#[name("ir.Module_Lookup_str")]
fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
#[name("ir.Module_GetGlobalTypeVars")]
fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
fn module_get_global_type_vars(module: IRModule) -> Array<GlobalTypeVar>;
#[name("ir.Module_ContainGlobalVar")]
fn module_contains_global_var(name: TVMString) -> bool;
fn module_contains_global_var(module: IRModule, name: TVMString) -> bool;
#[name("ir.Module_ContainGlobalTypeVar")]
fn module_contains_global_type_var(name: TVMString) -> bool;
fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool;
#[name("ir.Module_LookupDef")]
fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData;
#[name("ir.Module_LookupDef_str")]
fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData;
#[name("ir.Module_LookupTag")]
fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
#[name("ir.Module_FromExpr")]
Expand All @@ -87,8 +89,12 @@ external! {

// Note: we don't expose update here as update is going to be removed.


impl IRModule {
pub fn new<F, T>(funcs: F, types: T) -> Result<IRModule>
where F: IntoIterator<Item=(GlobalVar, BaseFunc)>, T: IntoIterator<Item=(GlobalTypeVar, TypeData)> {
module_new(Map::from_iter(funcs), Map::from_iter(types))
}

pub fn parse<N, S>(file_name: N, source: S) -> IRModule
where
N: Into<TVMString>,
Expand All @@ -105,6 +111,13 @@ impl IRModule {
Ok(module)
}

pub fn add(
&mut self,
var: GlobalVar,
func: BaseFunc) -> Result<IRModule> {
module_add(self.clone(), var, func, true)
}

pub fn add_def(
&mut self,
type_name: GlobalTypeVar,
Expand Down Expand Up @@ -132,10 +145,127 @@ impl IRModule {
{
module_lookup_str(self.clone(), name.into())
}

pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
module_get_global_type_vars(self.clone())
}

pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
module_contains_global_var(self.clone(), name.into())
}

pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
module_contains_global_type_var(self.clone(), name.into())
}

pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
module_lookup_def(self.clone(), global)
}

pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
module_lookup_def_str(self.clone(), global)
}

pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
module_lookup_tag(self.clone(), tag)
}

pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
module_from_expr(expr, funcs, types)
}

pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
module_import(self.clone(), path.into())
}

pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
module_import_from_std(self.clone(), path.into())
}
}

#[cfg(test)]
mod tests {
// #[test]
// fn
use std::collections::HashMap;
use super::relay::*;
use super::*;
use super::super::span::Span;
use tvm_rt::IsObjectRef;

#[test]
fn test_module_add() -> anyhow::Result<()> {
let funcs = HashMap::<GlobalVar, BaseFunc>::new();
let types = HashMap::<GlobalTypeVar, TypeData>::new();
let mut module = IRModule::new(funcs, types)?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = Array::from_vec(vec![x.clone()])?;
let func = relay::Function::simple(params, x.upcast()).upcast();
let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?;
// let lfunc = module.lookup_str("foo")?;
// let lfunc = lfunc.downcast::<relay::Function>()?;
// assert_eq!(lfunc.params.len(), 1);
Ok(())
}

#[test]
fn test_module_add_def() {

}

#[test]
fn test_get_global_var() {

}

#[test]
fn test_get_global_vars() {

}

#[test]
fn test_lookup() {

}


// pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
// module_get_global_type_vars(self.clone())
// }

// pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
// module_contains_global_var(self.clone(), name.into())
// }

// pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
// module_contains_global_type_var(self.clone(), name.into())
// }

#[test]
fn test_lookup_def() {

}
// pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
// module_lookup_def(self.clone(), global)
// }

// pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
// module_lookup_def_str(self.clone(), global)
// }

// pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
// module_lookup_tag(self.clone(), tag)
// }

// pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
// module_from_expr(expr, funcs, types)
// }


// pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
// module_import(self.clone(), path.into())
// }


// pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
// module_import_from_std(self.clone(), path.into())
// }
}
Loading

0 comments on commit 9608878

Please sign in to comment.