Skip to content

Commit

Permalink
Add tvm-rt
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 9, 2020
1 parent aded92d commit 3953fa2
Show file tree
Hide file tree
Showing 27 changed files with 3,054 additions and 123 deletions.
5 changes: 4 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/runtime/container.h>
#include <tvm/ir/span.h>
#include <tvm/ir/type.h>
#include <string>
Expand All @@ -36,6 +37,8 @@

namespace tvm {

using tvm::runtime::String;

/*!
* \brief Base type of all the expressions.
* \sa Expr
Expand Down Expand Up @@ -189,7 +192,7 @@ class GlobalVar;
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
String name_hint;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def asobject(self):


def convert_to_object(value):
"""Convert a python value to corresponding object type.
"""Convert a Python value to corresponding object type.
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions rust/macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ proc-macro = true
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
133 changes: 133 additions & 0 deletions rust/macros/src/import_module.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::LitStr;

use std::path::PathBuf;

struct ImportModule {
importing_file: LitStr,
}

impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule { importing_file })
}
}

pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);

let manifest =
std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");

let mut path = PathBuf::new();
path.push(manifest);
path = path.join(import_module_args.importing_file.value());

let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();

let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};

let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};

let fns = quote! {
use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError};
#extern_fns

#(
pub fn #fn_names(args: &[ArgValue]) -> Result<RetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(RetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};

proc_macro::TokenStream::from(fns)
}
124 changes: 10 additions & 114 deletions rust/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,121 +17,17 @@
* under the License.
*/

extern crate proc_macro;

use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::LitStr;

use std::path::PathBuf;

struct ImportModule {
importing_file: LitStr,
}

impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule { importing_file })
}
}
use proc_macro::TokenStream;
mod import_module;
mod object;

#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);

let manifest =
std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");

let mut path = PathBuf::new();
path.push(manifest);
path = path.join(import_module_args.importing_file.value());

let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();

let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};

let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};

let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns

#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}

proc_macro::TokenStream::from(fns)
#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
}
Loading

0 comments on commit 3953fa2

Please sign in to comment.