Skip to content

Commit

Permalink
tvm crate stage 3 of Rust refactor (#5769)
Browse files Browse the repository at this point in the history
* Adapt to new macro

* Add tvm crate

* Fix out of tree pass with new bindings

* Super slick API working

* Add examples

* Delay egg example and add ASF headers

* Move array.rs around

* Remove outdated tests will restore in CI PR

* Fix some memory issues

* Fix ref counting issue

* Formatting and cleanup

* Remove out-of-tree for now

* Remove out-of-tree
  • Loading branch information
jroesch authored Jun 18, 2020
1 parent 9ba98be commit d8c80c3
Show file tree
Hide file tree
Showing 25 changed files with 1,114 additions and 139 deletions.
4 changes: 3 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ members = [
"frontend/tests/callback",
"frontend/examples/resnet",
"tvm-sys",
"tvm-rt"
"tvm-macros",
"tvm-rt",
"tvm",
]
4 changes: 4 additions & 0 deletions rust/runtime/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ name = "test-wasm32"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }

[build-dependencies]
anyhow = "^1.0"
14 changes: 10 additions & 4 deletions rust/runtime/tests/test_wasm32/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

use std::{path::PathBuf, process::Command};

fn main() {
use anyhow::{Context, Result};

fn main() -> Result<()> {
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");

if !out_dir.is_dir() {
std::fs::create_dir(&out_dir).unwrap();
std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?;
}

let obj_file = out_dir.join("test.o");
Expand All @@ -36,7 +38,8 @@ fn main() {
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
.context("failed to execute Python script for generating TVM library")?;

assert!(
obj_file.exists(),
"Could not build tvm lib: {}",
Expand All @@ -49,12 +52,14 @@ fn main() {
);

let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");

let output = Command::new(ar)
.arg("rcs")
.arg(&lib_file)
.arg(&obj_file)
.output()
.expect("Failed to execute command");
.context("failed to run LLVM_AR command")?;

assert!(
lib_file.exists(),
"Could not create archive: {}",
Expand All @@ -68,4 +73,5 @@ fn main() {

println!("cargo:rustc-link-lib=static=test_wasm32");
println!("cargo:rustc-link-search=native={}", out_dir.display());
Ok(())
}
6 changes: 3 additions & 3 deletions rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let tvm_rt_crate = crate::util::get_tvm_rt_crate();

let err_type = quote! { #tvm_rt_crate::Error };
let result_type = quote! { #tvm_rt_crate::function::Result };

let mut items = Vec::new();

Expand Down Expand Up @@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);

let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>> = func_ref.to_boxed_fn();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
let res: #ret_type = func_ref(#(#args),*)?;
Ok(res)
}
Expand Down
31 changes: 15 additions & 16 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate;

pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
let tvm_rt_crate = get_tvm_rt_crate();
let result = quote! { #tvm_rt_crate::function::Result };
let error = quote! { #tvm_rt_crate::errors::Error };
let derive_input = syn::parse_macro_input!(input as DeriveInput);
let payload_id = derive_input.ident;

Expand Down Expand Up @@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
#[derive(Clone)]
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);

impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
fn to_object_ref(&self) -> ObjectRef {
ObjectRef(self.0.as_ref().map(|o| o.upcast()))
impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
type Object = #payload_id;

fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.0.as_ref()
}

fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
#ref_id(object_ptr)
}
}

Expand All @@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> {
fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
let oref: ObjectRef = ret_val.try_into()?;
let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
Expand Down Expand Up @@ -125,24 +133,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}

impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;

fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}

impl From<#ref_id> for #tvm_rt_crate::RetValue {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue {
Expand Down
79 changes: 79 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::convert::{TryFrom, TryInto};
use std::marker::PhantomData;

use crate::errors::Error;
use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
use crate::{
external,
function::{Function, Result},
RetValue,
};

#[repr(C)]
#[derive(Clone)]
pub struct Array<T: IsObjectRef> {
object: ObjectRef,
_data: PhantomData<T>,
}

// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
}

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data
.iter()
.map(|element| element.to_object_ref().into())
.collect();

let func = Function::get("node.Array").expect(
"node.Array function is not registered, this is most likely a build or linking error",
);

// let array_data = func.invoke(iter)?;
// let array_data: ObjectRef = func.invoke(iter)?.try_into()?;
let array_data: ObjectPtr<Object> = func.invoke(iter)?.try_into()?;

debug_assert!(
array_data.count() >= 1,
"array reference count is {}",
array_data.count()
);

Ok(Array {
object: ObjectRef(Some(array_data)),
_data: PhantomData,
})
}

pub fn get(&self, index: isize) -> Result<T>
where
T: TryFrom<RetValue, Error = Error>,
{
let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
oref.downcast()
}
}
2 changes: 2 additions & 0 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum Error {
NDArray(#[from] NDArrayError),
#[error("{0}")]
CallFailed(String),
#[error("this case will never occur")]
Infallible(#[from] std::convert::Infallible),
}

impl Error {
Expand Down
20 changes: 16 additions & 4 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ use std::{
ptr, str,
};

pub use tvm_sys::{ffi, ArgValue, RetValue};

use crate::errors::Error;

use super::to_boxed_fn::ToBoxedFn;
use super::to_function::{ToFunction, Typed};

pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;

Expand Down Expand Up @@ -65,6 +65,14 @@ impl Function {
}
}

pub unsafe fn null() -> Self {
Function {
handle: std::ptr::null_mut(),
is_global: false,
from_rust: false,
}
}

/// For a given function, it returns a function by name.
pub fn get<S: AsRef<str>>(name: S) -> Option<Function> {
let name = CString::new(name.as_ref()).unwrap();
Expand Down Expand Up @@ -171,7 +179,11 @@ impl TryFrom<RetValue> for Function {

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
ArgValue::FuncHandle(func.handle)
if func.handle.is_null() {
ArgValue::Null
} else {
ArgValue::FuncHandle(func.handle)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
}
}

#[macro_use]
pub mod function;
pub mod array;
pub mod context;
pub mod errors;
pub mod function;
pub mod module;
pub mod ndarray;
pub mod to_boxed_fn;
Expand Down
22 changes: 11 additions & 11 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,17 +411,17 @@ mod tests {
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
}

#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() {
let shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = Context::cpu(0);
let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap();
}
// #[test]
// #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
// fn copy_wrong_dtype() {
// let shape = vec![4];
// let mut data = vec![1f32, 2., 3., 4.];
// let ctx = Context::cpu(0);
// let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
// nd_float.copy_from_buffer(&mut data);
// let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
// nd_float.copy_to_ndarray(empty_int).unwrap();
// }

#[test]
fn rust_ndarray() {
Expand Down
Loading

0 comments on commit d8c80c3

Please sign in to comment.