From 6aae48b3785aeba565138f4bb63884ae9930b2aa Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 4 Mar 2021 10:43:14 -0800 Subject: [PATCH] Fixes for using Python APIs from Rust. (#7085) * Rewrite the Rust Module API and change some imports causing crashes. This commit also updates the docs to remove outdated information. * Renable Python test and remove warnings * Python test still flaky * Fix broken module test * Fix broken test * Reset test file --- python/tvm/relay/analysis/analysis.py | 6 +- .../tvm/relay/analysis/annotated_regions.py | 2 +- python/tvm/relay/analysis/call_graph.py | 4 +- .../relay/backend/graph_runtime_factory.py | 2 +- python/tvm/relay/build_module.py | 16 +- python/tvm/relay/frontend/__init__.py | 3 - python/tvm/relay/frontend/tensorflow.py | 7 +- python/tvm/topi/cuda/__init__.py | 2 - rust/tvm-rt/README.md | 4 +- rust/tvm-rt/src/lib.rs | 28 ++- rust/tvm-rt/src/map.rs | 12 + rust/tvm-rt/src/module.rs | 58 ++--- rust/tvm-rt/src/object/object_ptr.rs | 13 +- rust/tvm-rt/src/to_function.rs | 1 + rust/tvm-rt/src/value.rs | 106 -------- rust/tvm/Cargo.toml | 3 +- rust/tvm/README.md | 233 ++---------------- rust/tvm/src/compiler/graph_rt.rs | 124 ++++++++++ rust/tvm/src/compiler/mod.rs | 20 ++ rust/tvm/src/ir/expr.rs | 12 +- rust/tvm/src/ir/function.rs | 10 +- rust/tvm/src/ir/module.rs | 4 +- rust/tvm/src/ir/relay/mod.rs | 85 +++---- rust/tvm/src/ir/tir.rs | 7 +- rust/tvm/src/ir/ty.rs | 19 +- rust/tvm/src/lib.rs | 4 +- rust/tvm/src/python.rs | 24 +- rust/tvm/src/runtime/graph_rt.rs | 12 +- rust/tvm/tests/basics/src/main.rs | 5 +- rust/tvm/tests/basics/src/tvm_add.py | 1 - src/runtime/module.cc | 2 +- 31 files changed, 392 insertions(+), 437 deletions(-) delete mode 100644 rust/tvm-rt/src/value.rs create mode 100644 rust/tvm/src/compiler/graph_rt.rs create mode 100644 rust/tvm/src/compiler/mod.rs diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 7e49461dff52..48e9ce0643a9 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -20,9 +20,9 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ -from tvm.ir import IRModule -from tvm.relay import transform, build_module -from tvm.runtime.ndarray import cpu +from ...ir import IRModule +from ...relay import transform, build_module +from ...runtime.ndarray import cpu from . import _ffi_api from .feature import Feature diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py index 437b97b0fa16..a18ccb97836b 100644 --- a/python/tvm/relay/analysis/annotated_regions.py +++ b/python/tvm/relay/analysis/annotated_regions.py @@ -17,7 +17,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Regions used in Relay.""" -from tvm.runtime import Object +from ...runtime import Object from . import _ffi_api diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py index 966659aac494..fd9704d0af1f 100644 --- a/python/tvm/relay/analysis/call_graph.py +++ b/python/tvm/relay/analysis/call_graph.py @@ -17,8 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Call graph used in Relay.""" -from tvm.ir import IRModule -from tvm.runtime import Object +from ...ir import IRModule +from ...runtime import Object from ..expr import GlobalVar from . import _ffi_api diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 4c6ac47b71b4..3427a62cd491 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -21,7 +21,7 @@ from tvm.runtime import ndarray -class GraphRuntimeFactoryModule(object): +class GraphRuntimeFactoryModule: """Graph runtime factory module. This is a module of graph runtime factory diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index f05e105ed2a2..79eb7e4f19ff 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -25,7 +25,7 @@ from tvm.ir.transform import PassContext from tvm.tir import expr as tvm_expr -from .. import nd as _nd, autotvm +from .. import nd as _nd, autotvm, register_func from ..target import Target from ..contrib import graph_runtime as _graph_rt from . import _build_module @@ -194,6 +194,20 @@ def get_params(self): return ret +@register_func("tvm.relay.module_export_library") +def _module_export(module, file_name): # fcompile, addons, kwargs? + return module.export_library(file_name) + + +@register_func("tvm.relay.build") +def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): + """A wrapper around build which discards the Python GraphFactoryRuntime. + This wrapper is suitable to be used from other programming languages as + the runtime::Module can be freely passed between language boundaries. + """ + return build(mod, target, target_host, params, mod_name).module + + def build(mod, target=None, target_host=None, params=None, mod_name="default"): # fmt: off # pylint: disable=line-too-long diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 7e16499ccc44..aa8ac4fc7434 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -20,9 +20,6 @@ Contains the model importers currently defined for Relay. """ - -from __future__ import absolute_import - from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 20eb95ba7c00..b75331a4f9a2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1051,10 +1051,11 @@ def _impl(inputs, attr, params, mod): def _sparse_tensor_dense_matmul(): - # Sparse utility from scipy - from scipy.sparse import csr_matrix - def _impl(inputs, attr, params, mod): + # Loading this by default causes TVM to not be loadable from other languages. + # Sparse utility from scipy + from scipy.sparse import csr_matrix + assert len(inputs) == 4, "There should be 4 input tensors" indices_tensor = _infer_value(inputs[0], params, mod).asnumpy() diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 52e64804d692..c2f55668d2e2 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -17,8 +17,6 @@ # pylint: disable=redefined-builtin, wildcard-import """CUDA specific declaration and schedules.""" -from __future__ import absolute_import as _abs - from .conv1d import * from .conv1d_transpose_ncw import * from .conv2d import * diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md index a99eeaa578dd..58b1f8a30a39 100644 --- a/rust/tvm-rt/README.md +++ b/rust/tvm-rt/README.md @@ -17,8 +17,8 @@ # TVM Runtime Support -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime. -Currently this is tested on `1.42.0` and above. +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime, +see [here](https://github.com/apache/tvm/blob/main/rust/tvm/README.md) for more details. ## What Does This Crate Offer? diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 4b163eff9c8f..5f9ab1617378 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -99,7 +99,6 @@ pub mod map; pub mod module; pub mod ndarray; mod to_function; -pub mod value; /// Outputs the current TVM version. pub fn version() -> &'static str { @@ -112,6 +111,8 @@ pub fn version() -> &'static str { #[cfg(test)] mod tests { use super::*; + use crate::{ByteArray, Context, DataType}; + use std::{convert::TryInto, str::FromStr}; #[test] fn print_version() { @@ -127,4 +128,29 @@ mod tests { errors::NDArrayError::EmptyArray.to_string() ); } + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } } diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 5ea48893d86b..d6dfaf3641b8 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -107,6 +107,18 @@ where let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?; oref.downcast() } + + pub fn empty() -> Self { + Self::from_iter(vec![].into_iter()) + } + + //(@jroesch): I don't think this is a correct implementation. + pub fn null() -> Self { + Map { + object: ObjectRef::null(), + _data: PhantomData, + } + } } pub struct IntoIter { diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index c0822a5045e6..6109819939af 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -26,21 +26,24 @@ use std::{ ptr, }; +use crate::object::Object; +use tvm_macros::Object; use tvm_sys::ffi; use crate::errors::Error; +use crate::String as TString; use crate::{errors, function::Function}; -const ENTRY_FUNC: &str = "__tvm_main__"; - /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. /// /// [`entry_func`]:struct.Module.html#method.entry_func -#[derive(Debug, Clone)] -pub struct Module { - pub(crate) handle: ffi::TVMModuleHandle, - entry_func: Option, +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "Module"] +#[type_key = "runtime.Module"] +pub struct ModuleNode { + base: Object, } crate::external! { @@ -49,21 +52,18 @@ crate::external! { #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; + + #[name("runtime.ModuleSaveToFile")] + fn save_to_file(module: Module, name: TString, fmt: TString); + + // TODO(@jroesch): we need to refactor this + #[name("tvm.relay.module_export_library")] + fn export_library(module: Module, file_name: TString); } impl Module { - pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { - Self { - handle, - entry_func: None, - } - } - - pub fn entry(&mut self) -> Option { - if self.entry_func.is_none() { - self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); - } - self.entry_func.clone() + pub fn default_fn(&mut self) -> Result { + self.get_function("default", true) } /// Gets a function by name from a registered module. @@ -72,7 +72,7 @@ impl Module { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; check_call!(ffi::TVMModGetFunction( - self.handle, + self.handle(), name.as_ptr() as *const c_char, query_import as c_int, &mut fhandle as *mut _ @@ -87,7 +87,7 @@ impl Module { /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { - check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle())) } /// Loads a module shared library from path. @@ -110,6 +110,14 @@ impl Module { Ok(module) } + pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> { + save_to_file(self.clone(), name.into(), fmt.into()) + } + + pub fn export_library(&self, name: String) -> Result<(), Error> { + export_library(self.clone(), name.into()) + } + /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); @@ -118,13 +126,7 @@ impl Module { } /// Returns the underlying module handle. - pub fn handle(&self) -> ffi::TVMModuleHandle { - self.handle - } -} - -impl Drop for Module { - fn drop(&mut self) { - check_call!(ffi::TVMModFree(self.handle)); + pub unsafe fn handle(&self) -> ffi::TVMModuleHandle { + self.0.clone().unwrap().into_raw() as *mut _ } } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 8df6041956b8..264d5febd103 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -267,6 +267,10 @@ impl ObjectPtr { Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } } + + pub unsafe fn into_raw(self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -300,7 +304,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { use crate::ndarray::NDArrayContainer; match ret_value { - RetValue::ObjectHandle(handle) => { + RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); optr.downcast() @@ -329,6 +333,11 @@ impl<'a, T: IsObject> From> for ArgValue<'a> { assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } + "runtime.Module" => { + let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + assert!(!raw_ptr.is_null()); + ArgValue::ModuleHandle(raw_ptr) + } _ => { let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); @@ -346,7 +355,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { use crate::ndarray::NDArrayContainer; match arg_value { - ArgValue::ObjectHandle(handle) => { + ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); optr.downcast() diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index affd81b0e7ed..c5ede7d224ce 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); +impl_typed_and_to_function!(6; A, B, C, D, E, G); #[cfg(test)] mod tests { diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs deleted file mode 100644 index b8cd190176c4..000000000000 --- a/rust/tvm-rt/src/value.rs +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module implements [`ArgValue`] and [`RetValue`] types -//! and their conversions needed for the types used in frontend crate. -//! `RetValue` is the owned version of `TVMPODValue`. - -use std::convert::TryFrom; - -use crate::{ArgValue, Module, RetValue}; -use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; - -macro_rules! impl_handle_val { - ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { - impl<'a> From<&'a $type> for ArgValue<'a> { - fn from(arg: &'a $type) -> Self { - ArgValue::$variant(arg.handle() as $inner_type) - } - } - - impl<'a> From<&'a mut $type> for ArgValue<'a> { - fn from(arg: &'a mut $type) -> Self { - ArgValue::$variant(arg.handle() as $inner_type) - } - } - - impl<'a> TryFrom> for $type { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) - } - } - - impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) - } - } - - impl From<$type> for RetValue { - fn from(val: $type) -> RetValue { - RetValue::$variant(val.handle() as $inner_type) - } - } - - impl TryFrom for $type { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result<$type, Self::Error> { - try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) - } - } - }; -} - -impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); - -#[cfg(test)] -mod tests { - use std::{convert::TryInto, str::FromStr}; - - use crate::{ByteArray, Context, DataType}; - - use super::*; - - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } - - #[test] - fn ty() { - let t = DataType::from_str("int32").unwrap(); - let tvm: DataType = RetValue::from(t).try_into().unwrap(); - assert_eq!(tvm, t); - } - - #[test] - fn ctx() { - let c = Context::from_str("gpu").unwrap(); - let tvm: Context = RetValue::from(c).try_into().unwrap(); - assert_eq!(tvm, c); - } -} diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 29d2003b5089..9438f340f78f 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -50,9 +50,10 @@ tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" -pyo3 = { version = "0.11.1", optional = true } +pyo3 = { version = "^0.13", optional = true } codespan-reporting = "0.9.5" structopt = { version = "0.3" } +tracing = "^0.1" [[bin]] name = "tyck" diff --git a/rust/tvm/README.md b/rust/tvm/README.md index 26f9f1fbedfd..75fabe7d9a1b 100644 --- a/rust/tvm/README.md +++ b/rust/tvm/README.md @@ -15,221 +15,40 @@ -# TVM Runtime Frontend Support +# TVM -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm). +The code works on **Stable Rust** and is tested against `rustc 1.47`. -## What Does This Crate Offer? - -Here is a major workflow - -1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/) -2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. -3. Deploy your models using **Rust** :heart: - -### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k - -Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. - -Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM - -```python -block = get_model('resnet18_v1', pretrained=True) - -sym, params = relay.frontend.from_mxnet(block, shape_dict) -# compile the model -with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build( - net, target, params=params) -# same the model artifacts -lib.save(os.path.join(target_dir, "deploy_lib.o")) -cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), - [os.path.join(target_dir, "deploy_lib.o")]) - -with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: - fo.write(graph.json()) -with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(params)) -``` +You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/tvm/index.html). -Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image - -![cat](https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true) +## What Does This Crate Offer? -as demostrated in the following Rust snippet +The goal of this crate is to provide bindings to both the TVM compiler and runtime +APIs. First train your **Deep Learning** model using any major framework such as +[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/). +Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators. -```rust - let graph = fs::read_to_string("deploy_graph.json")?; - // load the built module - let lib = Module::load(&Path::new("deploy_lib.so"))?; - // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); - let runtime_create_fn_ret = call_packed!( - runtime_create_fn, - &graph, - &lib, - &ctx.device_type, - &ctx.device_id - )?; - // get graph runtime module - let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; - // get the registered `load_params` from runtime module - let ref load_param_fn = graph_runtime_module - .get_function("load_params", false) - .unwrap(); - // parse parameters and convert to TVMByteArray - let params: Vec = fs::read("deploy_param.params")?; - let barr = TVMByteArray::from(¶ms); - // load the parameters - call_packed!(load_param_fn, &barr)?; - // get the set_input function - let ref set_input_fn = graph_runtime_module - .get_function("set_input", false) - .unwrap(); +The Rust bindings are composed of a few crates: +- The [tvm](https://tvm.apache.org/docs/api/rust/tvm/index.html) crate which exposes Rust bindings to + both the compiler and runtime. +- The [tvm_macros](https://tvm.apache.org/docs/api/rust/tvm/index.html) crate which provides macros + which generate unsafe boilerplate for TVM's data structures. +- The [tvm_rt](https://tvm.apache.org/docs/api/rust/tvm_rt/index.html) crate which exposes Rust + bindings to the TVM runtime APIs. +- The [tvm_sys] crate which provides raw bindings and linkage to the TVM C++ library. +- The [tvm_graph_rt] crate which implements a version of the TVM graph runtime in Rust vs. C++. - call_packed!(set_input_fn, "data", &input)?; - // get `run` function from runtime module - let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); - // execute the run function. Note that it has no argument - call_packed!(run_fn,)?; - // prepare to get the output - let output_shape = &mut [1, 1000]; - let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); - // get the `get_output` function from runtime module - let ref get_output_fn = graph_runtime_module - .get_function("get_output", false) - .unwrap(); - // execute the get output function - call_packed!(get_output_fn, &0, &output)?; - // flatten the output as Vec - let output = output.to_vec::()?; -``` +These crates have been recently refactored and reflect a much different philosophy than +previous bindings, as well as much increased support for more of the TVM API including +exposing all of the compiler internals. -and the model correctly predicts the input image as **tiger cat**. +These are still very much in development and should not be considered stable, but contributions +and usage is welcome and encouraged. If you want to discuss design issues check our Discourse +[forum](https://discuss.tvm.ai) and for bug reports check our GitHub [repository](https://github.com/apache/tvm). -## Installations +## Install -Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. +Please follow the TVM [install](https://tvm.apache.org/docs/install/index.html) instructions, `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. *Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. - -## Supported TVM Functionalities - -### Use TVM to Generate Shared Library - -One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. - -```python -import os -import tvm -from tvm import te -from tvm.contrib import cc - -def test_add(target_dir): - if not tvm.runtime.enabled("cuda"): - print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) - return - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - bx, tx = s[C].split(C.op.axis[0], factor=64) - s[C].bind(bx, tvm.thread_axis("blockIdx.x")) - s[C].bind(tx, tvm.thread_axis("threadIdx.x")) - fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") - - fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) - fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) - cc.create_shared(os.path.join(target_dir, "add_gpu.so"), - [os.path.join(target_dir, "add_gpu.o")]) - - -if __name__ == "__main__": - import sys - if len(sys.argv) != 2: - sys.exit(-1) - test_add(sys.argv[1]) -``` - -### Run the Generated Shared Library - -The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. - -```rust -extern crate tvm_frontend as tvm; - -use tvm::*; - -fn main() { - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); - arr.copy_from_buffer(data.as_mut_slice()); - let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); - let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); - let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); - assert!(fadd.enabled("gpu")); - fadd.import_module(fadd_dep); - fadd.entry(); - function::Builder::from(&mut fadd) - .arg(&arr) - .arg(&arr) - .set_output(&mut ret)? - .invoke() - .unwrap(); - - assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); -} -``` - -**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by -`cargo:rustc-link-search=native=add_gpu`. - -See the tests and examples custom `build.rs` for more details. - -### Convert and Register a Rust Function as a TVM Packed Function - -One can use `register_global_func!` macro to convert and register a Rust -function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows - -```rust -#[macro_use] -extern crate tvm_frontend as tvm; -use std::convert::TryInto; -use tvm::*; - -fn main() { - register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { - let mut ret = 0f32; - let shape = &mut [2]; - for arg in args.iter() { - let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); - let arg: NDArray = arg.try_into()?; - let arr = arg.copy_to_ndarray(e).unwrap(); - let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); - ret += rnd.scalar_sum(); - } - let ret_val = TVMRetValue::from(&ret); - Ok(ret_val) - } - } - - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); - arr.copy_from_buffer(data.as_mut_slice()); - let mut registered = function::Builder::default(); - let ret: f64 = registered - .get_function("sum", true) - .arg(&arr) - .arg(&arr) - .invoke() - .unwrap() - .try_into() - .unwrap(); - - assert_eq!(ret, 14f64); -} -``` diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs new file mode 100644 index 000000000000..6b5873398cab --- /dev/null +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryInto; +use std::io::Read; +use std::path::Path; + +use once_cell::sync::Lazy; +use thiserror::Error; + +use crate::ir::IRModule; +use crate::python; +use crate::runtime::{map::Map, Function, Module as RtModule, NDArray, String}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + TVM(#[from] crate::errors::Error), +} + +static TVM_BUILD: Lazy = Lazy::new(|| { + python::import("tvm").unwrap(); + python::import("tvm.relay").unwrap(); + Function::get("tvm.relay.build").unwrap() +}); + +fn _compile_module( + module: IRModule, + target: String, + target_host: String, + params: Map, + module_name: String, +) -> Result { + // The RAW API is Fn(IRModule, String, String, Map, String); + let module = TVM_BUILD.invoke(vec![ + module.into(), + target.into(), + target_host.into(), + params.into(), + module_name.into(), + ])?; + let module: RtModule = module.try_into().unwrap(); + Ok(module) +} + +#[derive(Debug)] +pub struct CompilerConfig { + target: Option, + target_host: Option, + params: Map, + module_name: Option, +} + +impl Default for CompilerConfig { + fn default() -> Self { + CompilerConfig { + target: None, + target_host: None, + params: Map::empty(), + module_name: None, + } + } +} + +/// Compile a module from a configuration and IRModule. +/// +/// # Arguments +/// +/// * `config` - The configuration for the compiler. +/// * `module` - The IRModule to compile. +pub fn compile_module(config: CompilerConfig, module: IRModule) -> Result { + let target = config.target.unwrap_or("llvm".into()); + _compile_module( + module, + target, + "llvm".into(), + Map::::empty(), + "default".into(), + ) +} + +/// Compile an IRModule on disk and output a runtime module to disk. +/// +/// # Arguments +/// * `config` - The configuration for the compiler. +/// * `ir_mod_path` - The path the serialized IRModule. +// +/// * `output_rt_mod_path` - The path to the output runtime module. +pub fn compile_from_disk( + config: CompilerConfig, + ir_mod_path: P1, + output_rt_mod_path: P2, +) -> Result<(), Error> +where + P1: AsRef, + P2: AsRef, +{ + let mut input_file = std::fs::File::open(ir_mod_path.as_ref())?; + let mut input_module_text = std::string::String::new(); + input_file.read_to_string(&mut input_module_text)?; + let input_module = IRModule::parse("name", input_module_text)?; + let rt_module = compile_module(config, input_module)?; + let output_path_str = output_rt_mod_path.as_ref().display().to_string(); + rt_module.export_library(output_path_str)?; + Ok(()) +} diff --git a/rust/tvm/src/compiler/mod.rs b/rust/tvm/src/compiler/mod.rs new file mode 100644 index 000000000000..ed8b47edbad4 --- /dev/null +++ b/rust/tvm/src/compiler/mod.rs @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +pub mod graph_rt; diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 653169def3a4..03d8a4920718 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -32,12 +32,14 @@ use super::span::Span; #[type_key = "Expr"] pub struct BaseExprNode { pub base: Object, + pub span: Span, } impl BaseExprNode { - pub fn base() -> BaseExprNode { + pub fn base(span: Span) -> BaseExprNode { BaseExprNode { base: Object::base::(), + span, } } } @@ -52,9 +54,9 @@ pub struct PrimExprNode { } impl PrimExprNode { - pub fn base(datatype: DataType) -> PrimExprNode { + pub fn base(datatype: DataType, span: Span) -> PrimExprNode { PrimExprNode { - base: BaseExprNode::base::(), + base: BaseExprNode::base::(span), datatype, } } @@ -70,9 +72,9 @@ pub struct GlobalVarNode { } impl GlobalVar { - pub fn new(name_hint: String, _span: Span) -> GlobalVar { + pub fn new(name_hint: String, span: Span) -> GlobalVar { let node = GlobalVarNode { - base: relay::ExprNode::base::(), + base: relay::ExprNode::base::(span), name_hint: name_hint.into(), }; GlobalVar(Some(ObjectPtr::new(node))) diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index 14c00ea02bf6..43aca869f385 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -17,12 +17,12 @@ * under the License. */ -use crate::ir::relay::ExprNode; -use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; - use tvm_macros::Object; -// Define Calling Convention. +use super::span::Span; + +use crate::ir::relay::ExprNode; +use crate::runtime::{IsObject, IsObjectRef, ObjectRef}; // TODO(@jroesch): define DictAttrs pub type DictAttrs = ObjectRef; @@ -39,7 +39,7 @@ pub struct BaseFuncNode { impl BaseFuncNode { pub fn base() -> BaseFuncNode { BaseFuncNode { - base: ExprNode::base::(), + base: ExprNode::base::(Span::null()), attrs: ::null(), } } diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index a09f70dc25b9..513a906f6db4 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -279,8 +279,8 @@ mod tests { let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null()); let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); module.add_def(name.clone(), type_data, true)?; - let by_gtv = module.lookup_def(name)?; - let by_gv = module.lookup_def_str("my_type")?; + let _by_gtv = module.lookup_def(name)?; + let _by_gv = module.lookup_def_str("my_type")?; Ok(()) } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 9d2983237acb..f43967f28d60 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -23,7 +23,7 @@ use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::span::Span; -use super::ty::{Type, TypeNode}; +use super::ty::Type; use tvm_macros::Object; use tvm_rt::NDArray; @@ -39,19 +39,14 @@ pub mod attrs; #[type_key = "RelayExpr"] pub struct ExprNode { pub base: BaseExprNode, - pub span: ObjectRef, pub checked_type: Type, } impl ExprNode { - pub fn base() -> ExprNode { + pub fn base(span: Span) -> ExprNode { ExprNode { - base: BaseExprNode::base::(), - span: ObjectRef::null(), - checked_type: Type::from(TypeNode { - base: Object::base::(), - span: Span::null(), - }), + base: BaseExprNode::base::(span.clone()), + checked_type: Type::null(), } } } @@ -85,9 +80,9 @@ pub struct ConstantNode { } impl Constant { - pub fn new(data: NDArray, _span: ObjectRef) -> Constant { + pub fn new(data: NDArray, span: Span) -> Constant { let node = ConstantNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), data: data, }; Constant(Some(ObjectPtr::new(node))) @@ -104,9 +99,9 @@ pub struct TupleNode { } impl Tuple { - pub fn new(fields: Array, _span: ObjectRef) -> Tuple { + pub fn new(fields: Array, span: Span) -> Tuple { let node = TupleNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), fields, }; Tuple(Some(ObjectPtr::new(node))) @@ -124,9 +119,9 @@ pub struct VarNode { } impl Var { - pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var { + pub fn new(name_hint: String, type_annotation: Type, span: Span) -> Var { let node = VarNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), vid: Id::new(name_hint.into()), type_annotation: type_annotation, }; @@ -165,10 +160,10 @@ impl Call { args: Array, attrs: Attrs, type_args: Array, - _span: ObjectRef, + span: Span, ) -> Call { let node = CallNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), op: op, args: args, attrs: attrs, @@ -190,9 +185,9 @@ pub struct LetNode { } impl Let { - pub fn new(var: Var, value: Expr, body: Expr, _span: ObjectRef) -> Let { + pub fn new(var: Var, value: Expr, body: Expr, span: Span) -> Let { let node = LetNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), var, value, body, @@ -213,9 +208,9 @@ pub struct IfNode { } impl If { - pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, _span: ObjectRef) -> If { + pub fn new(cond: Expr, true_branch: Expr, false_branch: Expr, span: Span) -> If { let node = IfNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), cond, true_branch, false_branch, @@ -235,9 +230,9 @@ pub struct TupleGetItemNode { } impl TupleGetItem { - pub fn new(tuple: Expr, index: i32, _span: ObjectRef) -> TupleGetItem { + pub fn new(tuple: Expr, index: i32, span: Span) -> TupleGetItem { let node = TupleGetItemNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), tuple, index, }; @@ -255,9 +250,9 @@ pub struct RefCreateNode { } impl RefCreate { - pub fn new(value: Expr, _span: ObjectRef) -> RefCreate { + pub fn new(value: Expr, span: Span) -> RefCreate { let node = RefCreateNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), value, }; RefCreate(Some(ObjectPtr::new(node))) @@ -274,9 +269,9 @@ pub struct RefReadNode { } impl RefRead { - pub fn new(ref_value: Expr, _span: ObjectRef) -> RefRead { + pub fn new(ref_value: Expr, span: Span) -> RefRead { let node = RefReadNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), ref_value, }; RefRead(Some(ObjectPtr::new(node))) @@ -294,9 +289,9 @@ pub struct RefWriteNode { } impl RefWrite { - pub fn new(ref_value: Expr, value: Expr, _span: ObjectRef) -> RefWrite { + pub fn new(ref_value: Expr, value: Expr, span: Span) -> RefWrite { let node = RefWriteNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), ref_value, value, }; @@ -316,9 +311,9 @@ pub struct ConstructorNode { } impl Constructor { - pub fn new(name_hint: String, inputs: Array, tag: i32, _span: ObjectRef) -> Constructor { + pub fn new(name_hint: String, inputs: Array, tag: i32, span: Span) -> Constructor { let node = ConstructorNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), name_hint, inputs, tag, @@ -335,14 +330,14 @@ impl Constructor { #[type_key = "relay.Pattern"] pub struct PatternNode { pub base: Object, - pub span: ObjectRef, + pub span: Span, } impl PatternNode { - pub fn base() -> PatternNode { + pub fn base(span: Span) -> PatternNode { PatternNode { base: Object::base::(), - span: ObjectRef::null(), + span: span, } } } @@ -356,9 +351,9 @@ pub struct PatternWildcardNode { } impl PatternWildcard { - pub fn new(_span: ObjectRef) -> PatternWildcard { + pub fn new(span: Span) -> PatternWildcard { let node = PatternWildcardNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), }; PatternWildcard(Some(ObjectPtr::new(node))) } @@ -374,9 +369,9 @@ pub struct PatternVarNode { } impl PatternVar { - pub fn new(var: Var, _span: ObjectRef) -> PatternVar { + pub fn new(var: Var, span: Span) -> PatternVar { let node = PatternVarNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), var: var, }; PatternVar(Some(ObjectPtr::new(node))) @@ -397,10 +392,10 @@ impl PatternConstructor { pub fn new( constructor: Constructor, patterns: Array, - _span: ObjectRef, + span: Span, ) -> PatternConstructor { let node = PatternConstructorNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), constructor, patterns, }; @@ -418,9 +413,9 @@ pub struct PatternTupleNode { } impl PatternTuple { - pub fn new(patterns: Array, _span: ObjectRef) -> PatternTuple { + pub fn new(patterns: Array, span: Span) -> PatternTuple { let node = PatternTupleNode { - base: PatternNode::base::(), + base: PatternNode::base::(span), patterns, }; PatternTuple(Some(ObjectPtr::new(node))) @@ -438,7 +433,7 @@ pub struct ClauseNode { } impl Clause { - pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { + pub fn new(lhs: Pattern, rhs: Expr, _span: Span) -> Clause { let node = ClauseNode { base: Object::base::(), lhs, @@ -460,9 +455,9 @@ pub struct MatchNode { } impl Match { - pub fn new(data: Expr, clauses: Array, complete: bool, _span: ObjectRef) -> Match { + pub fn new(data: Expr, clauses: Array, complete: bool, span: Span) -> Match { let node = MatchNode { - base: ExprNode::base::(), + base: ExprNode::base::(span), data, clauses, complete, diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index ccbe30c95820..dcbec520d3b6 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -18,7 +18,9 @@ */ use super::{PrimExpr, PrimExprNode}; -use crate::runtime::String as TVMString; + +use crate::ir::span::Span; +use crate::runtime::{IsObjectRef, String as TVMString}; use crate::DataType; use tvm_macros::Object; @@ -36,7 +38,7 @@ macro_rules! define_node { impl $name { pub fn new(datatype: DataType, $($id : $t,)*) -> $name { - let base = PrimExprNode::base::<$node>(datatype); + let base = PrimExprNode::base::<$node>(datatype, Span::null()); let node = $node { base, $($id),* }; node.into() } @@ -56,7 +58,6 @@ impl From for IntImm { impl From for PrimExpr { fn from(i: i32) -> PrimExpr { - use crate::runtime::IsObjectRef; IntImm::from(i).upcast() } } diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index f7c52b51f332..83fdbfeb66aa 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -23,7 +23,7 @@ use tvm_rt::{array::Array, DataType}; use crate::ir::relay::Constructor; use crate::ir::span::Span; use crate::ir::PrimExpr; -use crate::runtime::{string::String as TString, IsObject, Object, ObjectPtr}; +use crate::runtime::{string::String as TString, IsObject, IsObjectRef, Object, ObjectPtr}; #[repr(C)] #[derive(Object, Debug)] @@ -147,8 +147,17 @@ pub struct TupleTypeNode { } impl TupleType { + // todo add coercion + pub fn new(fields: Vec, span: Span) -> Self { + let node = TupleTypeNode { + base: TypeNode::base::(span), + fields: Array::from_vec(fields).unwrap(), + }; + ObjectPtr::new(node).into() + } + pub fn empty() -> TupleType { - todo!() + TupleType::new(vec![], Span::null()) } } @@ -236,7 +245,13 @@ impl TensorType { }; ObjectPtr::new(node).into() } + + pub fn static_sh(shape: Vec, dtype: DataType, span: Span) -> TensorType { + let sh = Array::from_vec(shape.into_iter().map(Into::into).collect()).unwrap(); + Self::new(sh, dtype, span) + } } + // TODO(@jroesch): implement these in future. // // using TypeCall = tvm::TypeCall; diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index e86420eb70c9..caae07775d21 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -39,7 +39,9 @@ pub use tvm_rt::errors; pub use tvm_rt::function; pub use tvm_rt::module; pub use tvm_rt::ndarray; -pub use tvm_rt::value; + +#[cfg(feature = "python")] +pub mod compiler; pub mod ir; #[cfg(feature = "python")] pub mod python; diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs index 89558af733b3..c224fb4db372 100644 --- a/rust/tvm/src/python.rs +++ b/rust/tvm/src/python.rs @@ -29,6 +29,8 @@ use pyo3::prelude::*; pub fn load() -> Result { let gil = Python::acquire_gil(); let py = gil.python(); + // let main_mod = initialize(); + //let main_mod = main_mod.as_ref(py); load_python_tvm_(py).map_err(|e| { // We can't display Python exceptions via std::fmt::Display, // so print the error here manually. @@ -36,25 +38,33 @@ pub fn load() -> Result { }) } -// const TVMC_CODE: &'static str = include_str!("tvmc.py"); +pub fn import(mod_to_import: &str) -> PyResult<()> { + let gil = Python::acquire_gil(); + let py = gil.python(); + import_python(py, mod_to_import)?; + Ok(()) +} + +fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> { + let imported_mod = py.import(to_import)?; + Ok(imported_mod) +} fn load_python_tvm_(py: Python) -> PyResult { - let sys = py.import("tvm")?; - let version: String = sys.get("__version__")?.extract()?; - // py.run(TVMC_CODE, None, None)?; + let imported_mod = import_python(py, "tvm")?; + let version: String = imported_mod.get("__version__")?.extract()?; Ok(version) } #[cfg(test)] mod tests { - use super::load_python_tvm_; + use super::*; use anyhow::Result; - use pyo3::prelude::*; #[ignore] #[test] fn test_run() -> Result<()> { - load_python_tvm_(Python::acquire_gil().python()).unwrap(); + load().unwrap(); Ok(()) } } diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm/src/runtime/graph_rt.rs index 8b26ebb4ca22..fcc41aca560f 100644 --- a/rust/tvm/src/runtime/graph_rt.rs +++ b/rust/tvm/src/runtime/graph_rt.rs @@ -34,13 +34,23 @@ pub struct GraphRt { } impl GraphRt { + /// Create a graph runtime directly from a runtime module. + pub fn from_module(module: Module, ctx: Context) -> Result { + let default: Box Result> = + module.get_function("default", false)?.into(); + + Ok(Self { + module: default(ctx)?, + }) + } + /// Create a graph runtime from the deprecated graph, lib, ctx triple. pub fn create_from_parts(graph: &str, lib: Module, ctx: Context) -> Result { let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - (&lib).into(), + lib.into(), (&ctx.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (ctx.device_id as i32).into(), diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index e4249a491746..450ab48dc1b2 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -30,6 +30,7 @@ fn main() { } else { (Context::gpu(0), "gpu") }; + let dtype = DataType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); @@ -38,11 +39,13 @@ fn main() { if !fadd.enabled(ctx_name) { return; } + if cfg!(feature = "gpu") { fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); } - fadd.entry() + // todo(@jroesch): fix the entry_name + fadd.get_function("__tvm_main__", false) .expect("module must have entry point") .invoke(vec![(&arr).into(), (&arr).into(), (&ret).into()]) .unwrap(); diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py index b9672fbf4aaf..3c1fc64d3e36 100755 --- a/rust/tvm/tests/basics/src/tvm_add.py +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -37,7 +37,6 @@ def main(target, out_dir): s[C].bind(tx, te.thread_axis("threadIdx.x")) fadd = tvm.build(s, [A, B, C], target, target_host="llvm", name="myadd") - fadd.save(osp.join(out_dir, "test_add.o")) if target == "cuda": fadd.imported_modules[0].save(osp.join(out_dir, "test_add.ptx")) diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 4cec5e3643c1..d84a8215421f 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -178,7 +178,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") - .set_body_typed([](Module mod, std::string name, std::string fmt) { + .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) { mod->SaveToFile(name, fmt); });