Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tvm crate stage 3 of Rust refactor #5769

Merged
merged 13 commits into from
Jun 18, 2020
4 changes: 3 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ members = [
"frontend/tests/callback",
"frontend/examples/resnet",
"tvm-sys",
"tvm-rt"
"tvm-macros",
"tvm-rt",
"tvm",
]
4 changes: 4 additions & 0 deletions rust/runtime/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ name = "test-wasm32"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }

[build-dependencies]
anyhow = "^1.0"
14 changes: 10 additions & 4 deletions rust/runtime/tests/test_wasm32/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

use std::{path::PathBuf, process::Command};

fn main() {
use anyhow::{Context, Result};

fn main() -> Result<()> {
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");

if !out_dir.is_dir() {
std::fs::create_dir(&out_dir).unwrap();
std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?;
}

let obj_file = out_dir.join("test.o");
Expand All @@ -36,7 +38,8 @@ fn main() {
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
.context("failed to execute Python script for generating TVM library")?;

assert!(
obj_file.exists(),
"Could not build tvm lib: {}",
Expand All @@ -49,12 +52,14 @@ fn main() {
);

let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");

let output = Command::new(ar)
.arg("rcs")
.arg(&lib_file)
.arg(&obj_file)
.output()
.expect("Failed to execute command");
.context("failed to run LLVM_AR command")?;

assert!(
lib_file.exists(),
"Could not create archive: {}",
Expand All @@ -68,4 +73,5 @@ fn main() {

println!("cargo:rustc-link-lib=static=test_wasm32");
println!("cargo:rustc-link-search=native={}", out_dir.display());
Ok(())
}
6 changes: 3 additions & 3 deletions rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let tvm_rt_crate = crate::util::get_tvm_rt_crate();

let err_type = quote! { #tvm_rt_crate::Error };
let result_type = quote! { #tvm_rt_crate::function::Result };

let mut items = Vec::new();

Expand Down Expand Up @@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);

let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>> = func_ref.to_boxed_fn();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
let res: #ret_type = func_ref(#(#args),*)?;
Ok(res)
}
Expand Down
31 changes: 15 additions & 16 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate;

pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
let tvm_rt_crate = get_tvm_rt_crate();
let result = quote! { #tvm_rt_crate::function::Result };
let error = quote! { #tvm_rt_crate::errors::Error };
let derive_input = syn::parse_macro_input!(input as DeriveInput);
let payload_id = derive_input.ident;

Expand Down Expand Up @@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
#[derive(Clone)]
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);

impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
fn to_object_ref(&self) -> ObjectRef {
ObjectRef(self.0.as_ref().map(|o| o.upcast()))
impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
type Object = #payload_id;

fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.0.as_ref()
}

fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
#ref_id(object_ptr)
}
}

Expand All @@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> {
fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
let oref: ObjectRef = ret_val.try_into()?;
let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
Expand Down Expand Up @@ -125,24 +133,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;
type Error = #error;

fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}

impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #tvm_rt_crate::Error;

fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}

impl From<#ref_id> for #tvm_rt_crate::RetValue {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue {
Expand Down
79 changes: 79 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::convert::{TryFrom, TryInto};
use std::marker::PhantomData;

use crate::errors::Error;
use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
use crate::{
external,
function::{Function, Result},
RetValue,
};

#[repr(C)]
#[derive(Clone)]
pub struct Array<T: IsObjectRef> {
object: ObjectRef,
_data: PhantomData<T>,
}

// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
}

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data
.iter()
.map(|element| element.to_object_ref().into())
.collect();

let func = Function::get("node.Array").expect(
"node.Array function is not registered, this is most likely a build or linking error",
);

// let array_data = func.invoke(iter)?;
// let array_data: ObjectRef = func.invoke(iter)?.try_into()?;
let array_data: ObjectPtr<Object> = func.invoke(iter)?.try_into()?;

debug_assert!(
array_data.count() >= 1,
"array reference count is {}",
array_data.count()
);

Ok(Array {
object: ObjectRef(Some(array_data)),
_data: PhantomData,
})
}

pub fn get(&self, index: isize) -> Result<T>
where
T: TryFrom<RetValue, Error = Error>,
{
let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
oref.downcast()
}
}
2 changes: 2 additions & 0 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum Error {
NDArray(#[from] NDArrayError),
#[error("{0}")]
CallFailed(String),
#[error("this case will never occur")]
Infallible(#[from] std::convert::Infallible),
}

impl Error {
Expand Down
20 changes: 16 additions & 4 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ use std::{
ptr, str,
};

pub use tvm_sys::{ffi, ArgValue, RetValue};

use crate::errors::Error;

use super::to_boxed_fn::ToBoxedFn;
use super::to_function::{ToFunction, Typed};

pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;

Expand Down Expand Up @@ -65,6 +65,14 @@ impl Function {
}
}

pub unsafe fn null() -> Self {
Function {
handle: std::ptr::null_mut(),
is_global: false,
from_rust: false,
}
}

/// For a given function, it returns a function by name.
pub fn get<S: AsRef<str>>(name: S) -> Option<Function> {
let name = CString::new(name.as_ref()).unwrap();
Expand Down Expand Up @@ -171,7 +179,11 @@ impl TryFrom<RetValue> for Function {

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
ArgValue::FuncHandle(func.handle)
if func.handle.is_null() {
ArgValue::Null
} else {
ArgValue::FuncHandle(func.handle)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
}
}

#[macro_use]
pub mod function;
pub mod array;
pub mod context;
pub mod errors;
pub mod function;
pub mod module;
pub mod ndarray;
pub mod to_boxed_fn;
Expand Down
22 changes: 11 additions & 11 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,17 +411,17 @@ mod tests {
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
}

#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
fn copy_wrong_dtype() {
let shape = vec![4];
let mut data = vec![1f32, 2., 3., 4.];
let ctx = Context::cpu(0);
let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
nd_float.copy_from_buffer(&mut data);
let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
nd_float.copy_to_ndarray(empty_int).unwrap();
}
// #[test]
// #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
// fn copy_wrong_dtype() {
// let shape = vec![4];
// let mut data = vec![1f32, 2., 3., 4.];
// let ctx = Context::cpu(0);
// let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
// nd_float.copy_from_buffer(&mut data);
// let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
// nd_float.copy_to_ndarray(empty_int).unwrap();
// }

#[test]
fn rust_ndarray() {
Expand Down
Loading