Skip to content

Commit

Permalink
All tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Nov 5, 2020
1 parent fdaa4af commit 726a01b
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 28 deletions.
2 changes: 0 additions & 2 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ where
// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
#[name("node.MapSize")]
fn map_size(map: ObjectRef) -> i64;
#[name("node.MapGetItem")]
Expand Down
100 changes: 82 additions & 18 deletions rust/tvm/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,20 @@ impl IRModule {
module_lookup_tag(self.clone(), tag)
}

pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
module_from_expr(expr, funcs, types)
pub fn from_expr<E>(expr: E) -> Result<IRModule>
where E: IsObjectRef, E::Object: AsRef<<relay::Expr as IsObjectRef>::Object> {
Self::from_expr_with_items(expr, HashMap::new(), HashMap::new())
}

pub fn from_expr_with_items<E, F, T>(expr: E, funcs: F, types: T) -> Result<IRModule>
where F: IntoIterator<Item=(GlobalVar, BaseFunc)>,
T: IntoIterator<Item=(GlobalTypeVar, TypeData)>,
E: IsObjectRef,
E::Object: AsRef<<relay::Expr as IsObjectRef>::Object> {
module_from_expr(expr.upcast(), Map::from_iter(funcs), Map::from_iter(types))
}


pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
module_import(self.clone(), path.into())
}
Expand All @@ -212,6 +222,33 @@ mod tests {
use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind};
use crate::ir::span::Span;

fn add_dummy_functions(names: Vec<&str>) -> Result<IRModule> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = relay::Function::simple(params, x);

for name in names {
let gv = GlobalVar::new(name.into(), Span::null());
module = module.add(gv, func.clone())?;
}

Ok(module)
}

fn add_dummy_types(names: Vec<&str>) -> Result<IRModule> {
let mut module = IRModule::empty()?;

for name in names {
let name: String = name.into();
let name = GlobalTypeVar::new(name, TypeKind::Type, Span::null());
let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null());
module.add_def(name, type_data, true)?;
}

Ok(module)
}

#[test]
fn test_module_add() -> anyhow::Result<()> {
let mut module = IRModule::empty()?;
Expand All @@ -229,7 +266,7 @@ mod tests {
fn test_module_add_def() -> Result<()> {
let mut module = IRModule::empty()?;
let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null());
let type_data = TypeData::new(name.clone(), vec![], vec![]);
let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null());
module.add_def(name.clone(), type_data, true)?;
let by_gtv = module.lookup_def(name)?;
let by_gv = module.lookup_def_str("my_type")?;
Expand All @@ -251,39 +288,66 @@ mod tests {

#[test]
fn test_get_global_vars() -> Result<()> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = relay::Function::simple(params, x);
let gv_foo = GlobalVar::new("foo".into(), Span::null());
let module = module.add(gv_foo.clone(), func)?;
let gv = module.get_global_var("foo")?;
assert_eq!(gv_foo, gv);
let names = vec!["foo", "bar", "baz"];
let module = add_dummy_functions(names.clone())?;
let gvars: Vec<String> =
module.get_global_vars()?.into_iter().map(|gv| {
gv.name_hint.as_str().unwrap().to_string()
}).collect();

for name in names {
assert!(gvars.contains(&name.to_string()));
}

Ok(())
}

#[test]
fn test_get_global_type_vars() {
fn test_get_global_type_vars() -> Result<()> {
let names = vec!["foo", "bar", "baz"];
let module = add_dummy_types(names.clone())?;
let gvars: Vec<String> =
module.get_global_type_vars()?.into_iter().map(|gv| {
gv.name_hint.as_str().unwrap().to_string()
}).collect();

for name in names {
assert!(gvars.contains(&name.to_string()));
}

Ok(())
}

#[test]
fn test_contains_global_var() {
fn test_contains_global_var() -> Result<()> {
let module = add_dummy_functions(vec!["foo"])?;
assert!(module.contains_global_var("foo")?);
Ok(())
}

#[test]
fn test_contains_global_type_var() {
fn test_contains_global_type_var() -> Result<()> {
let module = add_dummy_types(vec!["foo"])?;
assert!(module.contains_global_type_var("foo")?);
Ok(())
}

// TODO(@jroesch): not really sure about this API at all.
// pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
// module_lookup_tag(self.clone(), tag)
// }

// TODO(@jroesch): do we need to test this?
// pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
// module_from_expr(expr, funcs, types)
// }
#[test]
fn test_from_expr() -> Result<()> {
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = relay::Function::simple(params, x);
let module = IRModule::from_expr(func.clone())?;
let main_fn = module.lookup_str("main")?;
let main_fn = main_fn.downcast::<relay::Function>()?;
assert_eq!(main_fn, func);
Ok(())
}

#[test]
fn test_import() -> Result<()> {
Expand Down
14 changes: 7 additions & 7 deletions rust/tvm/src/ir/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tvm_rt::{array::Array, DataType};
use crate::ir::span::Span;
use crate::ir::relay::Constructor;
use crate::ir::PrimExpr;
use crate::runtime::{IsObject, Object, ObjectPtr};
use crate::runtime::{IsObject, Object, ObjectPtr, string::String as TString};

#[repr(C)]
#[derive(Object, Debug)]
Expand Down Expand Up @@ -110,7 +110,7 @@ pub enum TypeKind {
#[type_key = "TypeVar"]
pub struct TypeVarNode {
pub base: TypeNode,
pub name_hint: String,
pub name_hint: TString,
pub kind: TypeKind,
}

Expand All @@ -121,13 +121,13 @@ pub struct TypeVarNode {
#[type_key = "GlobalTypeVar"]
pub struct GlobalTypeVarNode {
pub base: TypeNode,
pub name_hint: String,
pub name_hint: TString,
pub kind: TypeKind,
}

impl GlobalTypeVar {
pub fn new<S>(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar
where S: Into<String> {
where S: Into<TString> {
let node = GlobalTypeVarNode {
base: TypeNode::base::<GlobalTypeVarNode>(span),
name_hint: name_hint.into(),
Expand Down Expand Up @@ -266,7 +266,7 @@ pub struct TypeDataNode {
/// We adopt nominal typing for ADT definitions;
/// that is, differently-named ADT definitions with same constructors
/// have different types.
pub base: Object,
pub base: TypeNode,
pub type_name: GlobalTypeVar,
/// The type variables (to allow for polymorphism).
pub type_vars: Array<TypeVar>,
Expand All @@ -275,13 +275,13 @@ pub struct TypeDataNode {
}

impl TypeData {
pub fn new<TypeVars, Ctors>(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors) -> TypeData
pub fn new<TypeVars, Ctors>(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors, span: Span) -> TypeData
where TypeVars: IntoIterator<Item=TypeVar>,
Ctors: IntoIterator<Item=Constructor>,
{
use std::iter::FromIterator;
let type_data = TypeDataNode {
base: Object::base::<TypeDataNode>(),
base: TypeNode::base::<TypeDataNode>(span,),
type_name,
type_vars: Array::from_iter(type_vars),
constructors: Array::from_iter(constructors),
Expand Down
8 changes: 7 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret)
*ret = mod;
});

TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_typed([](IRModule module, GlobalTypeVar var, TypeData type_def, bool update) {
module->AddTypeDef(var, type_def, update);
});

TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
Expand All @@ -439,6 +441,10 @@ TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);

TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar")
.set_body_method<IRModule>(&IRModuleNode::ContainGlobalTypeVar);


TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
.set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);

Expand Down

0 comments on commit 726a01b

Please sign in to comment.