Skip to content

Commit

Permalink
Write some more tests for IRModule
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Nov 5, 2020
1 parent 1b72ca0 commit fdaa4af
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 58 deletions.
9 changes: 8 additions & 1 deletion rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

use std::convert::{TryFrom, TryInto};
use std::iter::{IntoIterator, Iterator};
use std::iter::{IntoIterator, Iterator, FromIterator};
use std::marker::PhantomData;

use crate::errors::Error;
Expand Down Expand Up @@ -125,6 +125,13 @@ impl<T: IsObjectRef> IntoIterator for Array<T> {
}
}

impl<T: IsObjectRef> FromIterator<T> for Array<T> {
fn from_iter<I: IntoIterator<Item=T>>(iter: I) -> Self {
Array::from_vec(iter.into_iter().collect()).unwrap()
}
}


impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
Expand Down
102 changes: 58 additions & 44 deletions rust/tvm/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ use crate::runtime::array::Array;
use crate::runtime::function::Result;
use crate::runtime::map::Map;
use crate::runtime::string::String as TVMString;
use crate::runtime::{external, Object, ObjectRef};
use crate::runtime::{external, Object, IsObjectRef};

use super::expr::GlobalVar;
use super::function::BaseFunc;
use super::function::{BaseFunc};
use super::source_map::SourceMap;
use super::{ty::GlobalTypeVar, relay};

// TODO(@jroesch): define type
type TypeData = ObjectRef;
use super::{ty::GlobalTypeVar, ty::TypeData, relay};

#[derive(Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -88,7 +85,7 @@ external! {
#[name("ir.Module_LookupDef")]
fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData;
#[name("ir.Module_LookupDef_str")]
fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData;
fn module_lookup_def_str(module: IRModule, global: TVMString) -> TypeData;
#[name("ir.Module_LookupTag")]
fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
#[name("ir.Module_FromExpr")]
Expand Down Expand Up @@ -131,11 +128,13 @@ impl IRModule {
Ok(module)
}

pub fn add(
pub fn add<F>(
&mut self,
var: GlobalVar,
func: BaseFunc) -> Result<IRModule> {
module_add(self.clone(), var, func, true)
func: F) -> Result<IRModule>
// todo(@jroesch): can we do better here? why doesn't BaseFunc::Object work?
where F: IsObjectRef, F::Object: AsRef<<BaseFunc as IsObjectRef>::Object> {
module_add(self.clone(), var, func.upcast(), true)
}

pub fn add_def(
Expand Down Expand Up @@ -183,8 +182,9 @@ impl IRModule {
module_lookup_def(self.clone(), global)
}

pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
module_lookup_def_str(self.clone(), global)
pub fn lookup_def_str<S>(&self, global: S) -> Result<TypeData>
where S: Into<TVMString> {
module_lookup_def_str(self.clone(), global.into())
}

pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
Expand All @@ -206,18 +206,18 @@ impl IRModule {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::relay::*;
use super::*;
use super::super::span::Span;
use tvm_rt::IsObjectRef;
use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind};
use crate::ir::span::Span;

#[test]
fn test_module_add() -> anyhow::Result<()> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = Array::from_vec(vec![x.clone()])?;
let func = relay::Function::simple(params, x.upcast()).upcast();
let params = vec![x.clone()];
let func = relay::Function::simple(params, x);
let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?;
let lfunc = module.lookup_str("foo")?;
let lfunc = lfunc.downcast::<relay::Function>()?;
Expand All @@ -226,38 +226,47 @@ mod tests {
}

#[test]
fn test_module_add_def() {
todo!("this is blocked on having ability to define ADTs")
fn test_module_add_def() -> Result<()> {
let mut module = IRModule::empty()?;
let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null());
let type_data = TypeData::new(name.clone(), vec![], vec![]);
module.add_def(name.clone(), type_data, true)?;
let by_gtv = module.lookup_def(name)?;
let by_gv = module.lookup_def_str("my_type")?;
Ok(())
}

#[test]
fn test_get_global_var() -> Result<()> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = relay::Function::simple(params, x.upcast()).upcast();
let func = relay::Function::simple(params, x);
let gv_foo = GlobalVar::new("foo".into(), Span::null());
let module = module.add(gv_foo, func)?;
let gv = module.get_global_var("foo");
let module = module.add(gv_foo.clone(), func)?;
let gv = module.get_global_var("foo")?;
assert_eq!(gv_foo, gv);
Ok(())
}

#[test]
fn test_get_global_vars() {

fn test_get_global_vars() -> Result<()> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = relay::Function::simple(params, x);
let gv_foo = GlobalVar::new("foo".into(), Span::null());
let module = module.add(gv_foo.clone(), func)?;
let gv = module.get_global_var("foo")?;
assert_eq!(gv_foo, gv);
Ok(())
}

#[test]
fn test_get_global_type_vars() {

}

#[test]
fn test_lookup() {

}

#[test]
fn test_contains_global_var() {
}
Expand All @@ -266,29 +275,34 @@ mod tests {
fn test_contains_global_type_var() {
}

#[test]
fn test_lookup_def() {

}

#[test]
fn lookup_def() {

}

// TODO(@jroesch): not really sure about this API at all.
// pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
// module_lookup_tag(self.clone(), tag)
// }

// TODO(@jroesch): do we need to test this?
// pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
// module_from_expr(expr, funcs, types)
// }

// pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
// module_import(self.clone(), path.into())
// }
#[test]
fn test_import() -> Result<()> {
let mut std_path: String = env!("CARGO_MANIFEST_DIR").into();
std_path += "/../../python/tvm/relay/std/prelude.rly";

// pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
// module_import_from_std(self.clone(), path.into())
// }
let mut mod1 = IRModule::empty()?;
mod1.import(std_path.clone())?;
mod1.lookup_str("map")?;

// TODO(@jroesch): this requires another patch of mine to enable.

// if cfg!(feature = "python") {
// crate::python::load().unwrap();
// let mut mod2 = IRModule::empty()?;
// mod2.import_from_std("prelude.rly")?;
// mod2.lookup_str("map")?;
// }

Ok(())
}
}
7 changes: 4 additions & 3 deletions rust/tvm/src/ir/relay/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,10 @@ impl Function {
Function(Some(ObjectPtr::new(node)))
}

pub fn simple(params: Vec<Var>, body: Expr) -> Function {
pub fn simple<E>(params: Vec<Var>, body: E) -> Function
where E: IsObjectRef, E::Object: AsRef<<Expr as IsObjectRef>::Object> {
let params = Array::from_vec(params).unwrap();
Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap())
Self::new(params, body.upcast(), Type::null(), Array::from_vec(vec![]).unwrap())
}
}

Expand Down Expand Up @@ -547,7 +548,7 @@ def @main() -> float32 {
)
.unwrap();
let main = module
.lookup(module.get_global_var("main".to_string().into()).unwrap())
.lookup(module.get_global_var("main").unwrap())
.unwrap();
let func = main.downcast::<crate::ir::relay::Function>().unwrap();
let constant = func
Expand Down
47 changes: 37 additions & 10 deletions rust/tvm/src/ir/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* under the License.
*/

use super::span::Span;
use crate::runtime::{IsObject, Object, ObjectPtr};

use tvm_macros::Object;
use tvm_rt::{array::Array, DataType};

use super::PrimExpr;
use super::relay::Constructor;
use crate::ir::span::Span;
use crate::ir::relay::Constructor;
use crate::ir::PrimExpr;
use crate::runtime::{IsObject, Object, ObjectPtr};

#[repr(C)]
#[derive(Object, Debug)]
Expand Down Expand Up @@ -124,6 +125,18 @@ pub struct GlobalTypeVarNode {
pub kind: TypeKind,
}

impl GlobalTypeVar {
pub fn new<S>(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar
where S: Into<String> {
let node = GlobalTypeVarNode {
base: TypeNode::base::<GlobalTypeVarNode>(span),
name_hint: name_hint.into(),
kind: kind,
};
ObjectPtr::new(node).into()
}
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TupleType"]
Expand Down Expand Up @@ -249,16 +262,30 @@ The kind checker enforces this. */
#[ref_name = "TypeData"]
#[type_key = "relay.TypeData"]
pub struct TypeDataNode {
// /*!
// * \brief The header is simply the name of the ADT.
// * We adopt nominal typing for ADT definitions;
// * that is, differently-named ADT definitions with same constructors
// * have different types.
// */
/// The header is simply the name of the ADT.
/// We adopt nominal typing for ADT definitions;
/// that is, differently-named ADT definitions with same constructors
/// have different types.
pub base: Object,
pub type_name: GlobalTypeVar,
/// The type variables (to allow for polymorphism).
pub type_vars: Array<TypeVar>,
/// The constructors.
pub constructors: Array<Constructor>,
}

impl TypeData {
pub fn new<TypeVars, Ctors>(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors) -> TypeData
where TypeVars: IntoIterator<Item=TypeVar>,
Ctors: IntoIterator<Item=Constructor>,
{
use std::iter::FromIterator;
let type_data = TypeDataNode {
base: Object::base::<TypeDataNode>(),
type_name,
type_vars: Array::from_iter(type_vars),
constructors: Array::from_iter(constructors),
};
TypeData(Some(ObjectPtr::new(type_data)))
}
}

0 comments on commit fdaa4af

Please sign in to comment.