diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8467f6a92ea8..f08f86169bfb 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,6 +22,7 @@ members = [ "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", + "runtime/tests/test_wasm32", "runtime/tests/test_nn", "frontend", "frontend/tests/basics", diff --git a/rust/common/build.rs b/rust/common/build.rs index b3ae7b6d1837..07326f41f801 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -51,6 +51,7 @@ fn main() { .layout_tests(false) .derive_partialeq(true) .derive_eq(true) + .derive_default(true) .generate() .expect("unable to generate bindings") .write_to_file(PathBuf::from("src/c_runtime_api.rs")) diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs index d0a66a62b8bf..a8f4f989c146 100644 --- a/rust/common/src/array.rs +++ b/rust/common/src/array.rs @@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray { shape: arr.shape().as_ptr() as *const i64 as *mut i64, strides: arr.strides().as_ptr() as *const isize as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 2ae64e7a32b3..33b2993bf3da 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -31,8 +31,13 @@ pub mod ffi { include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - pub type BackendPackedCFunc = - extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + out_ret_value: *mut TVMValue, + out_ret_tcode: *mut u32, + ) -> c_int; } pub mod array; diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 2b6c7c217e28..c38b3ff8e527 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -297,6 +297,7 @@ impl<'a> Tensor<'a> { self.strides.as_ref().unwrap().as_ptr() } as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 518bf724f319..71541ba27826 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -382,7 +382,18 @@ named! { // Converts a bytes to String. named! { name, - map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec())) + do_parse!( + len_l: le_u32 >> + len_h: le_u32 >> + data: take!(len_l) >> + ( + if len_h == 0 { + String::from_utf8(data.to_vec()).unwrap() + } else { + panic!("Too long string") + } + ) + ) } // Parses a TVMContext diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs index 856dd78193bc..cb4d7776dd0b 100644 --- a/rust/runtime/src/module/mod.rs +++ b/rust/runtime/src/module/mod.rs @@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< (val, code as i32) }) .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + let ret: TVMRetValue = TVMRetValue::default(); + let (mut ret_val, mut ret_type_code) = ret.to_tvm_value(); + let exit_code = func( + values.as_ptr(), + type_codes.as_ptr(), + values.len() as i32, + &mut ret_val, + &mut ret_type_code, + ); if exit_code == 0 { - Ok(TVMRetValue::default()) + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code)) } else { Err(tvm_common::errors::FuncCallError::get_with_context( func_name.clone(), diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index f473bbf3990a..b8be01270ae7 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -18,7 +18,6 @@ */ use std::{ - env, os::raw::{c_int, c_void}, sync::{ atomic::{AtomicUsize, Ordering}, @@ -27,6 +26,9 @@ use std::{ thread::{self, JoinHandle}, }; +#[cfg(not(target_arch = "wasm32"))] +use std::env; + use crossbeam::channel::{bounded, Receiver, Sender}; use tvm_common::ffi::TVMParallelGroupEnv; @@ -147,7 +149,10 @@ impl ThreadPool { fn run_worker(queue: Receiver) { loop { - let task = queue.recv().expect("should recv"); + let task = match queue.recv() { + Ok(v) => v, + Err(_) => break, + }; let result = task.run(); if result == ::min_value() { break; diff --git a/rust/runtime/tests/test_wasm32/.cargo/config b/rust/runtime/tests/test_wasm32/.cargo/config new file mode 100644 index 000000000000..6b77899cb333 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/.cargo/config @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasi" diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml new file mode 100644 index 000000000000..1d3373a9e60f --- /dev/null +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-wasm32" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs new file mode 100644 index 000000000000..8b72be290267 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/build.rs @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::PathBuf, process::Command}; + +fn main() { + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest_wasm32.a"); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + obj_file.exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + let output = Command::new(ar) + .arg("rcs") + .arg(&lib_file) + .arg(&obj_file) + .output() + .expect("Failed to execute command"); + assert!( + lib_file.exists(), + "Could not create archive: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-lib=static=test_wasm32"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); +} diff --git a/rust/runtime/tests/test_wasm32/src/build_test_lib.py b/rust/runtime/tests/test_wasm32/src/build_test_lib.py new file mode 100755 index 000000000000..6016c60c4ea3 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/src/build_test_lib.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm +from tvm import te + +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_wasm32/src/main.rs b/rust/runtime/tests/test_wasm32/src/main.rs new file mode 100644 index 000000000000..a46cfa979bec --- /dev/null +++ b/rust/runtime/tests/test_wasm32/src/main.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern "C" { + static __tvm_module_ctx: i32; +} + +#[no_mangle] +unsafe fn __get_tvm_module_ctx() -> i32 { + // Refer a symbol in the libtest_wasm32.a to make sure that the link of the + // library is not optimized out. + __tvm_module_ctx +} + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +fn main() { + // try static + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 0ec0a2e86450..04d6c94e10e8 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -103,7 +103,8 @@ "KEYS", "DISCLAIMER", "Jenkinsfile", - # sgx config + # cargo config + "rust/runtime/tests/test_wasm32/.cargo/config", "apps/sgx/.cargo/config", # html for demo purposes "tests/webgl/test_static_webgl_library.html", diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index fae07d34e992..5529632e50ea 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -54,6 +54,12 @@ cd tests/test_tvm_dso cargo run cd - +# # run wasm32 test +# cd tests/test_wasm32 +# cargo build +# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm +# cd - + # run nn graph test cd tests/test_nn cargo run