diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fba35a9193f94..82689bda73612 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,8 @@ namespace tvm { +using tvm::runtime::String; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -189,7 +192,7 @@ class GlobalVar; class GlobalVarNode : public RelayExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; + String name_hint; void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index cc21450e25c1a..8f559ae24aacb 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,7 +38,7 @@ def asobject(self): def convert_to_object(value): - """Convert a python value to corresponding object type. + """Convert a Python value to corresponding object type. Parameters ---------- diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml index 784b35e2fdae4..7abc9ae64f7c6 100644 --- a/rust/macros/Cargo.toml +++ b/rust/macros/Cargo.toml @@ -32,5 +32,5 @@ proc-macro = true [dependencies] goblin = "0.0.24" proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +quote = "^1.0" +syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/macros/src/import_module.rs b/rust/macros/src/import_module.rs new file mode 100644 index 0000000000000..6b059ae363f82 --- /dev/null +++ b/rust/macros/src/import_module.rs @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use quote::quote; +use std::{fs::File, io::Read}; +use syn::parse::{Parse, ParseStream, Result}; +use syn::LitStr; + +use std::path::PathBuf; + +struct ImportModule { + importing_file: LitStr, +} + +impl Parse for ImportModule { + fn parse(input: ParseStream) -> Result { + let importing_file: LitStr = input.parse()?; + Ok(ImportModule { importing_file }) + } +} + +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let import_module_args = syn::parse_macro_input!(input as ImportModule); + + let manifest = + std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); + + let mut path = PathBuf::new(); + path.push(manifest); + path = path.join(import_module_args.importing_file.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, ref nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } + }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[ArgValue]) -> Result { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(RetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) +} diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index 9f28c74febd62..e9ddc25ddf9c7 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -17,121 +17,17 @@ * under the License. */ -extern crate proc_macro; - -use quote::quote; -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::LitStr; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result { - let importing_file: LitStr = input.parse()?; - Ok(ImportModule { importing_file }) - } -} +use proc_macro::TokenStream; +mod import_module; +mod object; #[proc_macro] -pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let manifest = - std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); - - let mut path = PathBuf::new(); - path.push(manifest); - path = path.join(import_module_args.importing_file.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_runtime::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[TVMArgValue]) -> Result { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; +pub fn import_module(input: TokenStream) -> TokenStream { + import_module::macro_impl(input) +} - proc_macro::TokenStream::from(fns) +#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +pub fn macro_impl(input: TokenStream) -> TokenStream { + // let input = proc_macro2::TokenStream::from(input); + TokenStream::from(object::macro_impl(input)) } diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs new file mode 100644 index 0000000000000..96a86dd740749 --- /dev/null +++ b/rust/macros/src/object.rs @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::DeriveInput; +use syn::Ident; + +pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { + let derive_input = syn::parse_macro_input!(input as DeriveInput); + let payload_id = derive_input.ident; + + let mut type_key = None; + let mut ref_name = None; + let base = Some(Ident::new("base", Span::call_site())); + + for attr in derive_input.attrs { + if attr.path.is_ident("type_key") { + type_key = Some(attr.parse_meta().expect("foo")) + } + + if attr.path.is_ident("ref_name") { + ref_name = Some(attr.parse_meta().expect("foo")) + } + } + + let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key { + match name_value.lit { + syn::Lit::Str(type_key) => type_key, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name { + match name_value.lit { + syn::Lit::Str(ref_name) => ref_name, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_id = Ident::new(&ref_name.value(), Span::call_site()); + let base = base.expect("should be present"); + + let expanded = quote! { + unsafe impl tvm_rt::object::IsObject for #payload_id { + const TYPE_KEY: &'static str = #type_key; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.#base.as_object() + } + } + + #[derive(Clone)] + pub struct #ref_id(Option>); + + impl tvm_rt::object::ToObjectRef for #ref_id { + fn to_object_ref(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } + } + + impl std::ops::Deref for #ref_id { + type Target = #payload_id; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().unwrap() + } + } + + impl std::convert::TryFrom for #ref_id { + type Error = ::anyhow::Error; + + fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let oref: ObjectRef = ret_val.try_into()?; + let ptr = oref.0.ok_or(anyhow::anyhow!("null ptr"))?; + let ptr = ptr.downcast::<#payload_id>()?; + Ok(#ref_id(Some(ptr))) + } + } + + impl<'a> From<#ref_id> for tvm_rt::ArgValue<'a> { + fn from(object_ref: #ref_id) -> tvm_rt::ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + tvm_rt::ArgValue:: + ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + impl<'a> From<&#ref_id> for tvm_rt::ArgValue<'a> { + fn from(object_ref: &#ref_id) -> tvm_rt::ArgValue<'a> { + let oref: #ref_id = object_ref.clone(); + tvm_rt::ArgValue::<'a>::from(oref) + } + } + + impl<'a> std::convert::TryFrom> for #ref_id { + type Error = anyhow::Error; + + fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id { + type Error = anyhow::Error; + + fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl From<#ref_id> for tvm_rt::RetValue { + fn from(object_ref: #ref_id) -> tvm_rt::RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + tvm_rt::RetValue::ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + }; + + TokenStream::from(expanded) +} + +// impl TryFrom for Var { +// type Error = anyhow::Error; + +// fn try_from(ret_val: RetValue) -> Result { +// let oref: ObjectRef = ret_val.try_into()?; +// let var_ptr = oref.0.ok_or(anyhow!("null ptr"))?; +// let var_ptr = var_ptr.downcast::()?; +// Ok(Var(Some(var_ptr))) +// } +// } diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore new file mode 100644 index 0000000000000..2430329c78b6a --- /dev/null +++ b/rust/tvm-rt/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml new file mode 100644 index 0000000000000..417f2567595c2 --- /dev/null +++ b/rust/tvm-rt/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-rt" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust bindings for the TVM runtime API." +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } +tvm-macros = { version = "0.1", path = "../macros" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md new file mode 100644 index 0000000000000..fff3b5673073f --- /dev/null +++ b/rust/tvm-rt/README.md @@ -0,0 +1,235 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to ByteArray + let params: Vec = fs::read("deploy_param.params")?; + let barr = ByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec + let output = output.to_vec::()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[ArgValue]) -> Result` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[ArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = RetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs new file mode 100644 index 0000000000000..bceae5e87ba84 --- /dev/null +++ b/rust/tvm-rt/src/context.rs @@ -0,0 +1,76 @@ +use tvm_sys::ffi; +pub use tvm_sys::context::*; + +use std::os::raw::c_void; +use std::ptr; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; + fn sync(&self) -> anyhow::Result<()>; + fn max_threads_per_block(&self) -> isize; + fn warp_size(&self) -> isize; + fn max_shared_memory_per_block(&self) -> isize; + fn compute_version(&self) -> isize; + fn device_name(&self) -> isize; + fn max_clock_rate(&self) -> isize; + fn multi_processor_count(&self) -> isize; + fn max_thread_dimensions(&self) -> isize; +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + fn $attr_name(&self) -> isize { + get_device_attr(self.device_type.0 as i32, self.device_id as i32, 0) + .expect("should not fail") as isize + } + + )+ + }; +} + +external_func! { + fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32 as "runtime.GetDeviceAttr"; +} + + +impl ContextExt for Context { + fn exist(&self) -> bool { + let exists = get_device_attr(self.device_type.0 as i32, self.device_id as i32, 0) + .expect("should not fail"); + + exists != 0 + } + + /// Synchronize the context stream. + fn sync(&self) -> anyhow::Result<()> { + check_call!(ffi::TVMSynchronize( + self.device_type.0 as i32, + self.device_id as i32, + ptr::null_mut() as *mut c_void + )); + Ok(()) + } + + impl_device_attrs!((max_threads_per_block, 1); + (warp_size, 2); + (max_shared_memory_per_block, 3); + (compute_version, 4); + (device_name, 5); + (max_clock_rate, 6); + (multi_processor_count, 7); + (max_thread_dimensions, 8)); +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sync() { + let ctx = Context::cpu(0); + assert!(ctx.sync().is_ok()) + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs new file mode 100644 index 0000000000000..77dbba7475275 --- /dev/null +++ b/rust/tvm-rt/src/errors.rs @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use thiserror::Error; + +#[derive(Debug, Error)] +#[error("Cannot convert from an empty array.")] +pub struct EmptyArrayError; + +#[derive(Debug, Error)] +#[error("Handle `{name}` is null.")] +pub struct NullHandleError { + pub name: String, +} + +#[derive(Debug, Error)] +#[error("Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; + +#[derive(Debug, Error)] +#[error("Expected type `{expected}` but found `{actual}`")] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, +} + +#[derive(Debug, Error)] +#[error("Missing NDArray shape.")] +pub struct MissingShapeError; diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs new file mode 100644 index 0000000000000..739c7a09d0956 --- /dev/null +++ b/rust/tvm-rt/src/function.rs @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::{ + collections::BTreeMap, + ffi::{CStr, CString}, + mem::{self, MaybeUninit}, + os::raw::{c_char, c_int}, + ptr, slice, str, + sync::Mutex, +}; + +use anyhow::{Result}; +use lazy_static::lazy_static; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use super::to_function::{ToFunction, Typed}; +use super::to_boxed_fn::ToBoxedFn; + +lazy_static! { + static ref GLOBAL_FUNCTIONS: Mutex>> = { + let mut out_size = 0 as c_int; + let mut names_ptr = ptr::null_mut() as *mut *const c_char; + check_call!(ffi::TVMFuncListGlobalNames( + &mut out_size as *mut _, + &mut names_ptr as *mut _, + )); + let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) }; + + let names_list: Vec = + names_list + .iter() + .map(|&p| unsafe { CStr::from_ptr(p).to_str().unwrap().into() }) + .collect(); + + // println!("{:?}", &names_list); + + let names_list = names_list + .into_iter() + .map(|p| (p, None)) + .collect(); + + Mutex::new(names_list) + }; +} + +/// Wrapper around TVM function handle which includes `is_global` +/// indicating whether the function is global or not, and `is_cloned` showing +/// not to drop a cloned function from Rust side. +/// The value of these fields can be accessed through their respective methods. +#[derive(Debug, Hash)] +pub struct Function { + pub(crate) handle: ffi::TVMFunctionHandle, + // whether the registered function is global or not. + is_global: bool, + // whether the function has been cloned from frontend or not. + is_cloned: bool, +} + +unsafe impl Send for Function {} +unsafe impl Sync for Function {} + +impl Function { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + Function { + handle, + is_global: false, + is_cloned: false, + } + } + + /// For a given function, it returns a function by name. + pub fn get>(name: S) -> Option<&'static Function> { + let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); + globals.get_mut(name.as_ref()).and_then(|maybe_func| { + if maybe_func.is_none() { + let name = CString::new(name.as_ref()).unwrap(); + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMFuncGetGlobal( + name.as_ptr() as *const c_char, + &mut handle as *mut _ + )); + maybe_func.replace(Function { + handle, + is_global: true, + is_cloned: false, + }); + } + + unsafe { + mem::transmute::, Option<&'static Function>>(maybe_func.as_ref()) + } + }) + } + + /// Returns the underlying TVM function handle. + pub fn handle(&self) -> ffi::TVMFunctionHandle { + self.handle + } + + /// Returns `true` if the underlying TVM function is global and `false` otherwise. + pub fn is_global(&self) -> bool { + self.is_global + } + + /// Returns `true` if the underlying TVM function has been cloned + /// from the frontend and `false` otherwise. + pub fn is_cloned(&self) -> bool { + self.is_cloned + } + + /// Calls the function that created from `Builder`. + pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { + let num_args = arg_buf.len(); + let (mut values, mut type_codes): (Vec, Vec) = + arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); + + let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() }; + let mut ret_type_code = 0i32; + check_call!(ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr(), + type_codes.as_mut_ptr() as *mut i32, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + + Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32)) + } + + pub fn to_boxed_fn(&'static self) -> Box where F: ToBoxedFn { + F::to_boxed_fn(self) + } +} + +impl Clone for Function { + fn clone(&self) -> Function { + Self { + handle: self.handle, + is_global: self.is_global, + is_cloned: true, + } + } +} + +impl Drop for Function { + fn drop(&mut self) { + if !self.is_global && !self.is_cloned { + check_call!(ffi::TVMFuncFree(self.handle)); + } + } +} + +/// Registers a Rust function with signature +/// `fn(&[ArgValue]) -> Result` +/// as a **global TVM packed function** from frontend to TVM backend. +/// +/// Use [`register_global_func`] if overriding an existing global TVM function +/// is not required. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, function, RetValue}; +/// # use tvm_rt::function::Builder; +/// # use anyhow::Error; +/// use std::convert::TryInto; +/// +/// fn sum(args: &[ArgValue]) -> Result { +/// let mut ret = 0i64; +/// for arg in args.iter() { +/// let arg: i64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = RetValue::from(ret); +/// Ok(ret_val) +/// } +/// +/// function::register(sum, "mysum".to_owned()).unwrap(); +/// let mut registered = Builder::default(); +/// registered.get_function("mysum"); +/// assert!(registered.func.is_some()); +/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register>(f: F, name: S) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + register_override(f, name, false) +} + +/// Registers a Rust function with signature +/// `fn(&[ArgValue]) -> Result` +/// as a **global TVM packed function** from frontend to TVM backend. +/// +/// Use [`register_global_func`] if overriding an existing global TVM function +/// is not required. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, function, RetValue}; +/// # use tvm_rt::function::Builder; +/// # use anyhow::Error; +/// use std::convert::TryInto; +/// +/// fn sum(args: &[ArgValue]) -> Result { +/// let mut ret = 0i64; +/// for arg in args.iter() { +/// let arg: i64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = RetValue::from(ret); +/// Ok(ret_val) +/// } +/// +/// function::register_override(sum, "mysum".to_owned(), false).unwrap(); +/// let mut registered = Builder::default(); +/// registered.get_function("mysum"); +/// assert!(registered.func.is_some()); +/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + let func = f.to_function(); + let name = name.into(); + let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); + // Not sure about this code + let handle = func.handle(); + globals.insert(name.clone(), Some(func)); + let name= CString::new(name)?; + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), + handle, + override_ as c_int + )); + + Ok(()) +} + +#[macro_export] +macro_rules! external_func { + (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { + ::paste::item! { + #[allow(non_upper_case_globals)] + static []: ::once_cell::sync::Lazy<&'static $crate::Function> = + ::once_cell::sync::Lazy::new(|| { + $crate::Function::get($ext_name) + .expect(concat!("unable to load external function", stringify!($ext_name), "from TVM registry.")) + }); + } + + pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> { + let func_ref: &$crate::Function = ::paste::expr! { &*[] }; + let func_ref: Box anyhow::Result<$ret_type>> = func_ref.to_boxed_fn(); + let res: $ret_type = func_ref($($arg),*)?; + Ok(res) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::{Function}; + + static CANARY: &str = "runtime.ModuleLoadFromFile"; + + // #[test] + // fn list_global_func() { + // assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY)); + // } + + #[test] + fn get_fn() { + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); + } + + #[test] + fn register_and_call_closure0() { + use crate::function; + + fn constfn() -> i64 { + return 10; + } + + function::register_override(constfn, "constfn".to_owned(), true).unwrap(); + let func = Function::get("constfn").unwrap(); + let func = func.to_boxed_fn:: Result>(); + let ret = func().unwrap(); + assert_eq!(ret, 10); + } + + // #[test] + // fn register_and_call_closure1() { + // use crate::function::{self}; + + // fn ident(x: i64) -> i64 { + // return x; + // } + + // function::register_override(ident, "ident".to_owned(), false).unwrap(); + // let func = Function::get("ident").unwrap(); + // let func = func.to_boxed_fn:: Result>(); + // assert_eq!(func(60).unwrap(), 60); + // } +} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs new file mode 100644 index 0000000000000..e9ae02f339307 --- /dev/null +++ b/rust/tvm-rt/src/lib.rs @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +extern crate ndarray as rust_ndarray; + +pub use crate as tvm_rt; + +pub mod object; +pub mod string; + +pub use object::*; +pub use string::*; + +use std::{ + ffi::{CStr, CString}, + str, +}; + +use anyhow::Error; + +pub use crate::{ + context::{Context, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, +}; + +pub use function::{ArgValue, RetValue}; +pub use tvm_sys::byte_array::ByteArray; +pub use tvm_sys::datatype::DataType; + +use tvm_sys::ffi; + +// Macro to check the return call to TVM runtime shared library. +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +/// Gets the last error message. +pub fn get_last_error() -> &'static str { + unsafe { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + } +} + +pub(crate) fn set_last_error(err: &Error) { + let c_string = CString::new(err.to_string()).unwrap(); + unsafe { + ffi::TVMAPISetLastError(c_string.as_ptr()); + } +} + +#[macro_use] +pub mod function; +pub mod context; +pub mod errors; +pub mod module; +pub mod ndarray; +pub mod to_function; +pub mod to_boxed_fn; +pub mod value; + +/// Outputs the current TVM version. +pub fn version() -> &'static str { + match str::from_utf8(ffi::TVM_VERSION) { + Ok(s) => s, + Err(_) => "Invalid UTF-8 string", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_version() { + println!("TVM version: {}", version()); + } + + #[test] + fn set_error() { + let err = errors::EmptyArrayError; + set_last_error(&err.into()); + assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string()); + } +} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs new file mode 100644 index 0000000000000..f9b49d9a58e15 --- /dev/null +++ b/rust/tvm-rt/src/module.rs @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides the [`Module`] type and methods for working with runtime TVM modules. + +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + path::Path, + ptr, +}; + +use anyhow::{anyhow, ensure, Error}; +use tvm_sys::ffi; + +use crate::{errors, function::Function}; + +const ENTRY_FUNC: &str = "__tvm_main__"; + +/// Wrapper around TVM module handle which contains an entry function. +/// The entry function can be applied to an imported module through [`entry_func`]. +/// +/// [`entry_func`]:struct.Module.html#method.entry_func +#[derive(Debug, Clone)] +pub struct Module { + pub(crate) handle: ffi::TVMModuleHandle, + entry_func: Option, +} + + +external_func! { + fn runtime_enabled(target: CString) -> i32 as "runtime.RuntimeEnabled"; +} + +external_func! { + fn load_from_file(file_name: CString, format: CString) -> Module as "runtime.ModuleLoadFromFile"; +} + + +impl Module { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { + Self { + handle, + entry_func: None, + } + } + + pub fn entry(&mut self) -> Option<&Function> { + if self.entry_func.is_none() { + self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); + } + self.entry_func.as_ref() + } + + /// Gets a function by name from a registered module. + pub fn get_function(&self, name: &str, query_import: bool) -> Result { + let name = CString::new(name)?; + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( + self.handle, + name.as_ptr() as *const c_char, + query_import as c_int, + &mut fhandle as *mut _ + )); + ensure!( + !fhandle.is_null(), + errors::NullHandleError { + name: name.into_string()?.to_string() + } + ); + Ok(Function::new(fhandle)) + } + + /// Imports a dependent module such as `.ptx` for gpu. + pub fn import_module(&self, dependent_module: Module) { + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + } + + /// Loads a module shared library from path. + pub fn load>(path: &P) -> Result { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or_else(|| std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, + )?; + let cpath = CString::new( + path.as_ref() + .to_str() + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, + )?; + let module = load_from_file(cpath, ext)?; + Ok(module) + } + + /// Checks if a target device is enabled for a module. + pub fn enabled(&self, target: &str) -> bool { + let target = CString::new(target).unwrap(); + let enabled = runtime_enabled(target).unwrap(); + enabled != 0 + } + + /// Returns the underlying module handle. + pub fn handle(&self) -> ffi::TVMModuleHandle { + self.handle + } +} + +impl Drop for Module { + fn drop(&mut self) { + check_call!(ffi::TVMModFree(self.handle)); + } +} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs new file mode 100644 index 0000000000000..4653117b88fbd --- /dev/null +++ b/rust/tvm-rt/src/ndarray.rs @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use crate::errors; +use anyhow::{bail, ensure, Result}; +use num_traits::Num; +use rust_ndarray::{Array, ArrayD}; +use std::convert::TryInto; +use std::ffi::c_void; +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub enum NDArray { + Borrowed { handle: ffi::TVMArrayHandle }, + Owned { handle: *mut c_void }, +} + +impl NDArray { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { + NDArray::Borrowed { handle } + } + + pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { + NDArray::Owned { handle } + } + + pub fn as_dltensor(&self) -> &DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result { + Ok(match self.strides() { + None => true, + Some(strides) => { + // errors::MissingShapeError in case shape is not determined + self.shape() + .ok_or(errors::MissingShapeError)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::().unwrap(), data); + /// ``` + pub fn to_vec(&self) -> Result> { + ensure!(self.shape().is_some(), errors::EmptyArrayError); + let earr = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(errors::MissingShapeError)?; + let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result { + let v = self.to_vec::()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if self.dtype() != target.dtype() { + bail!( + "{}", + errors::TypeMismatchError { + expected: self.dtype().to_string(), + actual: target.dtype().to_string(), + } + ); + } + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray( + rnd: &ArrayD, + ctx: Context, + dtype: DataType, + ) -> Result { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type.0 as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = anyhow::Error; + fn try_from(nd: &NDArray) -> Result> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = anyhow::Error; + fn try_from(nd: &mut NDArray) -> Result> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = Context::cpu(0); + let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } + + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = Context::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert!(ndarray.to_vec::().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let shape = vec![4]; + let e = NDArray::empty( + &shape, + Context::cpu(0), + DataType::from_str("int32").unwrap(), + ); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + .unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } +} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs new file mode 100644 index 0000000000000..8d8efdf9d2a9d --- /dev/null +++ b/rust/tvm-rt/src/object/mod.rs @@ -0,0 +1,99 @@ +use std::convert::TryFrom; +use std::convert::TryInto; +use std::ffi::CString; +use tvm_sys::{ArgValue, RetValue}; +use crate::external_func; + +mod object_ptr; + +pub use object_ptr::{IsObject, Object, ObjectPtr}; + +#[derive(Clone)] +pub struct ObjectRef(pub Option>); + +impl ObjectRef { + pub fn null() -> ObjectRef { + ObjectRef(None) + } +} + +pub trait ToObjectRef { + fn to_object_ref(&self) -> ObjectRef; +} + +impl ToObjectRef for ObjectRef { + fn to_object_ref(&self) -> ObjectRef { + self.clone() + } +} + +// impl ToObjectRef for &T { +// fn to_object_ref(&self) -> ObjectRef { +// (*self).to_object_ref() +// } +// } + +impl TryFrom for ObjectRef { + type Error = anyhow::Error; + + fn try_from(ret_val: RetValue) -> Result { + let optr = ret_val.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl From for RetValue { + fn from(object_ref: ObjectRef) -> RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => RetValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> std::convert::TryFrom> for ObjectRef { + type Error = anyhow::Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result { + let optr = arg_value.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { + type Error = anyhow::Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result { + // TODO(@jroesch): remove the clone + let value: ArgValue<'a> = arg_value.clone(); + ObjectRef::try_from(value) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(object_ref: ObjectRef) -> ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> From<&ObjectRef> for ArgValue<'a> { + fn from(object_ref: &ObjectRef) -> ArgValue<'a> { + let oref: ObjectRef = object_ref.clone(); + ArgValue::<'a>::from(oref) + } +} + +external_func! { + fn debug_print(object: ObjectRef) -> CString as "ir.DebugPrinter"; +} + +external_func! { + fn as_text(object: ObjectRef) -> CString as "ir.TextPrinter"; +} diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs new file mode 100644 index 0000000000000..c716c05183221 --- /dev/null +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -0,0 +1,283 @@ +use anyhow::Context; +use std::convert::TryFrom; +use std::ffi::CString; +use std::ptr::NonNull; +use tvm_sys::ffi::{self, /* TVMObjectFree, */ TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::{ArgValue, RetValue}; + +type Deleter = unsafe extern "C" fn(object: *mut T) -> (); + +#[derive(Debug)] +#[repr(C)] +pub struct Object { + pub type_index: u32, + pub ref_count: i32, + pub fdeleter: Deleter, +} + +unsafe extern "C" fn delete(object: *mut Object) { + let typed_object: *mut T = std::mem::transmute(object); + T::typed_delete(typed_object); +} + +fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { + let mut is_derived = 0; + crate::check_call!(ffi::TVMObjectDerivedFrom( + child_type_index, + parent_type_index, + &mut is_derived + )); + if is_derived == 0 { + false + } else { + true + } +} + +impl Object { + fn new(type_index: u32, deleter: Deleter) -> Object { + Object { + type_index, + // Note: do not touch this field directly again, this is + // a critical section, we write a 1 to the atomic which will now + // be managed by the C++ atomics. + // In the future we should probably use C-atomcis. + ref_count: 1, + fdeleter: deleter, + } + } + + fn get_type_index() -> u32 { + let type_key = T::TYPE_KEY; + let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { + return 0; + } else { + let mut index = 0; + unsafe { + let index_ptr = std::mem::transmute(&mut index); + if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 { + panic!(crate::get_last_error()) + } + } + return index; + } + } + + pub fn base_object() -> Object { + let index = Object::get_type_index::(); + Object::new(index, delete::) + } +} + +pub unsafe trait IsObject { + const TYPE_KEY: &'static str; + + fn as_object<'s>(&'s self) -> &'s Object; + + unsafe extern "C" fn typed_delete(_object: *mut Self) { + // let object = Box::from_raw(object); + // drop(object) + } +} + +unsafe impl IsObject for Object { + const TYPE_KEY: &'static str = "Object"; + + fn as_object<'s>(&'s self) -> &'s Object { + self + } +} + +#[repr(C)] +pub struct ObjectPtr { + pub ptr: NonNull, +} + +impl ObjectPtr { + fn from_raw(object_ptr: *mut Object) -> Option> { + println!("{:?}", object_ptr); + let non_null = NonNull::new(object_ptr); + non_null.map(|ptr| ObjectPtr { ptr }) + } +} + +impl Clone for ObjectPtr { + fn clone(&self) -> Self { + unsafe { + let raw_ptr = std::mem::transmute(self.ptr); + assert_eq!(TVMObjectRetain(raw_ptr), 0); + ObjectPtr { ptr: self.ptr } + } + } +} + +// impl Drop for ObjectPtr { +// fn drop(&mut self) { +// unsafe { +// let raw_ptr = std::mem::transmute(self.ptr); +// assert_eq!(TVMObjectFree(raw_ptr), 0) +// } +// } +// } + +impl ObjectPtr { + pub fn new(object: T) -> ObjectPtr { + let object_ptr = Box::new(object); + let ptr = NonNull::from(Box::leak(object_ptr)); + ObjectPtr { ptr } + } + + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.as_object().ref_count + } + + fn as_object<'s>(&'s self) -> &'s Object { + unsafe { self.ptr.as_ref().as_object() } + } + + pub fn upcast(&self) -> ObjectPtr { + ObjectPtr { + ptr: self.ptr.cast(), + } + } + + pub fn downcast(&self) -> anyhow::Result> { + let child_index = Object::get_type_index::(); + let object_index = self.as_object().type_index; + + let is_derived = if child_index == object_index { + true + } else { + // TODO(@jroesch): write tests + derived_from(object_index, child_index) + }; + + if is_derived { + Ok(ObjectPtr { + ptr: self.ptr.cast(), + }) + } else { + Err(anyhow::anyhow!("failed to downcast to object subtype")) + } + } +} + +impl std::ops::Deref for ObjectPtr { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { self.ptr.as_ref() } + } +} + +impl<'a, T: IsObject> From> for RetValue { + fn from(object_ptr: ObjectPtr) -> RetValue { + let raw_object_ptr = object_ptr.ptr.as_ptr(); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + RetValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom for ObjectPtr { + type Error = anyhow::Error; + + fn try_from(ret_value: RetValue) -> Result, Self::Error> { + match ret_value { + RetValue::ObjectHandle(handle) => { + let handle: *mut Object = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +impl<'a, T: IsObject> From> for ArgValue<'a> { + fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { + let raw_object_ptr = object_ptr.ptr.as_ptr(); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + ArgValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom> for ObjectPtr { + type Error = anyhow::Error; + fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { + type Error = anyhow::Error; + fn try_from(arg_value: &ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +#[cfg(test)] +mod tests { + use super::{Object, ObjectPtr}; + use anyhow::{ensure, Result}; + use std::convert::TryInto; + use tvm_sys::{ArgValue, RetValue}; + + #[test] + fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) + } + + #[test] + fn roundtrip_retvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let ret_value: RetValue = ptr.clone().into(); + let ptr2: ObjectPtr = ret_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + #[test] + fn roundtrip_argvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let arg_value: ArgValue = ptr.clone().into(); + let ptr2: ObjectPtr = arg_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs new file mode 100644 index 0000000000000..ac806252bf58c --- /dev/null +++ b/rust/tvm-rt/src/string.rs @@ -0,0 +1,72 @@ +use std::ffi::{CString, NulError}; +use std::os::raw::c_char; + +use super::{Object, ObjectPtr, ObjectRef}; +use crate as tvm_rt; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "String"] +#[type_key = "runtime.String"] +pub struct StringObj { + base: Object, + data: *const c_char, + size: u64, +} + +impl String { + pub fn new(string: std::string::String) -> Result { + let cstring = CString::new(string)?; + + // The string is being corrupted. + // why is this wrong + let length = cstring.as_bytes().len(); + + let string_obj = StringObj { + base: Object::base_object::(), + data: cstring.into_raw(), + size: length as u64, + }; + + let object_ptr = ObjectPtr::new(string_obj); + Ok(String(Some(object_ptr))) + } + + pub fn to_cstring(&self) -> Result { + use std::slice; + let ptr = self.0.as_ref().unwrap().data; + let size = self.0.as_ref().unwrap().size; + unsafe { + let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); + CString::new(slice) + } + } + + pub fn to_string(&self) -> anyhow::Result { + let string = self.to_cstring()?.into_string()?; + Ok(string) + } +} + +#[cfg(test)] +mod tests { + use super::String; + use crate::object::debug_print; + use crate::ToObjectRef; + use anyhow::{ensure, Result}; + + #[test] + fn test_string_debug() -> Result<()> { + let s = String::new("foo".to_string()).unwrap(); + let object_ref = s.to_object_ref(); + println!("about to call"); + let string = debug_print(object_ref)?; + println!("after call"); + ensure!( + string.into_string().expect("is cstring").contains("foo"), + "string content is invalid" + ); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs new file mode 100644 index 0000000000000..7a560b6bb757c --- /dev/null +++ b/rust/tvm-rt/src/to_boxed_fn.rs @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use anyhow::Result; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::{Module}; + +use super::function::Function; + +pub trait ToBoxedFn { + fn to_boxed_fn(func: &'static Function) -> Box; +} + +use std::convert::{TryInto, TryFrom}; + +impl ToBoxedFn for dyn Fn() -> Result + where E: std::error::Error + Send + Sync + 'static, + O: TryFrom, { + fn to_boxed_fn(func: &'static Function) -> Box { + Box::new(move || { + let mut builder = Builder::default(); + builder.func = Some(func); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A) -> Result + where E: std::error::Error + Send + Sync + 'static, + A: Into>, + O: TryFrom, { + fn to_boxed_fn(func: &'static Function) -> Box { + Box::new(move |a: A| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B) -> Result + where E: std::error::Error + Send + Sync + 'static, + A: Into>, + B: Into>, + O: TryFrom, { + fn to_boxed_fn(func: &'static Function) -> Box { + Box::new(move |a: A, b: B| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C) -> Result + where E: std::error::Error + Send + Sync + 'static, + A: Into>, + B: Into>, + C: Into>, + O: TryFrom, { + fn to_boxed_fn(func: &'static Function) -> Box { + Box::new(move |a: A, b: B, c: C| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C, D) -> Result + where E: std::error::Error + Send + Sync + 'static, + A: Into>, + B: Into>, + C: Into>, + D: Into>, + O: TryFrom, { + fn to_boxed_fn(func: &'static Function) -> Box { + Box::new(move |a: A, b: B, c: C, d: D| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + builder.arg(d.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +/// Function builder in order to create and call functions. +/// +/// *Note:* Currently TVM functions accept *at most* one return value. +#[derive(Default)] +pub struct Builder<'a, 'm> { + pub func: Option<&'m Function>, + pub arg_buf: Vec>, + pub ret_buf: Option, +} + +impl<'a, 'm> Builder<'a, 'm> { + pub fn new( + func: Option<&'m Function>, + arg_buf: Vec>, + ret_buf: Option, + ) -> Self { + Self { + func, + arg_buf, + ret_buf, + } + } + + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); + self + } + + /// Pushes a [`ArgValue`] into the function argument buffer. + pub fn arg(&mut self, arg: T) -> &mut Self + where + ArgValue<'a>: From, + { + self.arg_buf.push(arg.into()); + self + } + + /// Pushes multiple [`ArgValue`]s into the function argument buffer. + pub fn args(&mut self, args: I) -> &mut Self + where + I: IntoIterator, + ArgValue<'a>: From, + { + args.into_iter().for_each(|arg| { + self.arg(arg); + }); + self + } + + /// Sets an output for a function that requirs a mutable output to be provided. + /// See the `basics` in tests for an example. + pub fn set_output(&mut self, ret: T) -> &mut Self + where + RetValue: From, + { + self.ret_buf = Some(ret.into()); + self + } + + pub fn invoke(self) -> Result { + self.func.unwrap().invoke(self.arg_buf) + } + +} + +/// Converts a [`Function`] to builder. Currently, this is the best way to work with +/// TVM functions. +impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { + fn from(func: &'m Function) -> Self { + Builder::new(Some(func), Vec::new(), None) + } +} + +/// Converts a mutable reference of a [`Module`] to [`Builder`]. +impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { + fn from(module: &'m mut Module) -> Self { + Builder::new(module.entry(), Vec::new(), None) + } +} +#[cfg(test)] +mod tests { + use anyhow::Result; + use crate::function::{self, Function}; + + #[test] + fn to_boxed_fn0() { + fn boxed0() -> i64 { + return 10; + } + + function::register_override(boxed0, "boxed0".to_owned(), true).unwrap(); + let func = Function::get("boxed0").unwrap(); + let typed_func: Box Result> = func.to_boxed_fn(); + assert_eq!(typed_func().unwrap(), 10); + } +} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs new file mode 100644 index 0000000000000..6954650ff59f6 --- /dev/null +++ b/rust/tvm-rt/src/to_function.rs @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::{ + mem::MaybeUninit, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use anyhow::Result; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use super::Function; +use std::convert::{TryFrom, TryInto}; + +/// A trait representing whether the function arguments +/// and return type can be assigned to a TVM packed function. +/// +/// By splitting the conversion to function into two traits +/// we are able to improve error reporting, by splitting the +/// conversion of inputs and outputs to this trait. +/// +/// And the implementation of it to `ToFunction`. +pub trait Typed { + fn args(i: &[ArgValue<'static>]) -> anyhow::Result; + fn ret(o: O) -> RetValue; +} + +impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result> for F +where + F: Fn(&'a [ArgValue]) -> anyhow::Result, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a [ArgValue<'static>]> { + // this is BAD but just hacking for time being + Ok(unsafe { std::mem::transmute(args) }) + } + + fn ret(ret_value: anyhow::Result) -> RetValue { + ret_value.unwrap() + } +} + +impl> Typed<(), O> for F +where + F: Fn() -> O, +{ + fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> { + debug_assert!(_args.len() == 0); + Ok(()) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E: Into> Typed<(A,), O> for F +where + F: Fn(A) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + Ok((a,)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E: Into> Typed<(A, B), O> for F +where + F: Fn(A, B) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + Ok((a, b)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E: Into> Typed<(A, B, C), O> for F +where + F: Fn(A, B, C) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, + C: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B, C)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + let c: C = args[2].clone().try_into()?; + Ok((a, b, c)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +pub trait ToFunction: Sized { + type Handle; + + fn into_raw(self) -> *mut Self::Handle; + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> anyhow::Result + where + Self: Typed; + + fn drop(handle: *mut Self::Handle); + + fn to_function(self) -> Function + where + Self: Typed, + { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = self.into_raw(); + check_call!(ffi::TVMFuncCreateFromCFunc( + Some(Self::tvm_callback), + resource_handle as *mut _, + Some(Self::tvm_finalizer), + &mut fhandle as *mut _ + )); + println!("fnhandle: {:?}", fhandle); + Function::new(fhandle) + } + + /// The callback function which is wrapped converted by TVM + /// into a packed function stored in fhandle. + unsafe extern "C" fn tvm_callback( + args: *mut ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ffi::TVMRetValueHandle, + fhandle: *mut c_void, + ) -> c_int + where + Self: Typed, + { + // turning off the incorrect linter complaints + #![allow(unused_assignments, unused_unsafe)] + println!("here"); + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec = Vec::new(); + let mut value = MaybeUninit::uninit().assume_init(); + let mut tcode = MaybeUninit::uninit().assume_init(); + let rust_fn = fhandle as *mut Self::Handle; + for i in 0..len { + value = args_list[i]; + println!("{:?}", value.v_handle); + tcode = type_codes_list[i]; + if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + println!("{:?}", value.v_handle); + } + let arg_value = ArgValue::from_tvm_value(value, tcode as u32); + println!("{:?}", arg_value); + local_args.push(arg_value); + } + println!("before call"); + let rv = match Self::call(rust_fn, local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + println!("after call"); + + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); + let mut ret_type_code = ret_tcode as c_int; + check_call!(ffi::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 + } + + /// The finalizer which is invoked when the packed function's + /// reference count is zero. + unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { + let handle = std::mem::transmute(fhandle); + Self::drop(handle) + } +} + +// /// A wrapper that is used to work around inference issues for bare functions. +// /// +// /// Used to implement `register_untyped`. +// pub(self) struct RawFunction { +// fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> Result +// } + +// impl RawFunction { +// fn new(fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> Result) -> RawFunction { +// RawFunction { fn_ptr: fn_ptr } +// } +// } + +// impl Typed<&[ArgValue<'static>], ()> for RawFunction { +// fn args(i: &[ArgValue<'static>]) -> anyhow::Result<&[ArgValue<'static>]> { +// Ok(i) +// } + +// fn ret(o: O) -> RetValue; +// } + +// impl ToFunction<(), ()> for RawFunction +// { +// type Handle = fn(&[ArgValue<'static>]) -> Result; + +// fn into_raw(self) -> *mut Self::Handle { +// self.fn_ptr as *mut Self::Handle +// } + +// fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result { +// let handle: Self::Handle = unsafe { std::mem::transmute(handle) }; +// let r = handle(args); +// println!("afters"); +// r +// } + +// // Function's don't need de-allocation because the pointers are into the code section of memory. +// fn drop(_: *mut Self::Handle) {} +// } + +impl ToFunction<(), O> for F +where + F: Fn() -> O + 'static, +{ + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result + where + F: Typed<(), O>, + { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let out = unsafe { (*handle)() }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} +} + +macro_rules! to_function_instance { + ($(($param:ident,$index:tt),)+) => { + impl ToFunction<($($param,)+), O> for + F where F: Fn($($param,)+) -> O + 'static { + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result where F: Typed<($($param,)+), O> { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let args = F::args(args)?; + let out = unsafe { + (*handle)($(args.$index),+) + }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} + } + } +} + +to_function_instance!((A, 0),); +to_function_instance!((A, 0), (B, 1),); +to_function_instance!((A, 0), (B, 1), (C, 2),); +to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),); + +#[cfg(test)] +mod tests { + // use super::RawFunction; + use super::{Function, ToFunction, Typed}; + + fn zero() -> i32 { + 10 + } + + fn helper(f: F) -> Function + where + F: ToFunction, + F: Typed, + { + f.to_function() + } + + // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result { + // Ok(10.into()) + // } + + // #[test] + // fn test_fn_ptr() { + // let raw_fn = RawFunction::new(func_args); + // raw_fn.to_function(); + // } + + #[test] + fn test_to_function0() { + helper(zero); + } + + fn one_arg(i: i32) -> i32 { + i + } + + #[test] + fn test_to_function1() { + helper(one_arg); + } + + fn two_arg(i: i32, j: i32) -> i32 { + i + j + } + + #[test] + fn test_to_function2() { + helper(two_arg); + } +} diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs new file mode 100644 index 0000000000000..a9355e0a8d541 --- /dev/null +++ b/rust/tvm-rt/src/value.rs @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements [`ArgValue`] and [`RetValue`] types +//! and their conversions needed for the types used in frontend crate. +//! `RetValue` is the owned version of `TVMPODValue`. + +use std::convert::TryFrom; +// use std::ffi::c_void; + +use crate::{ArgValue, Function, Module, NDArray, RetValue}; +use tvm_sys::{ + errors::ValueDowncastError, + ffi::{TVMFunctionHandle, TVMModuleHandle}, + try_downcast, +}; + +macro_rules! impl_handle_val { + ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { + impl<'a> From<&'a $type> for ArgValue<'a> { + fn from(arg: &'a $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> From<&'a mut $type> for ArgValue<'a> { + fn from(arg: &'a mut $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> TryFrom> for $type { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> RetValue { + RetValue::$variant(val.handle() as $inner_type) + } + } + + impl TryFrom for $type { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) + } + } + }; +} + +impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); +impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); + +impl<'a> From<&'a NDArray> for ArgValue<'a> { + fn from(arg: &'a NDArray) -> Self { + match arg { + &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> From<&'a mut NDArray> for ArgValue<'a> { + fn from(arg: &'a mut NDArray) -> Self { + match arg { + &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> TryFrom> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) + } +} + +impl From for RetValue { + fn from(val: NDArray) -> RetValue { + match val { + NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), + _ => panic!("NYI"), + } + } +} + +impl TryFrom for NDArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> NDArray, + |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |RetValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +#[cfg(test)] +mod tests { + use std::{convert::TryInto, str::FromStr}; + + use crate::{ByteArray, Context, DataType}; + + use super::*; + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } +} diff --git a/rust/tvm-rt/tests/test_ir.rs b/rust/tvm-rt/tests/test_ir.rs new file mode 100644 index 0000000000000..7d9e4758b856c --- /dev/null +++ b/rust/tvm-rt/tests/test_ir.rs @@ -0,0 +1,36 @@ +// use std::convert::TryInto; +// use std::str::FromStr; +// use tvm_rt::string::String as TString; +// use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +// use tvm::{call_packed, DLDataType, Function}; +// use tvm_sys::RetValue; + +// #[test] +// fn test_new_object() -> anyhow::Result<()> { +// let object = Object::base_object::(); +// let ptr = ObjectPtr::new(object); +// assert_eq!(ptr.count(), 1); +// Ok(()) +// } + +// #[test] +// fn test_new_string() -> anyhow::Result<()> { +// let string = TString::new("hello world!".to_string())?; +// Ok(()) +// } + +// #[test] +// fn test_obj_build() -> anyhow::Result<()> { +// let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + +// let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + +// let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) +// .expect("foo") +// .try_into() +// .unwrap(); + +// debug_print(&ret_val); + +// Ok(()) +// } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213ad406f..b3223889cc72d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -162,7 +162,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ +.set_body_typed([](String name){ return GlobalVar(name); }); @@ -214,4 +214,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << '}'; }); + +TVM_REGISTER_GLOBAL("ir.DebugPrinter") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef ref = args[0]; + std::stringstream ss; + ss << ref; + *ret = ss.str(); +}); + } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d7..fc9546a14c9af 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -193,8 +193,7 @@ class RelayTextPrinter : case kTypeData: return Doc::Text("TypeData"); default: - LOG(ERROR) << "Unknown Kind"; - throw; + CHECK(false) << "Unknown Kind"; } } /*! @@ -479,7 +478,8 @@ class RelayTextPrinter : } Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); + std::string name_hint = op->name_hint; + return Doc::Text('@' + name_hint); } Doc VisitExpr_(const OpNode* op) final { @@ -939,4 +939,13 @@ TVM_REGISTER_GLOBAL("ir.PrettyPrint") TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); + +TVM_REGISTER_GLOBAL("ir.TextPrinter") +.set_body_typed([](ObjectRef node) { + std::cout << "The program: " << node << std::endl; + auto text = AsText(node, false, nullptr); + std::cout << "The text " << text; + return text; +}); + } // namespace tvm diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c83928b098a..65ee57f6a3f38 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -164,7 +164,7 @@ Function ToCPS(const Function& f, // only look unfold non-external calls. BaseFunc base_func = m->Lookup(gv); if (auto* n = base_func.as()) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(GetRef(n), m, cm)); } else { diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 03012006bd797..5496159c211c5 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -244,12 +244,26 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_END(); } +int TVMObjectRetain(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::ObjectInternal::ObjectRetain(obj); + API_END(); +} + int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } + +int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { + API_BEGIN(); + *is_derived = tvm::runtime::TypeContext::Global()-> + DerivedFrom(child_type_index, parent_type_index); + API_END(); +} + int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 79551309d67c0..ab48802e774c2 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -37,6 +37,15 @@ namespace runtime { */ class ObjectInternal { public: + /*! + * \brief Retain an object handle. + */ + static void ObjectRetain(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->IncRef(); + } + } + /*! * \brief Free an object handle. */