Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tvm crate stage 3 of Rust refactor #5769

Merged
merged 13 commits into from
Jun 18, 2020
5 changes: 4 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ members = [
"frontend/tests/callback",
"frontend/examples/resnet",
"tvm-sys",
"tvm-rt"
"tvm-macros",
"tvm-rt",
"tvm",
"out-of-tree"
]
33 changes: 33 additions & 0 deletions rust/out-of-tree/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
jroesch marked this conversation as resolved.
Show resolved Hide resolved
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[package]
name = "out-of-tree"
version = "0.1.0"
authors = ["Jared Roesch <jroesch@octoml.ai>"]
edition = "2018"


# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "my_pass"
crate-type = ["cdylib"]

[dependencies]
tvm = { version = "0.1", path = "../tvm" }
tvm-sys = { version = "0.1", path = "../tvm-sys" }
anyhow = "*"
44 changes: 44 additions & 0 deletions rust/out-of-tree/import_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.relay
from tvm.ir.transform import PassContext

x = tvm.relay.var("x", shape=(10,))
test_func = tvm.relay.Function([x], x)
test_mod = tvm.IRModule.from_expr(test_func)

pass_dylib = "/Users/jroesch/Git/tvm/rust/target/debug/libmy_pass.dylib"
jroesch marked this conversation as resolved.
Show resolved Hide resolved

def load_rust_extension(ext_dylib):
load_so = tvm.get_global_func("runtime.module.loadfile_so")
mod = load_so(ext_dylib)
mod.get_function("initialize")()


def load_pass(pass_name, dylib):
load_rust_extension(dylib)
return tvm.get_global_func(pass_name)

MyPass = load_pass("out_of_tree.Pass", pass_dylib)
ctx = PassContext()
import pdb; pdb.set_trace()
f = MyPass(test_func, test_mod, ctx)
mod = MyPass()(test_mod)

print(mod)
42 changes: 42 additions & 0 deletions rust/out-of-tree/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::ffi::c_void;
use std::os::raw::c_int;
use tvm::ir::relay::{self, Function};
use tvm::runtime::ObjectRef;
use tvm::transform::{function_pass, PassInfo, Pass, PassContext, IRModule};
use tvm::runtime::function::{register, Result};
use tvm::export_pass;

fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function {
let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null());
relay::Function::new(
func.params.clone(),
var.to_expr(),
func.ret_type.clone(),
func.type_params.clone())
}

// fn the_pass() -> Result<Pass> {
// let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?;
// function_pass(my_pass_fn, pass_info)
// }
jroesch marked this conversation as resolved.
Show resolved Hide resolved

export_pass!("out_of_tree.Pass", my_pass_fn);
1 change: 1 addition & 0 deletions rust/runtime/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ license = "Apache-2.0"
authors = ["TVM Contributors"]

[dependencies]
anyhow = "*"
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ndarray="0.12"
tvm-runtime = { path = "../../" }
6 changes: 3 additions & 3 deletions rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let tvm_rt_crate = crate::util::get_tvm_rt_crate();

let err_type = quote! { #tvm_rt_crate::Error };
let result_type = quote! { #tvm_rt_crate::function::Result };

let mut items = Vec::new();

Expand Down Expand Up @@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);

let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>> = func_ref.to_boxed_fn();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
let res: #ret_type = func_ref(#(#args),*)?;
Ok(res)
}
Expand Down
26 changes: 17 additions & 9 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate;

pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
let tvm_rt_crate = get_tvm_rt_crate();
let result = quote! { #tvm_rt_crate::function::Result };
let error = quote! { #tvm_rt_crate::errors::Error };
let derive_input = syn::parse_macro_input!(input as DeriveInput);
let payload_id = derive_input.ident;

Expand Down Expand Up @@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
#[derive(Clone)]
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);

impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
fn to_object_ref(&self) -> ObjectRef {
ObjectRef(self.0.as_ref().map(|o| o.upcast()))
impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
type Object = #payload_id;

fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.0.as_ref()
}

fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
#ref_id(object_ptr)
}
}

Expand All @@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> {
fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
let oref: ObjectRef = ret_val.try_into()?;
let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
Expand Down Expand Up @@ -125,19 +133,19 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}

impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
Expand Down
63 changes: 63 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::convert::{TryFrom, TryInto};
use std::marker::PhantomData;

use crate::object::{ObjectRef, IsObjectRef};
use crate::{external, RetValue, function::{Function, Result}};
use crate::errors::Error;

#[repr(C)]
#[derive(Clone)]
pub struct Array<T: IsObjectRef> {
object: ObjectRef,
_data: PhantomData<T>,
}

// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
}

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data.iter().map(|element| element.to_object_ref().into()).collect();

let func = Function::get("node.Array")
.expect("node.Array function is not registered, this is most likely a build or linking error");

let array_data = func.invoke(iter)?.try_into()?;

Ok(Array {
object: array_data,
_data: PhantomData,
})
}

pub fn get(&self, index: isize) -> Result<T>
where
T: TryFrom<RetValue, Error = Error>,
{
let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
oref.downcast()
}
}
2 changes: 2 additions & 0 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum Error {
NDArray(#[from] NDArrayError),
#[error("{0}")]
CallFailed(String),
#[error("this case will never occur")]
Infallible(#[from] std::convert::Infallible),
}

impl Error {
Expand Down
5 changes: 3 additions & 2 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ use std::{
ptr, str,
};

pub use tvm_sys::{ffi, ArgValue, RetValue};

use crate::errors::Error;

use super::to_boxed_fn::ToBoxedFn;
use super::to_function::{ToFunction, Typed};

pub use tvm_sys::{ffi, ArgValue, RetValue};
pub use super::to_function::{ToFunction, Typed};

pub type Result<T> = std::result::Result<T, Error>;

Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
}
}

#[macro_use]
pub mod function;
pub mod array;
pub mod context;
pub mod errors;
pub mod function;
pub mod module;
pub mod ndarray;
pub mod to_boxed_fn;
Expand Down
31 changes: 26 additions & 5 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,34 @@ impl ObjectRef {
}
}

pub trait ToObjectRef {
fn to_object_ref(&self) -> ObjectRef;
}
pub trait IsObjectRef: Sized {
type Object: IsObject;
fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;

impl ToObjectRef for ObjectRef {
fn to_object_ref(&self) -> ObjectRef {
self.clone()
let object_ptr = self.as_object_ptr().cloned();
ObjectRef(object_ptr.map(|ptr| ptr.upcast()))
}

fn downcast<U: IsObjectRef>(&self) -> Result<U, Error> {
let ptr =
self.as_object_ptr()
.map(|ptr| ptr.downcast::<U::Object>());
let ptr = ptr.transpose()?;
Ok(U::from_object_ptr(ptr))
}
}

impl IsObjectRef for ObjectRef {
type Object = Object;

fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.0.as_ref()
}

fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
ObjectRef(object_ptr)
}
}

Expand Down
Loading