diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index 675b8ba5dc44..c75e9020cc93 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -155,7 +155,7 @@ TVMPODValue! { Bytes(val) => { (TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)} + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) } } } @@ -260,12 +260,24 @@ impl<'a> From<&'a str> for TVMArgValue<'a> { } } +impl<'a> From for TVMArgValue<'a> { + fn from(s: String) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + impl<'a> From<&'a CStr> for TVMArgValue<'a> { fn from(s: &'a CStr) -> Self { Self::Str(s) } } +impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> { + fn from(s: &'a TVMByteArray) -> Self { + Self::Bytes(s) + } +} + impl<'a> TryFrom> for &'a str { type Error = ValueDowncastError; fn try_from(val: TVMArgValue<'a>) -> Result { diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 94af95c62841..6d17db207865 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::str::FromStr; +use std::{os::raw::c_char, str::FromStr}; use failure::Error; @@ -157,17 +157,57 @@ impl_tvm_context!( DLDeviceType_kDLExtDev: [ext_dev] ); +/// A struct holding TVM byte-array. +/// +/// ## Example +/// +/// ``` +/// let v = b"hello"; +/// let barr = TVMByteArray::from(&v); +/// assert_eq!(barr.len(), v.len()); +/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); +/// ``` impl TVMByteArray { + /// Gets the underlying byte-array pub fn data(&self) -> &'static [u8] { unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) } } + + /// Gets the length of the underlying byte-array + pub fn len(&self) -> usize { + self.size + } + + /// Converts the underlying byte-array to `Vec` + pub fn to_vec(&self) -> Vec { + self.data().to_vec() + } } -impl<'a> From<&'a [u8]> for TVMByteArray { - fn from(bytes: &[u8]) -> Self { - Self { - data: bytes.as_ptr() as *const i8, - size: bytes.len(), +// Needs AsRef for Vec +impl> From for TVMByteArray { + fn from(arg: T) -> Self { + let arg = arg.as_ref(); + TVMByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len(), } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert() { + let v = vec![1u8, 2, 3]; + let barr = TVMByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); + let v = b"hello"; + let barr = TVMByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); + } +} diff --git a/rust/frontend/examples/resnet/README.md b/rust/frontend/examples/resnet/README.md index e84c099de411..3ce4a778e4bd 100644 --- a/rust/frontend/examples/resnet/README.md +++ b/rust/frontend/examples/resnet/README.md @@ -21,12 +21,25 @@ This end-to-end example shows how to: * build `Resnet 18` with `tvm` and `nnvm` from Python * use the provided Rust frontend API to test for an input image -To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +To run the example with pretrained resnet weights, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html). -* **Build the example**: `cargo build` +* **Build the example**: `cargo build To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with `println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. * **Run the example**: `cargo run` + +Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with + +``` +let output = Command::new("python") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg(&format!("--pretrained")) + .output() + .expect("Failed to execute command"); +``` + +Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/frontend/examples/resnet/build.rs b/rust/frontend/examples/resnet/build.rs index 037c3bbd97d2..b9a3c4ccdf12 100644 --- a/rust/frontend/examples/resnet/build.rs +++ b/rust/frontend/examples/resnet/build.rs @@ -17,16 +17,23 @@ * under the License. */ -use std::process::Command; +use std::{path::Path, process::Command}; fn main() { - let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + let output = Command::new("python3") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) .output() .expect("Failed to execute command"); assert!( - std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(), + Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), "Could not prepare demo: {}", - String::from_utf8(output.stderr).unwrap().trim() + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") ); println!( "cargo:rustc-link-search=native={}", diff --git a/rust/frontend/examples/resnet/src/build_resnet.py b/rust/frontend/examples/resnet/src/build_resnet.py index 5da1db63310e..2497a41c6ef7 100644 --- a/rust/frontend/examples/resnet/src/build_resnet.py +++ b/rust/frontend/examples/resnet/src/build_resnet.py @@ -24,19 +24,18 @@ import numpy as np -import mxnet as mx -from mxnet.gluon.model_zoo.vision import get_model -from mxnet.gluon.utils import download - import tvm +from tvm import relay +from tvm.relay import testing from tvm.contrib import graph_runtime, cc -import nnvm logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) parser = argparse.ArgumentParser(description='Resnet build example') aa = parser.add_argument +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--pretrained', action='store_true', help='use a pretrained resnet') aa('--batch-size', type=int, default=1, help='input image batch size') aa('--opt-level', type=int, default=3, help='level of optimization. 0 is unoptimized and 3 is the highest level') @@ -45,7 +44,7 @@ aa('--image-name', type=str, default='cat.png', help='name of input image to download') args = parser.parse_args() -target_dir = osp.dirname(osp.dirname(osp.realpath(__file__))) +build_dir = args.build_dir batch_size = args.batch_size opt_level = args.opt_level target = tvm.target.create(args.target) @@ -57,30 +56,42 @@ def build(target_dir): deploy_lib = osp.join(target_dir, 'deploy_lib.o') if osp.exists(deploy_lib): return - # download the pretrained resnet18 trained on imagenet1k dataset for - # image classification task - block = get_model('resnet18_v1', pretrained=True) - sym, params = nnvm.frontend.from_mxnet(block) - # add the softmax layer for prediction - net = nnvm.sym.softmax(sym) + if args.pretrained: + # needs mxnet installed + from mxnet.gluon.model_zoo.vision import get_model + + # if `--pretrained` is enabled, it downloads a pretrained + # resnet18 trained on imagenet1k dataset for image classification task + block = get_model('resnet18_v1', pretrained=True) + net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) + # we want a probability so add a softmax operator + net = relay.Function(net.params, relay.nn.softmax(net.body), + None, net.type_params, net.attrs) + else: + # use random weights from relay.testing + net, params = relay.testing.resnet.get_workload( + num_layers=18, batch_size=batch_size, image_shape=image_shape) + # compile the model - with nnvm.compiler.build_config(opt_level=opt_level): - graph, lib, params = nnvm.compiler.build( - net, target, shape={"data": data_shape}, params=params) + with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) + # save the model artifacts lib.save(deploy_lib) cc.create_shared(osp.join(target_dir, "deploy_lib.so"), [osp.join(target_dir, "deploy_lib.o")]) with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: - fo.write(graph.json()) + fo.write(graph) with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: - fo.write(nnvm.compiler.save_param_dict(params)) + fo.write(relay.save_param_dict(params)) def download_img_labels(): """ Download an image and imagenet1k class labels for test""" + from mxnet.gluon.utils import download + img_name = 'cat.png' synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', '4d0b62f3d01426887599d4f7ede23ee5/raw/', @@ -97,11 +108,11 @@ def download_img_labels(): w = csv.writer(fout) w.writerows(synset.items()) -def test_build(target_dir): +def test_build(build_dir): """ Sanity check with random input""" - graph = open(osp.join(target_dir, "deploy_graph.json")).read() - lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so")) - params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read()) + graph = open(osp.join(build_dir, "deploy_graph.json")).read() + lib = tvm.module.load(osp.join(build_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) @@ -112,10 +123,11 @@ def test_build(target_dir): if __name__ == '__main__': logger.info("building the model") - build(target_dir) + build(build_dir) logger.info("build was successful") logger.info("test the build artifacts") - test_build(target_dir) + test_build(build_dir) logger.info("test was successful") - download_img_labels() - logger.info("image and synset downloads are successful") + if args.pretrained: + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs index e50d92795883..cf24973ada5b 100644 --- a/rust/frontend/examples/resnet/src/main.rs +++ b/rust/frontend/examples/resnet/src/main.rs @@ -84,7 +84,7 @@ fn main() { let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); let runtime_create_fn_ret = call_packed!( runtime_create_fn, - &graph, + graph, &lib, &ctx.device_type, &ctx.device_id @@ -107,8 +107,7 @@ fn main() { .get_function("set_input", false) .unwrap(); - let data_str = "data".to_string(); - call_packed!(set_input_fn, &data_str, &input).unwrap(); + call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); // get `run` function from runtime module let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); // execute the run function. Note that it has no argument diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs deleted file mode 100644 index a1d183d9f525..000000000000 --- a/rust/frontend/src/bytearray.rs +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! Provides [`TVMByteArray`] used for passing the model parameters -//! (stored as byte-array) to a runtime module. -//! -//! For more detail, please see the example `resnet` in `examples` repository. - -use std::os::raw::c_char; - -use tvm_common::ffi; - -/// A struct holding TVM byte-array. -/// -/// ## Example -/// -/// ``` -/// let v = b"hello".to_vec(); -/// let barr = TVMByteArray::from(&v); -/// assert_eq!(barr.len(), v.len()); -/// assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); -/// ``` -#[derive(Debug, Clone)] -pub struct TVMByteArray { - pub(crate) inner: ffi::TVMByteArray, -} - -impl TVMByteArray { - pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray { - TVMByteArray { inner: barr } - } - - /// Gets the length of the underlying byte-array - pub fn len(&self) -> usize { - self.inner.size - } - - /// Gets the underlying byte-array as `Vec` - pub fn data(&self) -> Vec { - unsafe { - let sz = self.len(); - let mut ret_buf = Vec::with_capacity(sz); - ret_buf.set_len(sz); - self.inner.data.copy_to(ret_buf.as_mut_ptr(), sz); - ret_buf - } - } -} - -impl<'a, T: AsRef<[u8]>> From for TVMByteArray { - fn from(arg: T) -> Self { - let arg = arg.as_ref(); - let barr = ffi::TVMByteArray { - data: arg.as_ptr() as *const c_char, - size: arg.len(), - }; - TVMByteArray::new(barr) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn convert() { - let v = vec![1u8, 2, 3]; - let barr = TVMByteArray::from(&v); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.data(), vec![1i8, 2, 3]); - let v = b"hello".to_vec(); - let barr = TVMByteArray::from(&v); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.data(), vec![104i8, 101, 108, 108, 111]); - } -} diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index a5f0dd7b1019..d147871a3968 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -47,7 +47,7 @@ use failure::Error; use tvm_common::ffi; -use crate::function; +use crate::{function, TVMArgValue}; /// Device type can be from a supported device name. See the supported devices /// in [TVM](https://github.com/dmlc/tvm). @@ -60,7 +60,7 @@ use crate::function; ///``` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TVMDeviceType(pub usize); +pub struct TVMDeviceType(pub i64); impl Default for TVMDeviceType { /// default device is cpu. @@ -141,6 +141,12 @@ impl<'a> From<&'a str> for TVMDeviceType { } } +impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> { + fn from(dev: &TVMDeviceType) -> Self { + Self::Int(dev.0) + } +} + /// Represents the underlying device context. Default is cpu. /// /// ## Examples diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index 6e4123cb6217..adb258dbd3d9 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -30,7 +30,7 @@ //! //! Checkout the `examples` repository for more details. -#![feature(box_syntax)] +#![feature(box_syntax, type_alias_enum_variants)] #[macro_use] extern crate failure; @@ -48,7 +48,6 @@ use std::{ use failure::Error; pub use crate::{ - bytearray::TVMByteArray, context::{TVMContext, TVMDeviceType}, errors::*, function::Function, @@ -56,7 +55,7 @@ pub use crate::{ ndarray::NDArray, tvm_common::{ errors as common_errors, - ffi::{self, TVMType}, + ffi::{self, TVMByteArray, TVMType}, packed_func::{TVMArgValue, TVMRetValue}, }, }; @@ -89,7 +88,6 @@ pub(crate) fn set_last_error(err: &Error) { #[macro_use] pub mod function; -pub mod bytearray; pub mod context; pub mod errors; pub mod module; diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 46cafe3ccd79..1728fece5965 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -76,3 +76,7 @@ cargo run --bin float cargo run --bin array cargo run --bin string cd - + +cd examples/resnet +cargo build +cd -