Skip to content

Commit

Permalink
Add tvm-rt crate
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 28, 2020
1 parent a072da0 commit 9fc0d4a
Show file tree
Hide file tree
Showing 34 changed files with 3,098 additions and 136 deletions.
2 changes: 2 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

namespace tvm {

using tvm::runtime::String;

/*!
* \brief Base type of all the expressions.
* \sa Expr
Expand Down
22 changes: 12 additions & 10 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,15 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

/*!
* \brief Increase the reference count of an object.
*
* \param obj The object handle.
* \note Internally we increase the reference counter of the object.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectRetain(TVMObjectHandle obj);

/*!
* \brief Free the object.
*
Expand All @@ -514,16 +523,6 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param nbytes The number of bytes in memory.
* \param alignment The alignment of the memory.
* \param type_hint The type of elements. Only needed by certain backends such
* as nbytes & alignment are sufficient for most backends.
* \param out_data The allocated device pointer.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint, void** out_data);

Expand Down Expand Up @@ -554,6 +553,9 @@ TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void*
TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream);

TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived);


#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
Expand Down
17 changes: 13 additions & 4 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,20 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array container of ObjectRef in DSL graph.
* Array implements copy-on-write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
* \brief Array, container representing a contigious sequence of ObjectRefs.
*
* operator[] only provide const access, use Set to mutate the content.
* Array implements in-place copy-on-write semantics.
*
* As in typical copy-on-write, a method which would typically mutate the array
* instead opaquely copies the underlying container, and then acts on its copy.
*
* If the array has reference count equal to one, we directly update the
* container in place without copying. This is optimization is sound because
* when the reference count is equal to one this reference is guranteed to be
* the sole pointer to the container.
*
*
* operator[] only provides const access, use Set to mutate the content.
* \tparam T The content ObjectRef type.
*/
template <typename T,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def asobject(self):


def convert_to_object(value):
"""Convert a python value to corresponding object type.
"""Convert a Python value to corresponding object type.
Parameters
----------
Expand Down
3 changes: 2 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ members = [
"frontend/tests/basics",
"frontend/tests/callback",
"frontend/examples/resnet",
"tvm-sys"
"tvm-sys",
"tvm-rt"
]
4 changes: 2 additions & 2 deletions rust/macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ proc-macro = true
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
quote = "1.0"
syn = "1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
133 changes: 133 additions & 0 deletions rust/macros/src/import_module.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::LitStr;

use std::path::PathBuf;

struct ImportModule {
importing_file: LitStr,
}

impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule { importing_file })
}
}

pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);

let manifest =
std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");

let mut path = PathBuf::new();
path.push(manifest);
path = path.join(import_module_args.importing_file.value());

let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();

let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};

let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};

let fns = quote! {
use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError};
#extern_fns

#(
pub fn #fn_names(args: &[ArgValue]) -> Result<RetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(RetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};

proc_macro::TokenStream::from(fns)
}
124 changes: 10 additions & 114 deletions rust/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,121 +17,17 @@
* under the License.
*/

extern crate proc_macro;

use quote::quote;
use std::{fs::File, io::Read};
use syn::parse::{Parse, ParseStream, Result};
use syn::LitStr;

use std::path::PathBuf;

struct ImportModule {
importing_file: LitStr,
}

impl Parse for ImportModule {
fn parse(input: ParseStream) -> Result<Self> {
let importing_file: LitStr = input.parse()?;
Ok(ImportModule { importing_file })
}
}
use proc_macro::TokenStream;
mod import_module;
mod object;

#[proc_macro]
pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let import_module_args = syn::parse_macro_input!(input as ImportModule);

let manifest =
std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");

let mut path = PathBuf::new();
path.push(manifest);
path = path.join(import_module_args.importing_file.value());

let mut fd = File::open(&path)
.unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
let mut buffer = Vec::new();
fd.read_to_end(&mut buffer).unwrap();

let fn_names = match goblin::Object::parse(&buffer).unwrap() {
goblin::Object::Elf(elf) => elf
.syms
.iter()
.filter_map(|s| {
if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
return None;
}
match elf.strtab.get(s.st_name) {
Some(Ok(name)) if name != "" => {
Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
}
_ => None,
}
})
.collect::<Vec<_>>(),
goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
obj.symbols()
.filter_map(|s| match s {
Ok((name, ref nlist))
if nlist.is_global()
&& nlist.n_sect != 0
&& !name.ends_with("tvm_module_ctx") =>
{
Some(syn::Ident::new(
if name.starts_with('_') {
// Mach objects prepend a _ to globals.
&name[1..]
} else {
&name
},
proc_macro2::Span::call_site(),
))
}
_ => None,
})
.collect::<Vec<_>>()
}
_ => panic!("Unsupported object format."),
};

let extern_fns = quote! {
mod ext {
extern "C" {
#(
pub(super) fn #fn_names(
args: *const tvm_runtime::ffi::TVMValue,
type_codes: *const std::os::raw::c_int,
num_args: std::os::raw::c_int
) -> std::os::raw::c_int;
)*
}
}
};

let fns = quote! {
use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
#extern_fns

#(
pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
let exit_code = unsafe {
ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
};
if exit_code == 0 {
Ok(TVMRetValue::default())
} else {
Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
}
}
)*
};
pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}

proc_macro::TokenStream::from(fns)
#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
}
Loading

0 comments on commit 9fc0d4a

Please sign in to comment.