diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 6797f16c3829..b2ce50d91f58 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -37,6 +37,8 @@ namespace tvm { +using tvm::runtime::String; + /*! * \brief Base type of all the expressions. * \sa Expr diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index bf24f992bda4..213c7059a5f9 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -514,6 +514,15 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Increase the reference count of an object. + * + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectRetain(TVMObjectHandle obj); + /*! * \brief Free the object. * @@ -564,6 +573,16 @@ TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream); +/*! + * \brief Check that an object is derived from another. + * \param child_type_index The type index of the derived type. + * \param parent_type_index The type index of the parent type. + * \param is_derived A boolean representing whether this predicate holds. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, + int* is_derived); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 6bc6fbf5b026..2b3eb9264a48 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -511,11 +511,20 @@ class ArrayNode : public Object, public InplaceArrayBase { }; /*! - * \brief Array container of ObjectRef in DSL graph. - * Array implements copy-on-write semantics, which means array is mutable - * but copy will happen when array is referenced in more than two places. + * \brief Array, container representing a contigious sequence of ObjectRefs. * - * operator[] only provide const access, use Set to mutate the content. + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. * \tparam T The content ObjectRef type. */ template , + ret_type: ReturnType, +} + +impl Parse for External { + fn parse(input: ParseStream) -> Result { + let method: TraitItemMethod = input.parse()?; + assert_eq!(method.attrs.len(), 1); + let sig = method.sig; + let tvm_name = method.attrs[0].parse_meta()?; + let tvm_name = match tvm_name { + Meta::List(meta_list) => { + let name = meta_list.path.get_ident().expect("name"); + assert_eq!(name.to_string(), "name".to_string()); + match meta_list.nested.first() { + Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(), + _ => panic!(), + } + } + _ => panic!(), + }; + assert_eq!(method.default, None); + assert!(method.semi_token != None); + let ident = sig.ident; + let generics = sig.generics; + let inputs = sig.inputs.iter().map(|param| param.clone()).collect(); + let ret_type = sig.output; + + Ok(External { + tvm_name, + ident, + generics, + inputs, + ret_type, + }) + } +} + +struct ExternalInput { + externs: Vec, +} + +impl Parse for ExternalInput { + fn parse(input: ParseStream) -> Result { + let mut externs: Vec = Vec::new(); + + loop { + if input.is_empty() { + break; + } + externs.push(input.parse()?); + } + + Ok(ExternalInput { externs }) + } +} + +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ext_input = syn::parse_macro_input!(input as ExternalInput); + + let tvm_rt_crate = crate::util::get_tvm_rt_crate(); + + let err_type = quote! { #tvm_rt_crate::Error }; + + let mut items = Vec::new(); + + for external in &ext_input.externs { + let name = &external.ident; + let global_name = format!("global_{}", external.ident); + let global_name = Ident::new(&global_name, Span::call_site()); + let ext_name = &external.tvm_name; + + let ty_params: Vec = external + .generics + .params + .iter() + .map(|ty_param| match ty_param { + syn::GenericParam::Type(param) => param.clone(), + _ => panic!(), + }) + .collect(); + + let args = &external.inputs; + + let (args, tys): (Vec, Vec) = args + .iter() + .map(|arg| match arg { + FnArg::Typed(pat_type) => match &*pat_type.pat { + Pat::Ident(pat_ident) => { + let ident: Ident = pat_ident.ident.clone(); + let ty: Type = *pat_type.ty.clone(); + (ident, ty) + } + _ => panic!(), + }, + _ => panic!(), + }) + .unzip(); + + let ret_type = match &external.ret_type { + ReturnType::Type(_, rtype) => *rtype.clone(), + _ => panic!(), + }; + + let global = quote! { + #[allow(non_upper_case_globals)] + static #global_name: ::once_cell::sync::Lazy<#tvm_rt_crate::Function> = + ::once_cell::sync::Lazy::new(|| { + #tvm_rt_crate::Function::get(#ext_name) + .expect(concat!("unable to load external function", stringify!(#ext_name), "from TVM registry.")) + }); + }; + + items.push(global); + + let wrapper = quote! { + pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> { + let func_ref: #tvm_rt_crate::Function = #global_name.clone(); + let func_ref: Box Result<#ret_type, #err_type>> = func_ref.to_boxed_fn(); + let res: #ret_type = func_ref(#(#args),*)?; + Ok(res) + } + }; + + items.push(wrapper); + } + + proc_macro::TokenStream::from(quote! { + #(#items + )* + }) +} diff --git a/rust/tvm-macros/src/import_module.rs b/rust/tvm-macros/src/import_module.rs new file mode 100644 index 000000000000..6b059ae363f8 --- /dev/null +++ b/rust/tvm-macros/src/import_module.rs @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use quote::quote; +use std::{fs::File, io::Read}; +use syn::parse::{Parse, ParseStream, Result}; +use syn::LitStr; + +use std::path::PathBuf; + +struct ImportModule { + importing_file: LitStr, +} + +impl Parse for ImportModule { + fn parse(input: ParseStream) -> Result { + let importing_file: LitStr = input.parse()?; + Ok(ImportModule { importing_file }) + } +} + +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let import_module_args = syn::parse_macro_input!(input as ImportModule); + + let manifest = + std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); + + let mut path = PathBuf::new(); + path.push(manifest); + path = path.join(import_module_args.importing_file.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, ref nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } + }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[ArgValue]) -> Result { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(RetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) +} diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs new file mode 100644 index 000000000000..603e1ceaafcc --- /dev/null +++ b/rust/tvm-macros/src/lib.rs @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; + +mod external; +mod import_module; +mod object; +mod util; + +#[proc_macro] +pub fn import_module(input: TokenStream) -> TokenStream { + import_module::macro_impl(input) +} + +#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +pub fn macro_impl(input: TokenStream) -> TokenStream { + // let input = proc_macro2::TokenStream::from(input); + TokenStream::from(object::macro_impl(input)) +} + +#[proc_macro] +pub fn external(input: TokenStream) -> TokenStream { + external::macro_impl(input) +} diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs new file mode 100644 index 000000000000..bee22c367189 --- /dev/null +++ b/rust/tvm-macros/src/object.rs @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::DeriveInput; +use syn::Ident; + +use crate::util::get_tvm_rt_crate; + +pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { + let tvm_rt_crate = get_tvm_rt_crate(); + let derive_input = syn::parse_macro_input!(input as DeriveInput); + let payload_id = derive_input.ident; + + let mut type_key = None; + let mut ref_name = None; + let base = Some(Ident::new("base", Span::call_site())); + + for attr in derive_input.attrs { + if attr.path.is_ident("type_key") { + type_key = Some(attr.parse_meta().expect("foo")) + } + + if attr.path.is_ident("ref_name") { + ref_name = Some(attr.parse_meta().expect("foo")) + } + } + + let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key { + match name_value.lit { + syn::Lit::Str(type_key) => type_key, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name { + match name_value.lit { + syn::Lit::Str(ref_name) => ref_name, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_id = Ident::new(&ref_name.value(), Span::call_site()); + let base = base.expect("should be present"); + + let expanded = quote! { + unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { + const TYPE_KEY: &'static str = #type_key; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.#base.as_object() + } + } + + #[derive(Clone)] + pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); + + impl #tvm_rt_crate::object::ToObjectRef for #ref_id { + fn to_object_ref(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } + } + + impl std::ops::Deref for #ref_id { + type Target = #payload_id; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().unwrap() + } + } + + impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let oref: ObjectRef = ret_val.try_into()?; + let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; + let ptr = ptr.downcast::<#payload_id>()?; + Ok(#ref_id(Some(ptr))) + } + } + + impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + #tvm_rt_crate::ArgValue:: + ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { + let oref: #ref_id = object_ref.clone(); + #tvm_rt_crate::ArgValue::<'a>::from(oref) + } + } + + impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl From<#ref_id> for #tvm_rt_crate::RetValue { + fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + }; + + TokenStream::from(expanded) +} diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs new file mode 100644 index 000000000000..1e720f04dfef --- /dev/null +++ b/rust/tvm-macros/src/util.rs @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro2::TokenStream; +use quote::quote; +use std::env; + +pub fn get_tvm_rt_crate() -> TokenStream { + if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" { + quote!(crate) + } else { + quote!(tvm_rt) + } +} diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore new file mode 100644 index 000000000000..2430329c78b6 --- /dev/null +++ b/rust/tvm-rt/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml new file mode 100644 index 000000000000..465ae583ab6c --- /dev/null +++ b/rust/tvm-rt/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-rt" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust bindings for the TVM runtime API." +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +ndarray = "0.12" +num-traits = "0.2" +tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } +tvm-macros = { version = "0.1", path = "../tvm-macros" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[dev-dependencies] +anyhow = "^1.0" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md new file mode 100644 index 000000000000..7c87939db301 --- /dev/null +++ b/rust/tvm-rt/README.md @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime. +Currently this is tested on `1.42.0` and above. + +## What Does This Crate Offer? + +TVM is an end-to-end deep learning compiler which takes high level machine learning +models or tensor computations and lowers them into executable code for a variety +of heterogenous devices (e.g., CPU, GPU). + +This crate provides access to the APIs for manipulating runtime data structures, +as well as TVM's cross-language Object system which functions similarly to systems +such as COM, enabling cross-language interoperability. + +## Installations + +Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions, +`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +### Example of registering a cross-language closure. + +One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom` +and return types which implement `Into`. Once registered with TVM these functions can be +accessed via Python or C++, or any other language which implements the TVM packed function convention +see `docs.tvm.ai` for more information. + +```rust +use tvm_rt::{ArgValue, RetValue}; +use tvm_rt::function::{Function, Result, register}; + +fn sum(x: i64, y: i64, z: i64) -> i64 { + x + y + z +} + +fn main() { + register(sum, "mysum".to_owned()).unwrap(); + let func = Function::get("mysum").unwrap(); + let boxed_fn = func.to_boxed_fn:: Result>(); + let ret = boxed_fn(10, 20, 30).unwrap(); + assert_eq!(ret, 60); +} +``` diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs new file mode 100644 index 000000000000..b0fea33c6c61 --- /dev/null +++ b/rust/tvm-rt/src/context.rs @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::os::raw::c_void; +use std::ptr; + +use crate::errors::Error; + +use tvm_sys::ffi; + +pub use tvm_sys::context::*; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; + fn sync(&self) -> Result<(), Error>; + fn max_threads_per_block(&self) -> isize; + fn warp_size(&self) -> isize; + fn max_shared_memory_per_block(&self) -> isize; + fn compute_version(&self) -> isize; + fn device_name(&self) -> isize; + fn max_clock_rate(&self) -> isize; + fn multi_processor_count(&self) -> isize; + fn max_thread_dimensions(&self) -> isize; +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + fn $attr_name(&self) -> isize { + get_device_attr(self.device_type as i32, self.device_id as i32, 0) + .expect("should not fail") as isize + } + + )+ + }; +} + +crate::external! { + #[name("runtime.GetDeviceAttr")] + fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32; +} + +impl ContextExt for Context { + fn exist(&self) -> bool { + let exists = get_device_attr(self.device_type as i32, self.device_id as i32, 0) + .expect("should not fail"); + + exists != 0 + } + + /// Synchronize the context stream. + fn sync(&self) -> Result<(), Error> { + check_call!(ffi::TVMSynchronize( + self.device_type as i32, + self.device_id as i32, + ptr::null_mut() as *mut c_void + )); + Ok(()) + } + + impl_device_attrs!((max_threads_per_block, 1); + (warp_size, 2); + (max_shared_memory_per_block, 3); + (compute_version, 4); + (device_name, 5); + (max_clock_rate, 6); + (multi_processor_count, 7); + (max_thread_dimensions, 8)); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sync() { + let ctx = Context::cpu(0); + assert!(ctx.sync().is_ok()) + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs new file mode 100644 index 000000000000..0b45ebf445bf --- /dev/null +++ b/rust/tvm-rt/src/errors.rs @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::DataType; +use thiserror::Error; + +#[derive(Debug, Error)] +#[error("Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; + +#[derive(Debug, Error)] +#[error("Expected type `{expected}` but found `{actual}`")] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, +} + +#[derive(Debug, Error)] +pub enum NDArrayError { + #[error("Missing NDArray shape.")] + MissingShape, + #[error("Cannot convert from an empty array.")] + EmptyArray, + #[error("Invalid datatype when attempting to convert ndarray.")] + InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError), + #[error("a shape error occurred in the Rust ndarray library")] + ShapeError(#[from] ndarray::ShapeError), + #[error("Expected type `{expected}` but found `{actual}`")] + DataTypeMismatch { + expected: DataType, + actual: DataType, + }, +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("{0}")] + Downcast(#[from] tvm_sys::errors::ValueDowncastError), + #[error("raw pointer passed across boundary was null")] + Null, + #[error("failed to load module due to invalid path {0}")] + ModuleLoadPath(String), + #[error("failed to convert String into CString due to embedded nul character")] + ToCString(#[from] std::ffi::NulError), + #[error("failed to convert CString into String")] + FromCString(#[from] std::ffi::IntoStringError), + #[error("Handle `{0}` is null.")] + NullHandle(String), + #[error("{0}")] + NDArray(#[from] NDArrayError), + #[error("{0}")] + CallFailed(String), +} + +impl Error { + pub fn downcast(actual_type: String, expected_type: &'static str) -> Error { + Self::Downcast(tvm_sys::errors::ValueDowncastError { + actual_type, + expected_type, + }) + } +} diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs new file mode 100644 index 000000000000..cb8777a6227b --- /dev/null +++ b/rust/tvm-rt/src/function.rs @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::convert::TryFrom; +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + ptr, str, +}; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::errors::Error; + +use super::to_boxed_fn::ToBoxedFn; +use super::to_function::{ToFunction, Typed}; + +pub type Result = std::result::Result; + +/// Wrapper around TVM function handle which includes `is_global` +/// indicating whether the function is global or not, and `is_cloned` showing +/// not to drop a cloned function from Rust side. +/// The value of these fields can be accessed through their respective methods. +#[derive(Debug, Hash)] +pub struct Function { + pub(crate) handle: ffi::TVMFunctionHandle, + // whether the registered function is global or not. + is_global: bool, + from_rust: bool, +} + +unsafe impl Send for Function {} +unsafe impl Sync for Function {} + +impl Function { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + Function { + handle, + is_global: false, + from_rust: false, + } + } + + /// For a given function, it returns a function by name. + pub fn get>(name: S) -> Option { + let name = CString::new(name.as_ref()).unwrap(); + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + + check_call!(ffi::TVMFuncGetGlobal( + name.as_ptr() as *const c_char, + &mut handle as *mut _ + )); + + if handle.is_null() { + None + } else { + Some(Function { + handle, + is_global: true, + from_rust: false, + }) + } + } + + pub fn get_boxed>(name: S) -> Option> + where + F: ToBoxedFn, + { + Self::get(name).map(|f| f.to_boxed_fn::()) + } + + /// Returns the underlying TVM function handle. + pub fn handle(&self) -> ffi::TVMFunctionHandle { + self.handle + } + + /// Returns `true` if the underlying TVM function is global and `false` otherwise. + pub fn is_global(&self) -> bool { + self.is_global + } + + /// Calls the function that created from `Builder`. + pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { + let num_args = arg_buf.len(); + let (mut values, mut type_codes): (Vec, Vec) = + arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); + let mut ret_val = ffi::TVMValue { v_int64: 0 }; + let mut ret_type_code = 0i32; + + check_call!(ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr() as *mut ffi::TVMValue, + type_codes.as_mut_ptr() as *mut c_int, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + + Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32)) + } + + pub fn to_boxed_fn(self) -> Box + where + F: ToBoxedFn, + { + F::to_boxed_fn(self) + } +} + +impl Clone for Function { + fn clone(&self) -> Function { + Self { + handle: self.handle, + is_global: self.is_global, + from_rust: true, + } + } +} + +// impl Drop for Function { +// fn drop(&mut self) { +// if !self.is_global && !self.is_cloned { +// check_call!(ffi::TVMFuncFree(self.handle)); +// } +// } +// } + +impl From for RetValue { + fn from(func: Function) -> RetValue { + RetValue::FuncHandle(func.handle) + } +} + +impl TryFrom for Function { + type Error = Error; + + fn try_from(ret_value: RetValue) -> Result { + match ret_value { + RetValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast( + format!("{:?}", ret_value), + "FunctionHandle", + )), + } + } +} + +impl<'a> From for ArgValue<'a> { + fn from(func: Function) -> ArgValue<'a> { + ArgValue::FuncHandle(func.handle) + } +} + +impl<'a> TryFrom> for Function { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), + } + } +} + +impl<'a> TryFrom<&ArgValue<'a>> for Function { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), + } + } +} + +/// Registers a Rust function with an arbitrary type signature in +/// the TVM registry. +/// +/// +/// A function is convertible if and only if its arguments and return types are convertible +/// to and from TVM values respectively. +/// +/// Use [`register_override`] if control of overriding existing global TVM function +/// is required, this function will panic if a function is already registered. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, RetValue}; +/// # use tvm_rt::function::{Function, Result, register}; +/// +/// fn sum(x: i64, y: i64, z: i64) -> i64 { +/// x + y + z +/// } +/// +/// register(sum, "mysum".to_owned()).unwrap(); +/// let func = Function::get("mysum").unwrap(); +/// let boxed_fn = func.to_boxed_fn:: Result>(); +/// let ret = boxed_fn(10, 20, 30).unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register>(f: F, name: S) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + register_override(f, name, false) +} + +/// Register a function with explicit control over whether to override an existing registration or not. +/// +/// See `register` for more details on how to use the registration API. +pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + let func = f.to_function(); + let name = name.into(); + // Not sure about this code + let handle = func.handle(); + let name = CString::new(name)?; + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), + handle, + override_ as c_int + )); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::Function; + + static CANARY: &str = "runtime.ModuleLoadFromFile"; + + #[test] + fn get_fn() { + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); + } + + #[test] + fn register_and_call_closure0() { + use crate::function; + use function::Result; + + fn constfn() -> i64 { + return 10; + } + + function::register_override(constfn, "constfn".to_owned(), true).unwrap(); + + let func = Function::get_boxed:: Result, _>("constfn").unwrap(); + let ret = func().unwrap(); + assert_eq!(ret, 10); + } + + #[test] + fn register_and_call_closure1() { + use crate::function::{self}; + + fn ident(x: i64) -> i64 { + return x; + } + + function::register_override(ident, "ident".to_owned(), true).unwrap(); + let func = Function::get_boxed:: Result, _>("ident").unwrap(); + assert_eq!(func(60).unwrap(), 60); + } +} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs new file mode 100644 index 000000000000..10f8317bf7bd --- /dev/null +++ b/rust/tvm-rt/src/lib.rs @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime. +//! +//! The TVM runtime API contains the data structures used by higher-level TVM executors. +//! Specifically it exposes the basic types such as NDArray, as well as the more general object system. +//! The TVM object system enables cross-language interoperability including that of closures for all +//! supported languages including C++, and Python. + +pub mod object; +pub mod string; + +pub use object::*; +pub use string::*; + +use std::{ + ffi::{CStr, CString}, + str, +}; + +pub use crate::{ + context::{Context, DeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, +}; + +pub use function::{ArgValue, RetValue}; +pub use tvm_sys::byte_array::ByteArray; +pub use tvm_sys::datatype::DataType; +use tvm_sys::ffi; + +pub use tvm_macros::external; + +// Macro to check the return call to TVM runtime shared library. + +#[macro_export] +macro_rules! tvm_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + Err($crate::get_last_error().into()) + } else { + Ok(()) + } + }}; +} + +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +/// Gets the last error message. +pub fn get_last_error() -> &'static str { + unsafe { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + } +} + +pub(crate) fn set_last_error(err: &E) { + let c_string = CString::new(err.to_string()).unwrap(); + unsafe { + ffi::TVMAPISetLastError(c_string.as_ptr()); + } +} + +#[macro_use] +pub mod function; +pub mod context; +pub mod errors; +pub mod module; +pub mod ndarray; +pub mod to_boxed_fn; +mod to_function; +pub mod value; + +/// Outputs the current TVM version. +pub fn version() -> &'static str { + match str::from_utf8(ffi::TVM_VERSION) { + Ok(s) => s, + Err(_) => "Invalid UTF-8 string", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_version() { + println!("TVM version: {}", version()); + } + + #[test] + fn set_error() { + let err = errors::NDArrayError::EmptyArray; + set_last_error(&err); + assert_eq!( + get_last_error().trim(), + errors::NDArrayError::EmptyArray.to_string() + ); + } +} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs new file mode 100644 index 000000000000..b540c1ba9981 --- /dev/null +++ b/rust/tvm-rt/src/module.rs @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides the [`Module`] type and methods for working with runtime TVM modules. + +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + path::Path, + ptr, +}; + +use tvm_sys::ffi; + +use crate::errors::Error; +use crate::{errors, function::Function}; + +const ENTRY_FUNC: &str = "__tvm_main__"; + +/// Wrapper around TVM module handle which contains an entry function. +/// The entry function can be applied to an imported module through [`entry_func`]. +/// +/// [`entry_func`]:struct.Module.html#method.entry_func +#[derive(Debug, Clone)] +pub struct Module { + pub(crate) handle: ffi::TVMModuleHandle, + entry_func: Option, +} + +crate::external! { + #[name("runtime.RuntimeEnabled")] + fn runtime_enabled(target: CString) -> i32; + + #[name("runtime.ModuleLoadFromFile")] + fn load_from_file(file_name: CString, format: CString) -> Module; +} + +impl Module { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { + Self { + handle, + entry_func: None, + } + } + + pub fn entry(&mut self) -> Option { + if self.entry_func.is_none() { + self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); + } + self.entry_func.clone() + } + + /// Gets a function by name from a registered module. + pub fn get_function(&self, name: &str, query_import: bool) -> Result { + let name = CString::new(name)?; + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( + self.handle, + name.as_ptr() as *const c_char, + query_import as c_int, + &mut fhandle as *mut _ + )); + + if !fhandle.is_null() { + return Err(errors::Error::NullHandle(name.into_string()?.to_string())); + } + + Ok(Function::new(fhandle)) + } + + /// Imports a dependent module such as `.ptx` for gpu. + pub fn import_module(&self, dependent_module: Module) { + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + } + + /// Loads a module shared library from path. + pub fn load>(path: &P) -> Result { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or_else(|| std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, + )?; + + let cpath = CString::new( + path.as_ref() + .to_str() + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, + )?; + + let module = load_from_file(cpath, ext)?; + Ok(module) + } + + /// Checks if a target device is enabled for a module. + pub fn enabled(&self, target: &str) -> bool { + let target = CString::new(target).unwrap(); + let enabled = runtime_enabled(target).unwrap(); + enabled != 0 + } + + /// Returns the underlying module handle. + pub fn handle(&self) -> ffi::TVMModuleHandle { + self.handle + } +} + +impl Drop for Module { + fn drop(&mut self) { + check_call!(ffi::TVMModFree(self.handle)); + } +} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs new file mode 100644 index 000000000000..b7ae4622849d --- /dev/null +++ b/rust/tvm-rt/src/ndarray.rs @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::convert::TryInto; +use std::ffi::c_void; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use crate::errors::NDArrayError; + +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; + +use ndarray::{Array, ArrayD}; +use num_traits::Num; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub enum NDArray { + Borrowed { handle: ffi::TVMArrayHandle }, + Owned { handle: *mut c_void }, +} + +impl NDArray { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { + NDArray::Borrowed { handle } + } + + pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { + NDArray::Owned { handle } + } + + pub fn as_dltensor(&self) -> &DLTensor { + let ptr: *mut DLTensor = match self { + NDArray::Borrowed { ref handle } => *handle, + NDArray::Owned { ref handle } => *handle as *mut DLTensor, + }; + + unsafe { std::mem::transmute(ptr) } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + match self { + NDArray::Borrowed { handle } => *handle, + NDArray::Owned { handle } => *handle as *mut DLTensor, + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result { + Ok(match self.strides() { + None => true, + Some(strides) => { + // NDArrayError::MissingShape in case shape is not determined + self.shape() + .ok_or(NDArrayError::MissingShape)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::().unwrap(), data); + /// ``` + pub fn to_vec(&self) -> Result, NDArrayError> { + if !self.shape().is_some() { + return Err(NDArrayError::EmptyArray); + } + let earr = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(NDArrayError::MissingShape)?; + let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result { + let v = self.to_vec::()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if self.dtype() != target.dtype() { + return Err(NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype(), + }); + } + + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray( + rnd: &ArrayD, + ctx: Context, + dtype: DataType, + ) -> Result { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + } + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + }; + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = Context::cpu(0); + let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } + + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = Context::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert!(ndarray.to_vec::().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let shape = vec![4]; + let e = NDArray::empty( + &shape, + Context::cpu(0), + DataType::from_str("int32").unwrap(), + ); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + .unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } +} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs new file mode 100644 index 000000000000..c49f84e2d916 --- /dev/null +++ b/rust/tvm-rt/src/object/mod.rs @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryFrom; +use std::convert::TryInto; +use std::ffi::CString; + +use crate::errors::Error; +use crate::external; + +use tvm_sys::{ArgValue, RetValue}; + +mod object_ptr; + +pub use object_ptr::{IsObject, Object, ObjectPtr}; + +#[derive(Clone)] +pub struct ObjectRef(pub Option>); + +impl ObjectRef { + pub fn null() -> ObjectRef { + ObjectRef(None) + } +} + +pub trait ToObjectRef { + fn to_object_ref(&self) -> ObjectRef; +} + +impl ToObjectRef for ObjectRef { + fn to_object_ref(&self) -> ObjectRef { + self.clone() + } +} + +impl TryFrom for ObjectRef { + type Error = Error; + + fn try_from(ret_val: RetValue) -> Result { + let optr = ret_val.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl From for RetValue { + fn from(object_ref: ObjectRef) -> RetValue { + use std::ffi::c_void; + let object_ptr = object_ref.0; + match object_ptr { + None => RetValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> std::convert::TryFrom> for ObjectRef { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result { + let optr = arg_value.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result { + // TODO(@jroesch): remove the clone + let value: ArgValue<'a> = arg_value.clone(); + ObjectRef::try_from(value) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(object_ref: ObjectRef) -> ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> From<&ObjectRef> for ArgValue<'a> { + fn from(object_ref: &ObjectRef) -> ArgValue<'a> { + let oref: ObjectRef = object_ref.clone(); + ArgValue::<'a>::from(oref) + } +} + +external! { + #[name("ir.DebugPrint")] + fn debug_print(object: ObjectRef) -> CString; +} + +// external! { +// #[name("ir.TextPrinter")] +// fn as_text(object: ObjectRef) -> CString; +// } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs new file mode 100644 index 000000000000..40e218454f6a --- /dev/null +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryFrom; +use std::ffi::CString; +use std::ptr::NonNull; +use std::sync::atomic::AtomicI32; + +use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::{ArgValue, RetValue}; + +use crate::errors::Error; + +type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); + +#[derive(Debug)] +#[repr(C)] +pub struct Object { + pub type_index: u32, + // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. + // NB: in general we should not touch this in Rust. + pub(self) ref_count: AtomicI32, + pub fdeleter: Deleter, +} + +unsafe extern "C" fn delete(object: *mut Object) { + let typed_object: *mut T = std::mem::transmute(object); + T::typed_delete(typed_object); +} + +fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { + let mut is_derived = 0; + crate::check_call!(ffi::TVMObjectDerivedFrom( + child_type_index, + parent_type_index, + &mut is_derived + )); + + if is_derived == 0 { + false + } else { + true + } +} + +impl Object { + fn new(type_index: u32, deleter: Deleter) -> Object { + Object { + type_index, + // Note: do not touch this field directly again, this is + // a critical section, we write a 1 to the atomic which will now + // be managed by the C++ atomics. + // In the future we should probably use C-atomcis. + ref_count: AtomicI32::new(0), + fdeleter: deleter, + } + } + + fn get_type_index() -> u32 { + let type_key = T::TYPE_KEY; + let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { + return 0; + } else { + let mut index = 0; + unsafe { + let index_ptr = std::mem::transmute(&mut index); + if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 { + panic!(crate::get_last_error()) + } + } + return index; + } + } + + pub fn base_object() -> Object { + let index = Object::get_type_index::(); + Object::new(index, delete::) + } + + pub(self) fn inc_ref(&self) { + unsafe { + let raw_ptr = std::mem::transmute(self); + assert_eq!(TVMObjectRetain(raw_ptr), 0); + } + } + + pub(self) fn dec_ref(&self) { + unsafe { + let raw_ptr = std::mem::transmute(self); + assert_eq!(TVMObjectFree(raw_ptr), 0); + } + } +} + +pub unsafe trait IsObject { + const TYPE_KEY: &'static str; + + fn as_object<'s>(&'s self) -> &'s Object; + + unsafe extern "C" fn typed_delete(object: *mut Self) { + let object = Box::from_raw(object); + drop(object) + } +} + +unsafe impl IsObject for Object { + const TYPE_KEY: &'static str = "Object"; + + fn as_object<'s>(&'s self) -> &'s Object { + self + } +} + +#[repr(C)] +pub struct ObjectPtr { + pub ptr: NonNull, +} + +fn inc_ref(ptr: NonNull) { + unsafe { ptr.as_ref().as_object().inc_ref() } +} + +fn dec_ref(ptr: NonNull) { + unsafe { ptr.as_ref().as_object().dec_ref() } +} + +impl ObjectPtr { + fn from_raw(object_ptr: *mut Object) -> Option> { + let non_null = NonNull::new(object_ptr); + non_null.map(|ptr| ObjectPtr { ptr }) + } +} + +impl Clone for ObjectPtr { + fn clone(&self) -> Self { + inc_ref(self.ptr); + ObjectPtr { ptr: self.ptr } + } +} + +impl Drop for ObjectPtr { + fn drop(&mut self) { + dec_ref(self.ptr); + } +} + +impl ObjectPtr { + pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut T + where + T: 'a, + { + unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() } + } + + pub fn new(object: T) -> ObjectPtr { + let object_ptr = Box::new(object); + let object_ptr = Box::leak(object_ptr); + let ptr = NonNull::from(object_ptr); + inc_ref(ptr); + ObjectPtr { ptr } + } + + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.as_object() + .ref_count + .load(std::sync::atomic::Ordering::SeqCst) + } + + fn as_object<'s>(&'s self) -> &'s Object { + unsafe { self.ptr.as_ref().as_object() } + } + + pub fn upcast(&self) -> ObjectPtr { + ObjectPtr { + ptr: self.ptr.cast(), + } + } + + pub fn downcast(&self) -> Result, Error> { + let child_index = Object::get_type_index::(); + let object_index = self.as_object().type_index; + + let is_derived = if child_index == object_index { + true + } else { + // TODO(@jroesch): write tests + derived_from(object_index, child_index) + }; + + if is_derived { + Ok(ObjectPtr { + ptr: self.ptr.cast(), + }) + } else { + Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + } + } +} + +impl std::ops::Deref for ObjectPtr { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { self.ptr.as_ref() } + } +} + +impl<'a, T: IsObject> From> for RetValue { + fn from(object_ptr: ObjectPtr) -> RetValue { + let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + RetValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom for ObjectPtr { + type Error = Error; + + fn try_from(ret_value: RetValue) -> Result, Self::Error> { + match ret_value { + RetValue::ObjectHandle(handle) => { + let handle: *mut Object = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), + } + } +} + +impl<'a, T: IsObject> From> for ArgValue<'a> { + fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { + let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + ArgValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom> for ObjectPtr { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), + } + } +} + +impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), + } + } +} + +#[cfg(test)] +mod tests { + use super::{Object, ObjectPtr}; + use anyhow::{ensure, Result}; + use std::convert::TryInto; + use tvm_sys::{ArgValue, RetValue}; + + #[test] + fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) + } + + #[test] + fn roundtrip_retvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let ret_value: RetValue = ptr.clone().into(); + let ptr2: ObjectPtr = ret_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + #[test] + fn roundtrip_argvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let arg_value: ArgValue = ptr.clone().into(); + let ptr2: ObjectPtr = arg_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + fn test_fn(o: ObjectPtr) -> ObjectPtr { + assert_eq!(o.count(), 2); + return o; + } + + #[test] + fn test_ref_count_boundary() { + use super::*; + use crate::function::{register, Function, Result}; + let ptr = ObjectPtr::new(Object::base_object::()); + let stay = ptr.clone(); + assert_eq!(ptr.count(), 2); + register(test_fn, "my_func").unwrap(); + let func = Function::get("my_func").unwrap(); + let func = func.to_boxed_fn::) -> Result>>(); + func(ptr).unwrap(); + assert_eq!(stay.count(), 1); + } +} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs new file mode 100644 index 000000000000..26758b1170e7 --- /dev/null +++ b/rust/tvm-rt/src/string.rs @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::ffi::{CString, NulError}; +use std::os::raw::c_char; + +use super::errors::Error; +use super::{Object, ObjectPtr, ObjectRef}; + +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "String"] +#[type_key = "runtime.String"] +pub struct StringObj { + base: Object, + data: *const c_char, + size: u64, +} + +impl String { + pub fn new(string: std::string::String) -> Result { + let cstring = CString::new(string)?; + + // The string is being corrupted. + // why is this wrong + let length = cstring.as_bytes().len(); + + let string_obj = StringObj { + base: Object::base_object::(), + data: cstring.into_raw(), + size: length as u64, + }; + + let object_ptr = ObjectPtr::new(string_obj); + Ok(String(Some(object_ptr))) + } + + pub fn to_cstring(&self) -> Result { + use std::slice; + let ptr = self.0.as_ref().unwrap().data; + let size = self.0.as_ref().unwrap().size; + unsafe { + let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); + CString::new(slice) + } + } + + pub fn to_string(&self) -> Result { + let string = self.to_cstring()?.into_string()?; + Ok(string) + } +} + +// #[cfg(test)] +// mod tests { +// use super::String; +// use crate::object::debug_print; +// use crate::ToObjectRef; +// use anyhow::{ensure, Result}; + +// #[test] +// fn test_string_debug() -> Result<()> { +// let s = String::new("foo".to_string()).unwrap(); +// let object_ref = s.to_object_ref(); +// println!("about to call"); +// let string = debug_print(object_ref)?; +// println!("after call"); +// ensure!( +// string.into_string().expect("is cstring").contains("foo"), +// "string content is invalid" +// ); +// Ok(()) +// } +// } diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs new file mode 100644 index 000000000000..f0e5e80ff2ad --- /dev/null +++ b/rust/tvm-rt/src/to_boxed_fn.rs @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides a method for converting type erased TVM functions +//! into a boxed Rust closure. +//! +//! To call a registered function check the [`ToBoxedFn::to_boxed_fn`] method. +//! +//! See the tests and examples repository for more examples. + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::{errors, Module}; + +use super::function::{Function, Result}; + +pub trait ToBoxedFn { + fn to_boxed_fn(func: Function) -> Box; +} + +use std::convert::{TryFrom, TryInto}; + +impl ToBoxedFn for dyn Fn() -> Result +where + errors::Error: From, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move || { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A) -> Result +where + errors::Error: From, + A: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + C: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B, c: C| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C, D) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + C: Into>, + D: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B, c: C, d: D| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + builder.arg(d.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +/// Function builder in order to create and call functions. +/// +/// *Note:* Currently TVM functions accept *at most* one return value. +#[derive(Default)] +pub struct Builder<'a> { + pub func: Option, + pub arg_buf: Vec>, + pub ret_buf: Option, +} + +impl<'a, 'm> Builder<'a> { + pub fn new( + func: Option, + arg_buf: Vec>, + ret_buf: Option, + ) -> Self { + Self { + func, + arg_buf, + ret_buf, + } + } + + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); + self + } + + /// Pushes a [`ArgValue`] into the function argument buffer. + pub fn arg(&mut self, arg: T) -> &mut Self + where + ArgValue<'a>: From, + { + self.arg_buf.push(arg.into()); + self + } + + /// Pushes multiple [`ArgValue`]s into the function argument buffer. + pub fn args(&mut self, args: I) -> &mut Self + where + I: IntoIterator, + ArgValue<'a>: From, + { + args.into_iter().for_each(|arg| { + self.arg(arg); + }); + self + } + + /// Sets an output for a function that requires a mutable output to be provided. + /// See the `basics` in tests for an example. + pub fn set_output(&mut self, ret: T) -> &mut Self + where + RetValue: From, + { + self.ret_buf = Some(ret.into()); + self + } + + pub fn invoke(self) -> Result { + self.func.unwrap().invoke(self.arg_buf) + } +} + +/// Converts a [`Function`] to builder. Currently, this is the best way to work with +/// TVM functions. +impl<'a, 'm> From for Builder<'a> { + fn from(func: Function) -> Self { + Builder::new(Some(func), Vec::new(), None) + } +} + +/// Converts a mutable reference of a [`Module`] to [`Builder`]. +impl<'a, 'm> From<&'m mut Module> for Builder<'a> { + fn from(module: &'m mut Module) -> Self { + Builder::new(module.entry(), Vec::new(), None) + } +} +#[cfg(test)] +mod tests { + use crate::function::{self, Function, Result}; + + #[test] + fn to_boxed_fn0() { + fn boxed0() -> i64 { + return 10; + } + + function::register_override(boxed0, "boxed0".to_owned(), true).unwrap(); + let func = Function::get("boxed0").unwrap(); + let typed_func: Box Result> = func.to_boxed_fn(); + assert_eq!(typed_func().unwrap(), 10); + } +} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs new file mode 100644 index 000000000000..4814d098238a --- /dev/null +++ b/rust/tvm-rt/src/to_function.rs @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::convert::{TryFrom, TryInto}; +use std::{ + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use super::{function::Result, Function}; +use crate::errors::Error; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +/// A trait representing whether the function arguments +/// and return type can be assigned to a TVM packed function. +/// +/// By splitting the conversion to function into two traits +/// we are able to improve error reporting, by splitting the +/// conversion of inputs and outputs to this trait. +/// +/// And the implementation of it to `ToFunction`. +pub trait Typed { + fn args(i: &[ArgValue<'static>]) -> Result; + fn ret(o: O) -> RetValue; +} + +impl> Typed<(), O> for F +where + F: Fn() -> O, +{ + fn args(_args: &[ArgValue<'static>]) -> Result<()> { + debug_assert!(_args.len() == 0); + Ok(()) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A,), O> for F +where + F: Fn(A) -> O, + Error: From, + A: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + Ok((a,)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A, B), O> for F +where + F: Fn(A, B) -> O, + Error: From, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { + debug_assert!(args.len() == 2); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + Ok((a, b)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A, B, C), O> for F +where + F: Fn(A, B, C) -> O, + Error: From, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, + C: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { + debug_assert!(args.len() == 3); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + let c: C = args[2].clone().try_into()?; + Ok((a, b, c)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +pub trait ToFunction: Sized { + type Handle; + + fn into_raw(self) -> *mut Self::Handle; + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result + where + Self: Typed; + + fn drop(handle: *mut Self::Handle); + + fn to_function(self) -> Function + where + Self: Typed, + { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = self.into_raw(); + + check_call!(ffi::TVMFuncCreateFromCFunc( + Some(Self::tvm_callback), + resource_handle as *mut _, + None, // Some(Self::tvm_finalizer), + &mut fhandle as *mut ffi::TVMFunctionHandle, + )); + + Function::new(fhandle) + } + + /// The callback function which is wrapped converted by TVM + /// into a packed function stored in fhandle. + unsafe extern "C" fn tvm_callback( + args: *mut ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ffi::TVMRetValueHandle, + resource_handle: *mut c_void, + ) -> c_int + where + Self: Typed, + { + #![allow(unused_assignments, unused_unsafe)] + // turning off the incorrect linter complaints + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec = Vec::new(); + let mut value = ffi::TVMValue { v_int64: 0 }; + let mut tcode = 0; + let resource_handle = resource_handle as *mut Self::Handle; + for i in 0..len { + value = args_list[i]; + tcode = type_codes_list[i]; + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + } + let arg_value = ArgValue::from_tvm_value(value, tcode as u32); + local_args.push(arg_value); + } + + let rv = match Self::call(resource_handle, local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); + let mut ret_type_code = ret_tcode as c_int; + + check_call!(ffi::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 + } + + /// The finalizer which is invoked when the packed function's + /// reference count is zero. + unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { + let handle = std::mem::transmute(fhandle); + Self::drop(handle) + } +} + +impl ToFunction<(), O> for F +where + F: Fn() -> O + 'static, +{ + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result + where + F: Typed<(), O>, + { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let out = unsafe { (*handle)() }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} +} + +macro_rules! to_function_instance { + ($(($param:ident,$index:tt),)+) => { + impl ToFunction<($($param,)+), O> for + F where F: Fn($($param,)+) -> O + 'static { + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result where F: Typed<($($param,)+), O> { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let args = F::args(args)?; + let out = unsafe { + (*handle)($(args.$index),+) + }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} + } + } +} + +to_function_instance!((A, 0),); +to_function_instance!((A, 0), (B, 1),); +to_function_instance!((A, 0), (B, 1), (C, 2),); +to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),); + +#[cfg(test)] +mod tests { + use super::{Function, ToFunction, Typed}; + + fn zero() -> i32 { + 10 + } + + fn helper(f: F) -> Function + where + F: ToFunction, + F: Typed, + { + f.to_function() + } + + #[test] + fn test_to_function0() { + helper(zero); + } + + fn one_arg(i: i32) -> i32 { + i + } + + #[test] + fn test_to_function1() { + helper(one_arg); + } + + fn two_arg(i: i32, j: i32) -> i32 { + i + j + } + + #[test] + fn test_to_function2() { + helper(two_arg); + } +} diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs new file mode 100644 index 000000000000..1812c0cfbe45 --- /dev/null +++ b/rust/tvm-rt/src/value.rs @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements [`ArgValue`] and [`RetValue`] types +//! and their conversions needed for the types used in frontend crate. +//! `RetValue` is the owned version of `TVMPODValue`. + +use std::convert::TryFrom; +// use std::ffi::c_void; + +use crate::{ArgValue, Module, NDArray, RetValue}; +use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; + +macro_rules! impl_handle_val { + ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { + impl<'a> From<&'a $type> for ArgValue<'a> { + fn from(arg: &'a $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> From<&'a mut $type> for ArgValue<'a> { + fn from(arg: &'a mut $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> TryFrom> for $type { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> RetValue { + RetValue::$variant(val.handle() as $inner_type) + } + } + + impl TryFrom for $type { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) + } + } + }; +} + +impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); + +impl<'a> From<&'a NDArray> for ArgValue<'a> { + fn from(arg: &'a NDArray) -> Self { + match arg { + &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> From<&'a mut NDArray> for ArgValue<'a> { + fn from(arg: &'a mut NDArray) -> Self { + match arg { + &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> TryFrom> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) + } +} + +impl From for RetValue { + fn from(val: NDArray) -> RetValue { + match val { + NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), + _ => panic!("NYI"), + } + } +} + +impl TryFrom for NDArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> NDArray, + |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |RetValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +#[cfg(test)] +mod tests { + use std::{convert::TryInto, str::FromStr}; + + use crate::{ByteArray, Context, DataType}; + + use super::*; + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } +} diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 40f28f45d76b..9bd95262820f 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -16,9 +16,12 @@ * specific language governing permissions and limitations * under the License. */ +use std::convert::TryFrom; use std::os::raw::c_char; +use crate::errors::ValueDowncastError; use crate::ffi::TVMByteArray; +use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. /// @@ -69,6 +72,39 @@ impl> From for ByteArray { } } +impl TryFrom> for ByteArray { + type Error = ValueDowncastError; + + fn try_from(val: ArgValue<'static>) -> Result { + match val { + ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), + _ => Err(ValueDowncastError { + expected_type: "ByteArray", + actual_type: format!("{:?}", val), + }), + } + } +} + +impl From for RetValue { + fn from(val: ByteArray) -> RetValue { + RetValue::Bytes(val.array) + } +} + +impl TryFrom for ByteArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + match val { + RetValue::Bytes(array) => Ok(ByteArray { array }), + _ => Err(ValueDowncastError { + expected_type: "ByteArray", + actual_type: format!("{:?}", val), + }), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index 5dd414c17960..ccdee3f6f753 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -95,6 +95,16 @@ impl From for DataType { } } +impl From for DLDataType { + fn from(dtype: DataType) -> Self { + Self { + code: dtype.code, + bits: dtype.bits, + lanes: dtype.lanes, + } + } +} + #[derive(Debug, Error)] pub enum ParseDataTypeError { #[error("invalid number: {0}")] diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs index 8479ec62f19f..54fe261ec37e 100644 --- a/rust/tvm-sys/src/errors.rs +++ b/rust/tvm-sys/src/errors.rs @@ -39,7 +39,7 @@ impl FuncCallError { context, message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } .to_str() - .expect("double fault") + .expect("failed while attempting to retrieve the TVM error message") .to_owned(), } } diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index dd28e3603f90..0f455e726d26 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -34,8 +34,13 @@ pub mod ffi { include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - pub type BackendPackedCFunc = - extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + out_ret_value: *mut TVMValue, + out_ret_tcode: *mut u32, + ) -> c_int; } pub mod array; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 699b4dbf8271..97e285ccf0ac 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -185,4 +185,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << '}'; }); + +TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { + std::stringstream ss; + ss << ref; + return ss.str(); +}); + } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ad16f862ac2b..981d0c357e24 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -823,5 +823,12 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { return docs; } +TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { + std::cout << "The program: " << node << std::endl; + auto text = AsText(node, false, nullptr); + std::cout << "The text " << text; + return text; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 958590485d0f..6972d5a76b77 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -151,7 +151,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, // only look unfold non-external calls. BaseFunc base_func = m->Lookup(gv); if (auto* n = base_func.as()) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(GetRef(n), m, cm)); } else { diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 00be440441bd..dc5f1ceabbae 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -234,12 +234,25 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_END(); } +int TVMObjectRetain(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::ObjectInternal::ObjectRetain(obj); + API_END(); +} + int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } +int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { + API_BEGIN(); + *is_derived = + tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index); + API_END(); +} + int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 35642fbb731b..f255b28ad04c 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -38,6 +38,15 @@ namespace runtime { */ class ObjectInternal { public: + /*! + * \brief Retain an object handle. + */ + static void ObjectRetain(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->IncRef(); + } + } + /*! * \brief Free an object handle. */