Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
shingjan committed Jul 29, 2021
2 parents 4438dcb + bef7bf9 commit 2bc9da5
Show file tree
Hide file tree
Showing 106 changed files with 1,945 additions and 1,163 deletions.
17 changes: 8 additions & 9 deletions apps/wasm-standalone/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,10 @@ This project should be considered **experimental** at the very early stage, all

- Build DL library in the WebAssembly format.

- Download model
- Compile the model

```
cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx
```

- Compile

```
LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx
cd wasm-graph/tools && LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3
```

### Build wasm-graph package
Expand Down Expand Up @@ -170,9 +164,14 @@ $ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/syn
$ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv
original image dimensions: (256, 256)
resized image dimensions: (224, 224)
input image belongs to the class `tabby, tabby cat`
input image belongs to the class `tiger cat`
```

Note: this example also works without WASI support. Please modify `wasm-graph/.cargo/config` to change the target to
`wasm32-unknown-unknown` and uncomment the raw wasm engine in `wasm-runtime/src/graph.rs` to run in pure wasm32. SIMD
may not be supported without WASI support. You may also need to delete ` -mattr=+simd128` in the
[build script](apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py).

## Future Work

### More networks support
Expand Down
18 changes: 10 additions & 8 deletions apps/wasm-standalone/wasm-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ lazy_static! {
"/lib/graph.json"
)))
.unwrap();

let params_bytes =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params"));
let params = tvm_graph_rt::load_param_dict(params_bytes)
Expand All @@ -57,6 +58,7 @@ lazy_static! {
.collect::<HashMap<String, TVMTensor<'static>>>();

let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();

exec.load_params(params);

Mutex::new(exec)
Expand All @@ -68,14 +70,14 @@ pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 {
let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) };
let input: TVMTensor = in_tensor.as_dltensor().into();

GRAPH_EXECUTOR.lock().unwrap().set_input("data", input);
GRAPH_EXECUTOR.lock().unwrap().run();
let output = GRAPH_EXECUTOR
.lock()
.unwrap()
.get_output(0)
.unwrap()
.as_dltensor(false);
// since this executor is not multi-threaded, we can acquire lock once
let mut executor = GRAPH_EXECUTOR.lock().unwrap();

executor.set_input("data", input);

executor.run();

let output = executor.get_output(0).unwrap().as_dltensor(false);

let out_tensor: Tensor = output.into();
let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) };
Expand Down
2 changes: 1 addition & 1 deletion apps/wasm-standalone/wasm-graph/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
};
pub use tvm_sys::ffi::DLTensor;
use tvm_sys::ffi::{
DLDevice, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU,
DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDevice, DLDeviceType_kDLCPU,
};

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
Expand Down
19 changes: 13 additions & 6 deletions apps/wasm-standalone/wasm-graph/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ use std::ptr;
pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor {
let in_addr = in_addr as *mut u8;

let mut data_vec = Vec::new();
for i in 0..in_size {
data_vec.push(ptr::read(in_addr.offset(i as isize)));
}
let input: Tensor = serde_json::from_slice(&data_vec).unwrap();
println!("DEBUG: in_addr {:?}, in_size {:?}", in_addr, in_size);

let data_vec = unsafe { std::slice::from_raw_parts(in_addr, in_size) };

input
let input = serde_json::from_slice(&data_vec);
match input {
Ok(result) => {
println!("DEBUG: SER SUCCEED!!! and Ok");
result
}
Err(e) => {
panic!("DEBUG: SER SUCCEED!!! but Err, {:?}", &e);
}
}
}

pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize {
Expand Down
67 changes: 49 additions & 18 deletions apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.

"""Builds a simple graph for testing."""
"""Builds a simple resnet50 graph for testing."""
import argparse
import os
import subprocess
Expand All @@ -25,47 +25,78 @@
import onnx
import tvm
from tvm import relay, runtime
from tvm.contrib.download import download_testdata
from tvm.contrib import graph_executor

from PIL import Image
import numpy as np
import tvm.relay as relay

def _get_mod_and_params(model_file):
onnx_model = onnx.load(model_file)
shape_dict = {}
for input in onnx_model.graph.input:
shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
# This example uses resnet50-v2-7 model
model_url = "".join(
[
"https://github.com/onnx/models/raw/",
"master/vision/classification/resnet/model/",
"resnet50-v2-7.onnx",
]
)

return relay.frontend.from_onnx(onnx_model, shape_dict)


def build_graph_lib(model_file, opt_level):
def build_graph_lib(opt_level):
"""Compiles the pre-trained model with TVM"""
out_dir = os.path.join(sys.path[0], "../lib")
if not os.path.exists(out_dir):
os.makedirs(out_dir)

# Compile the relay mod
mod, params = _get_mod_and_params(model_file)
# Follow the tutorial to download and compile the model
model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx")
onnx_model = onnx.load(model_path)

img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg"
img_path = download_testdata(img_url, "imagenet_cat.png", module="data")

# Resize it to 224x224
resized_image = Image.open(img_path).resize((224, 224))
img_data = np.asarray(resized_image).astype("float32")

# Our input image is in HWC layout while ONNX expects CHW input, so convert the array
img_data = np.transpose(img_data, (2, 0, 1))

# Normalize according to the ImageNet input specification
imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev

# Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
img_data = np.expand_dims(norm_img_data, axis=0)

input_name = "data"
shape_dict = {input_name: img_data.shape}

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
target = "llvm -mtriple=wasm32-unknown-unknown -mattr=+simd128 --system-lib"

with tvm.transform.PassContext(opt_level=opt_level):
graph_json, lib, params = relay.build(mod, target=target, params=params)
factory = relay.build(mod, target=target, params=params)

# Save the model artifacts to obj_file
obj_file = os.path.join(out_dir, "graph.o")
lib.save(obj_file)
factory.get_lib().save(obj_file)

# Run llvm-ar to archive obj_file into lib_file
lib_file = os.path.join(out_dir, "libgraph_wasm32.a")
cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), "rcs", lib_file, obj_file]
subprocess.run(cmds)

# Save the json and params
with open(os.path.join(out_dir, "graph.json"), "w") as f_graph:
f_graph.write(graph_json)

f_graph.write(factory.get_graph_json())
with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
f_params.write(runtime.save_param_dict(params))
f_params.write(runtime.save_param_dict(factory.get_params()))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ONNX model build example")
parser.add_argument("model_file", type=str, help="the path of onnx model file")
parser.add_argument(
"-O",
"--opt-level",
Expand All @@ -75,4 +106,4 @@ def build_graph_lib(model_file, opt_level):
)
args = parser.parse_args()

build_graph_lib(args.model_file, args.opt_level)
build_graph_lib(args.opt_level)
4 changes: 2 additions & 2 deletions apps/wasm-standalone/wasm-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ license = "Apache-2.0"
keywords = ["wasm", "machine learning", "wasmtime"]

[dependencies]
wasmtime = "0.16.0"
wasmtime-wasi = "0.16.0"
wasmtime = "0.28.0"
wasmtime-wasi = "0.28.0"
anyhow = "1.0.31"
serde = "1.0.53"
serde_json = "1.0.53"
Expand Down
79 changes: 52 additions & 27 deletions apps/wasm-standalone/wasm-runtime/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

use anyhow::Result;
use wasmtime::*;
use wasmtime_wasi::{Wasi, WasiCtx};
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder};

use super::Tensor;

pub struct GraphExecutor {
pub(crate) wasm_addr: i32,
pub(crate) input_size: i32,
pub(crate) output_size: i32,
pub(crate) store: Option<Store<WasiCtx>>,
// None-WASI version:
// pub(crate) store: Option<Store<()>>,
pub(crate) instance: Option<Instance>,
}

Expand All @@ -37,25 +40,44 @@ impl GraphExecutor {
wasm_addr: 0,
input_size: 0,
output_size: 0,
store: None,
instance: None,
}
}

pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> {
let engine = Engine::new(Config::new().wasm_simd(true));
let store = Store::new(&engine);
// It seems WASI in this example is not necessary

// None WASI version: works with no SIMD
// let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
// let mut store = Store::new(&engine, ());
// let module = Module::from_file(store.engine(), &wasm_graph_file)?;

// let instance = Instance::new(&mut store, &module, &[])?;

// self.instance = Some(instance);
// self.store = Some(store);

// Ok(())

// WASI version:
let engine = Engine::new(Config::new().wasm_simd(true)).unwrap();
// First set up our linker which is going to be linking modules together. We
// want our linker to have wasi available, so we set that up here as well.
let mut linker = Linker::new(&store);
let mut linker = Linker::new(&engine);
wasmtime_wasi::add_to_linker(&mut linker, |s| s)?;
// Create an instance of `Wasi` which contains a `WasiCtx`. Note that
// `WasiCtx` provides a number of ways to configure what the target program
// will have access to.
let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?);
wasi.add_to_linker(&mut linker)?;
let wasi = WasiCtxBuilder::new()
.inherit_stdio()
.inherit_args()?
.build();
let mut store = Store::new(&engine, wasi);

let module = Module::from_file(&store, &wasm_graph_file)?;
self.instance = Some(linker.instantiate(&module)?);
let module = Module::from_file(&engine, &wasm_graph_file)?;
self.instance = Some(linker.instantiate(&mut store, &module)?);
self.store = Some(store);

Ok(())
}
Expand All @@ -65,26 +87,24 @@ impl GraphExecutor {
.instance
.as_ref()
.unwrap()
.get_memory("memory")
.get_memory(self.store.as_mut().unwrap(), "memory")
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;

// Specify the wasm address to access the wasm memory.
let wasm_addr = memory.data_size();
let wasm_addr = memory.data_size(self.store.as_mut().unwrap());

// Serialize the data into a JSON string.
let in_data = serde_json::to_vec(&input_data)?;
let in_size = in_data.len();

// Grow up memory size according to in_size to avoid memory leak.
memory.grow((in_size >> 16) as u32 + 1)?;
memory.grow(self.store.as_mut().unwrap(), (in_size >> 16) as u32 + 1)?;

// Insert the input data into wasm memory.
for i in 0..in_size {
unsafe {
memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap();
}
}
memory.write(self.store.as_mut().unwrap(), wasm_addr, &in_data)?;

self.wasm_addr = wasm_addr as i32;
self.input_size = in_size as i32;

Ok(())
}

Expand All @@ -94,11 +114,12 @@ impl GraphExecutor {
.instance
.as_ref()
.unwrap()
.get_func("run")
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?
.get2::<i32, i32, i32>()?;
.get_func(self.store.as_mut().unwrap(), "run")
.ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?;

let out_size = run(self.wasm_addr, self.input_size)?;
let params = [Val::I32(self.wasm_addr), Val::I32(self.input_size)];
let out_size = run.call(self.store.as_mut().unwrap(), &params[..])?;
let out_size = (*out_size)[0].unwrap_i32();
if out_size == 0 {
panic!("graph run failed!");
}
Expand All @@ -107,18 +128,22 @@ impl GraphExecutor {
Ok(())
}

pub fn get_output(&self) -> Result<Tensor> {
pub fn get_output(&mut self) -> Result<Tensor> {
let memory = self
.instance
.as_ref()
.unwrap()
.get_memory("memory")
.get_memory(self.store.as_mut().unwrap(), "memory")
.ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;

let out_data = unsafe {
&memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize]
};
let out_vec: Tensor = serde_json::from_slice(out_data).unwrap();
let mut out_data = vec![0 as u8; self.output_size as _];
memory.read(
self.store.as_mut().unwrap(),
self.wasm_addr as _,
&mut out_data,
)?;

let out_vec: Tensor = serde_json::from_slice(&out_data).unwrap();
Ok(out_vec)
}
}
Expand Down
3 changes: 3 additions & 0 deletions docker/Dockerfile.ci_arm
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ FROM ubuntu:18.04
RUN apt-get update --fix-missing
RUN apt-get install -y ca-certificates gnupg2

# Globally disable pip cache
RUN pip config set global.cache-dir false

COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh

Expand Down
Loading

0 comments on commit 2bc9da5

Please sign in to comment.