Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUST][FRONTEND] Fix resnet example #3000

Merged
merged 1 commit into from
Apr 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion rust/common/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ TVMPODValue! {
Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) }
}
}

Expand Down Expand Up @@ -260,12 +260,24 @@ impl<'a> From<&'a str> for TVMArgValue<'a> {
}
}

impl<'a> From<String> for TVMArgValue<'a> {
fn from(s: String) -> Self {
Self::String(CString::new(s).unwrap())
}
}

impl<'a> From<&'a CStr> for TVMArgValue<'a> {
fn from(s: &'a CStr) -> Self {
Self::Str(s)
}
}

impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> {
fn from(s: &'a TVMByteArray) -> Self {
Self::Bytes(s)
}
}

impl<'a> TryFrom<TVMArgValue<'a>> for &'a str {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
Expand Down
52 changes: 46 additions & 6 deletions rust/common/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

use std::str::FromStr;
use std::{os::raw::c_char, str::FromStr};

use failure::Error;

Expand Down Expand Up @@ -157,17 +157,57 @@ impl_tvm_context!(
DLDeviceType_kDLExtDev: [ext_dev]
);

/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello";
/// let barr = TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
impl TVMByteArray {
nhynes marked this conversation as resolved.
Show resolved Hide resolved
/// Gets the underlying byte-array
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) }
}

/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.size
}

/// Converts the underlying byte-array to `Vec<u8>`
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}
}

impl<'a> From<&'a [u8]> for TVMByteArray {
fn from(bytes: &[u8]) -> Self {
Self {
data: bytes.as_ptr() as *const i8,
size: bytes.len(),
// Needs AsRef for Vec
impl<T: AsRef<[u8]>> From<T> for TVMByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
let v = b"hello";
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
}
}
17 changes: 15 additions & 2 deletions rust/frontend/examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,25 @@ This end-to-end example shows how to:
* build `Resnet 18` with `tvm` and `nnvm` from Python
* use the provided Rust frontend API to test for an input image

To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
To run the example with pretrained resnet weights, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).

* **Build the example**: `cargo build`
* **Build the example**: `cargo build

To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.

* **Run the example**: `cargo run`

Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with

```
let output = Command::new("python")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.arg(&format!("--pretrained"))
.output()
.expect("Failed to execute command");
```

Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`!
15 changes: 11 additions & 4 deletions rust/frontend/examples/resnet/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@
* under the License.
*/

use std::process::Command;
use std::{path::Path, process::Command};

fn main() {
let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
let output = Command::new("python3")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.output()
.expect("Failed to execute command");
assert!(
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(),
"Could not prepare demo: {}",
String::from_utf8(output.stderr).unwrap().trim()
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
println!(
"cargo:rustc-link-search=native={}",
Expand Down
62 changes: 37 additions & 25 deletions rust/frontend/examples/resnet/src/build_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@

import numpy as np

import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download

import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
import nnvm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description='Resnet build example')
aa = parser.add_argument
aa('--build-dir', type=str, required=True, help='directory to put the build artifacts')
aa('--pretrained', action='store_true', help='use a pretrained resnet')
aa('--batch-size', type=int, default=1, help='input image batch size')
aa('--opt-level', type=int, default=3,
help='level of optimization. 0 is unoptimized and 3 is the highest level')
Expand All @@ -45,7 +44,7 @@
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
args = parser.parse_args()

target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
build_dir = args.build_dir
batch_size = args.batch_size
opt_level = args.opt_level
target = tvm.target.create(args.target)
Expand All @@ -57,30 +56,42 @@ def build(target_dir):
deploy_lib = osp.join(target_dir, 'deploy_lib.o')
if osp.exists(deploy_lib):
return
# download the pretrained resnet18 trained on imagenet1k dataset for
# image classification task
block = get_model('resnet18_v1', pretrained=True)

sym, params = nnvm.frontend.from_mxnet(block)
# add the softmax layer for prediction
net = nnvm.sym.softmax(sym)
if args.pretrained:
# needs mxnet installed
from mxnet.gluon.model_zoo.vision import get_model

# if `--pretrained` is enabled, it downloads a pretrained
# resnet18 trained on imagenet1k dataset for image classification task
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
# we want a probability so add a softmax operator
net = relay.Function(net.params, relay.nn.softmax(net.body),
None, net.type_params, net.attrs)
else:
# use random weights from relay.testing
net, params = relay.testing.resnet.get_workload(
num_layers=18, batch_size=batch_size, image_shape=image_shape)

# compile the model
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build_module.build(net, target, params=params)

# save the model artifacts
lib.save(deploy_lib)
cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
[osp.join(target_dir, "deploy_lib.o")])

with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
fo.write(graph.json())
fo.write(graph)

with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
fo.write(relay.save_param_dict(params))

def download_img_labels():
""" Download an image and imagenet1k class labels for test"""
from mxnet.gluon.utils import download

img_name = 'cat.png'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
Expand All @@ -97,11 +108,11 @@ def download_img_labels():
w = csv.writer(fout)
w.writerows(synset.items())

def test_build(target_dir):
def test_build(build_dir):
""" Sanity check with random input"""
graph = open(osp.join(target_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
graph = open(osp.join(build_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(build_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
Expand All @@ -112,10 +123,11 @@ def test_build(target_dir):

if __name__ == '__main__':
logger.info("building the model")
build(target_dir)
build(build_dir)
logger.info("build was successful")
logger.info("test the build artifacts")
test_build(target_dir)
test_build(build_dir)
logger.info("test was successful")
download_img_labels()
logger.info("image and synset downloads are successful")
if args.pretrained:
download_img_labels()
logger.info("image and synset downloads are successful")
5 changes: 2 additions & 3 deletions rust/frontend/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn main() {
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
graph,
&lib,
&ctx.device_type,
&ctx.device_id
Expand All @@ -107,8 +107,7 @@ fn main() {
.get_function("set_input", false)
.unwrap();

let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
call_packed!(set_input_fn, "data".to_string(), &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
Expand Down
92 changes: 0 additions & 92 deletions rust/frontend/src/bytearray.rs

This file was deleted.

Loading