From c33457c005b08f9cd96ed8dfadfd637803c25936 Mon Sep 17 00:00:00 2001 From: Barak Ugav Date: Sat, 3 Aug 2024 15:38:43 +0300 Subject: [PATCH] Prepare for release `0.1.0-rc.1` (#11) * Fill more metadata in Cargo.toml * Remove 'extension-' prefix from features names * Implement common traints for structs and enums * Remove re-use at lib level * Crate documentation * Change version to 0.1.0-rc.1 * Avoid linking to static libs when building docs * Enable all features when built by docs.rs --- .gitignore | 1 + Cargo.toml | 32 +++- README.md | 14 +- examples/hello_world_add/Cargo.toml | 2 +- examples/hello_world_add/src/main.rs | 17 +- examples/hello_world_add_module/Cargo.toml | 5 +- examples/hello_world_add_module/src/main.rs | 10 +- executorch-sys/Cargo.toml | 33 +++- executorch-sys/LICENSE | 201 ++++++++++++++++++++ executorch-sys/README.md | 76 ++++++++ executorch-sys/build.rs | 27 +-- executorch-sys/src/lib.rs | 79 ++++++++ src/data_loader.rs | 20 +- src/error.rs | 49 ++++- src/evalue.rs | 37 +++- src/lib.rs | 102 +++++++--- src/memory.rs | 6 + src/module.rs | 19 +- src/platform.rs | 12 ++ src/program.rs | 50 ++++- src/tensor.rs | 93 +++++++++ src/util.rs | 39 ++++ 22 files changed, 845 insertions(+), 79 deletions(-) create mode 100644 executorch-sys/LICENSE create mode 100644 executorch-sys/README.md create mode 100644 src/platform.rs diff --git a/.gitignore b/.gitignore index 46b3c51..8d7b4b6 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ target/ *.pdb .vscode/ +.venv/ diff --git a/Cargo.toml b/Cargo.toml index 6b709f4..a7c8faa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,14 +7,38 @@ members = [ [package] name = "executorch" -version = "0.1.0" +version = "0.1.0-rc.1" authors = ["Barak Ugav "] edition = "2021" +description = "Rust bindings for ExecuTorch - On-device AI across mobile, embedded and edge for PyTorch" +readme = "README.md" repository = "https://github.com/barakugav/executorch-rs" license = "Apache-2.0" +keywords = [ + "executorch", + "pytorch", + "ai", + "ml", + "machine-learning", + "mobile", + "embedded", + "edge-device", + "bindings", +] +categories = [ + "algorithms", + "mathematics", + "embedded", + "no-std", + "no-std::no-alloc", +] +include = ["Cargo.toml", "src/", "README.md", "LICENSE"] + +[package.metadata.docs.rs] +features = ["data-loader", "module"] [dependencies] -executorch-sys = { path = "executorch-sys" } +executorch-sys = { path = "executorch-sys", version = "0.1.0-rc.1" } ndarray = "0.15.6" log = "0.4.22" @@ -24,5 +48,5 @@ cc = "1.1.6" envsubst = "0.2.1" [features] -extension-data-loader = ["executorch-sys/extension-data-loader"] -extension-module = ["executorch-sys/extension-module"] +data-loader = ["executorch-sys/data-loader"] +module = ["executorch-sys/module"] diff --git a/README.md b/README.md index 958ff88..dbc28c6 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ with open("model.pte", "wb") as file: ``` Execute the model in Rust: ```rust -use executorch::{EValue, Module, Tag, Tensor, TensorImpl}; +use executorch::evalue::{EValue, Tag}; +use executorch::module::Module; +use executorch::tensor::{Tensor, TensorImpl}; use ndarray::array; let mut module = Module::new("model.pte", None); @@ -45,10 +47,10 @@ let outputs = module.forward(&[input_evalue1, input_evalue2]).unwrap(); assert_eq!(outputs.len(), 1); let output = outputs.into_iter().next().unwrap(); assert_eq!(output.tag(), Some(Tag::Tensor)); -let output = output.as_tensor().as_array::(); +let output = output.as_tensor(); println!("Output tensor computed: {:?}", output); -assert_eq!(output, array![2.0].into_dyn()); +assert_eq!(array![2.0_f32], output.as_array()); ``` See `example/hello_world_add` and `example/hello_world_add_module` for the complete examples. @@ -121,6 +123,6 @@ println!("cargo::rustc-link-search={}/kernels/portable/", libs_dir); Note that the `portable_ops_lib` is linked with `+whole-archive` to ensure that all symbols are included in the binary. ## Cargo Features -- `default`: disables all features. -- `extension-data-loader`: include the `FileDataLoader` strut. The `libextension_data_loader.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON`. -- `extension-module`: include the `Module` strut. The `libextension_module_static.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_MODULE=ON`. +By default all features are disabled. +- `data-loader`: include the `FileDataLoader` struct. The `libextension_data_loader.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON`. +- `module`: include the `Module` struct. The `libextension_module_static.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_MODULE=ON`. diff --git a/examples/hello_world_add/Cargo.toml b/examples/hello_world_add/Cargo.toml index 370c288..75e5923 100644 --- a/examples/hello_world_add/Cargo.toml +++ b/examples/hello_world_add/Cargo.toml @@ -5,7 +5,7 @@ version = "0.0.0" edition = "2021" [dependencies] -executorch = { path = "../../", features = ["extension-data-loader"] } +executorch = { path = "../../", features = ["data-loader"] } log = "0.4.22" env_logger = "0.11.3" ndarray = "0.15.6" diff --git a/examples/hello_world_add/src/main.rs b/examples/hello_world_add/src/main.rs index ca4df79..2302390 100644 --- a/examples/hello_world_add/src/main.rs +++ b/examples/hello_world_add/src/main.rs @@ -1,11 +1,12 @@ #![deny(warnings)] use executorch::data_loader::FileDataLoader; +use executorch::evalue::{EValue, Tag}; +use executorch::memory::{HierarchicalAllocator, MallocMemoryAllocator, MemoryManager}; +use executorch::program::{Program, ProgramVerification}; +use executorch::tensor::{Tensor, TensorImpl}; use executorch::util::Span; -use executorch::{ - EValue, HierarchicalAllocator, MallocMemoryAllocator, MemoryManager, Program, - ProgramVerification, Tag, Tensor, TensorImpl, -}; + use ndarray::array; use std::vec; @@ -14,7 +15,7 @@ fn main() { .filter_level(log::LevelFilter::Debug) .init(); - executorch::pal_init(); + executorch::platform::pal_init(); let mut file_data_loader = FileDataLoader::new("model.pte", None).unwrap(); @@ -57,10 +58,10 @@ fn main() { method_exe.set_input(&input_evalue2, 1).unwrap(); let outputs = method_exe.execute().unwrap(); - let output = outputs.get_output(0); + let output = &outputs[0]; assert_eq!(output.tag(), Some(Tag::Tensor)); - let output = output.as_tensor().as_array_dyn::(); + let output = output.as_tensor(); println!("Output tensor computed: {:?}", output); - assert_eq!(output, array![2.0].into_dyn()); + assert_eq!(array![2.0_f32], output.as_array()); } diff --git a/examples/hello_world_add_module/Cargo.toml b/examples/hello_world_add_module/Cargo.toml index f66bae6..b599b3d 100644 --- a/examples/hello_world_add_module/Cargo.toml +++ b/examples/hello_world_add_module/Cargo.toml @@ -5,10 +5,7 @@ version = "0.0.0" edition = "2021" [dependencies] -executorch = { path = "../../", features = [ - "extension-data-loader", - "extension-module", -] } +executorch = { path = "../../", features = ["data-loader", "module"] } log = "0.4.22" env_logger = "0.11.3" ndarray = "0.15.6" diff --git a/examples/hello_world_add_module/src/main.rs b/examples/hello_world_add_module/src/main.rs index 730a949..30d6d03 100644 --- a/examples/hello_world_add_module/src/main.rs +++ b/examples/hello_world_add_module/src/main.rs @@ -1,6 +1,8 @@ #![deny(warnings)] -use executorch::{EValue, Module, Tag, Tensor, TensorImpl}; +use executorch::evalue::{EValue, Tag}; +use executorch::module::Module; +use executorch::tensor::{Tensor, TensorImpl}; use ndarray::array; fn main() { @@ -8,7 +10,7 @@ fn main() { .filter_level(log::LevelFilter::Debug) .init(); - executorch::pal_init(); + executorch::platform::pal_init(); let mut module = Module::new("model.pte", None); @@ -24,8 +26,8 @@ fn main() { assert_eq!(outputs.len(), 1); let output = outputs.into_iter().next().unwrap(); assert_eq!(output.tag(), Some(Tag::Tensor)); - let output = output.as_tensor().as_array_dyn::(); + let output = output.as_tensor(); println!("Output tensor computed: {:?}", output); - assert_eq!(output, array![2.0].into_dyn()); + assert_eq!(array![2.0_f32], output.as_array()); } diff --git a/executorch-sys/Cargo.toml b/executorch-sys/Cargo.toml index 02cb772..d6af534 100644 --- a/executorch-sys/Cargo.toml +++ b/executorch-sys/Cargo.toml @@ -1,11 +1,38 @@ [package] name = "executorch-sys" -version = "0.1.0" +version = "0.1.0-rc.1" +authors = ["Barak Ugav "] edition = "2021" +description = "Unsafe Rust bindings for ExecuTorch - On-device AI across mobile, embedded and edge for PyTorch" +readme = "README.md" +repository = "https://github.com/barakugav/executorch-rs" +license = "Apache-2.0" +keywords = [ + "executorch", + "pytorch", + "ai", + "ml", + "machine-learning", + "mobile", + "embedded", + "edge-device", + "bindings", +] +categories = [ + "algorithms", + "mathematics", + "embedded", + "no-std", + "no-std::no-alloc", +] +links = "executorch" + +[package.metadata.docs.rs] +features = ["data-loader", "module"] [features] -extension-data-loader = [] -extension-module = [] +data-loader = [] +module = [] [dependencies] diff --git a/executorch-sys/LICENSE b/executorch-sys/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/executorch-sys/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/executorch-sys/README.md b/executorch-sys/README.md new file mode 100644 index 0000000..9773243 --- /dev/null +++ b/executorch-sys/README.md @@ -0,0 +1,76 @@ +# executorch-sys + +For a general description of the project, see the the `executorch` crate. + +## Build +To build the library, you need to build the C++ library first. +The C++ library allow for great flexibility with many flags, customizing which modules, kernels, and extensions are built. +Multiple static libraries are built, and the Rust library links to them. +In the following example we build the C++ library with the necessary flags to run example `hello_world_add`: +```bash +# Clone the C++ library +cd ${TEMP_DIR} +git clone --depth 1 --branch v0.2.1 https://github.com/pytorch/executorch.git +cd executorch +git submodule sync --recursive +git submodule update --init --recursive + +# Install requirements +./install_requirements.sh + +# Build C++ library +mkdir cmake-out && cd cmake-out +cmake \ + -DDEXECUTORCH_SELECT_OPS_LIST=aten::add.out \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=OFF \ + -DBUILD_EXECUTORCH_PORTABLE_OPS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + .. +make -j + +# Static libraries are in cmake-out/ +# core: +# cmake-out/libexecutorch.a +# cmake-out/libexecutorch_no_prim_ops.a +# kernels implementations: +# cmake-out/kernels/portable/libportable_ops_lib.a +# cmake-out/kernels/portable/libportable_kernels.a +# extension data loader, enabled with EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON: +# cmake-out/extension/data_loader/libextension_data_loader.a +# extension module, enabled with EXECUTORCH_BUILD_EXTENSION_MODULE=ON: +# cmake-out/extension/module/libextension_module_static.a + +# Run example +# We set EXECUTORCH_RS_EXECUTORCH_LIB_DIR to the path of the C++ build output +cd ${EXECUTORCH_RS_DIR}/examples/hello_world_add +python export_model.py +EXECUTORCH_RS_EXECUTORCH_LIB_DIR=${TEMP_DIR}/executorch/cmake-out cargo run +``` + +The `executorch` crate will always look for the following static libraries: +- `libexecutorch.a` +- `libexecutorch_no_prim_ops.a` + +Additional libs are required if feature flags are enabled (see next section): +- `libextension_data_loader.a` +- `libextension_module_static.a` + +The static libraries of the kernels implementations are required only if your model uses them, and they should be **linked manually** by the binary that uses the `executorch` crate. +For example, the `hello_world_add` example uses a model with a single addition operation, so it compile the C++ library with `DEXECUTORCH_SELECT_OPS_LIST=aten::add.out` and contain the following lines in its `build.rs`: +```rust +println!("cargo::rustc-link-lib=static=portable_kernels"); +println!("cargo::rustc-link-lib=static:+whole-archive=portable_ops_lib"); + +let libs_dir = std::env::var("EXECUTORCH_RS_EXECUTORCH_LIB_DIR").unwrap(); +println!("cargo::rustc-link-search={}/kernels/portable/", libs_dir); +``` +Note that the `portable_ops_lib` is linked with `+whole-archive` to ensure that all symbols are included in the binary. + +## Cargo Features +By default all features are disabled. +- `data-loader`: include the `FileDataLoader` struct. The `libextension_data_loader.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON`. +- `module`: include the `Module` struct. The `libextension_module_static.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_MODULE=ON`. diff --git a/executorch-sys/build.rs b/executorch-sys/build.rs index 4c3718f..47b238e 100644 --- a/executorch-sys/build.rs +++ b/executorch-sys/build.rs @@ -104,18 +104,18 @@ fn generate_bindings(executorch_headers: &Path) { .allowlist_item("torch::executor::MemoryManager") .allowlist_item("torch::executor::MethodMeta") .allowlist_item("torch::executor::util::MallocMemoryAllocator") - // extension-data-loader + // feature data-loader .allowlist_item("torch::executor::util::FileDataLoader") .allowlist_item("torch::executor::util::MmapDataLoader") .allowlist_item("torch::executor::util::BufferDataLoader") - // extension-module + // feature module .allowlist_item("torch::executor::Module") .blocklist_item("std::.*") .blocklist_item("torch::executor::Method_StepState") .blocklist_item("torch::executor::Method_InitializationState") .blocklist_item("torch::executor::Program_kMinHeadBytes") .blocklist_item("torch::executor::EventTracerEntry") - // extension-module + // feature module .blocklist_item("torch::executor::Module_MethodHolder") .blocklist_item("torch::executor::Module_load_method") .blocklist_item("torch::executor::Module_is_method_loaded") @@ -135,11 +135,11 @@ fn generate_bindings(executorch_headers: &Path) { .opaque_type("torch::executor::MemoryAllocator") .opaque_type("torch::executor::HierarchicalAllocator") .opaque_type("torch::executor::TensorInfo") - // extension-data-loader + // feature data-loader .opaque_type("torch::executor::util::FileDataLoader") .opaque_type("torch::executor::util::MmapDataLoader") .opaque_type("torch::executor::util::BufferDataLoader") - // extension-module + // feature module .opaque_type("torch::executor::Module") .rustified_enum("torch::executor::Error") .rustified_enum("torch::executor::ScalarType") @@ -147,9 +147,9 @@ fn generate_bindings(executorch_headers: &Path) { .rustified_enum("torch::executor::Program_Verification") .rustified_enum("torch::executor::Program_HeaderStatus") .rustified_enum("torch::executor::TensorShapeDynamism") - // extension-data-loader + // feature data-loader .rustified_enum("torch::executor::util::MmapDataLoader_MlockConfig") - // extension-module + // feature module .rustified_enum("torch::executor::Module_MlockConfig") .no_copy(".*") // TODO: specific some exact types, regex act weird .manually_drop_union(".*") @@ -164,6 +164,11 @@ fn generate_bindings(executorch_headers: &Path) { } fn link_executorch() { + if std::env::var("DOCS_RS").is_ok() { + // Skip linking to the static library when building documentation + return; + } + let libs_dir = std::env::var("EXECUTORCH_RS_EXECUTORCH_LIB_DIR") .expect("EXECUTORCH_RS_EXECUTORCH_LIB_DIR is not set, can't locate executorch static libs"); let libs_dir = envsubst::substitute( @@ -184,7 +189,7 @@ fn link_executorch() { println!("cargo::rustc-link-lib=static=executorch"); println!("cargo::rustc-link-lib=static=executorch_no_prim_ops"); - if cfg!(feature = "extension-data-loader") { + if cfg!(feature = "data-loader") { println!( "cargo::rustc-link-search={}/extension/data_loader/", libs_dir @@ -192,7 +197,7 @@ fn link_executorch() { println!("cargo::rustc-link-lib=static=extension_data_loader"); } - if cfg!(feature = "extension-module") { + if cfg!(feature = "module") { println!("cargo::rustc-link-search={}/extension/module/", libs_dir); // TODO: extension_module or extension_module_static ? println!("cargo::rustc-link-lib=static=extension_module_static"); @@ -207,10 +212,10 @@ fn cpp_ext_dir() -> PathBuf { fn cpp_defines() -> Vec<&'static str> { let mut defines = vec![]; - if cfg!(feature = "extension-data-loader") { + if cfg!(feature = "data-loader") { defines.push("EXECUTORCH_RS_EXTENSION_DATA_LOADER"); } - if cfg!(feature = "extension-module") { + if cfg!(feature = "module") { defines.push("EXECUTORCH_RS_EXTENSION_MODULE"); } defines diff --git a/executorch-sys/src/lib.rs b/executorch-sys/src/lib.rs index e2d6019..974ed43 100644 --- a/executorch-sys/src/lib.rs +++ b/executorch-sys/src/lib.rs @@ -1,3 +1,82 @@ +//! Unsafe bindings for ExecuTorch - On-device AI across mobile, embedded and edge for PyTorch. +//! +//! Provides a low level Rust bindings for the ExecuTorch library. +//! For the common use case, it is recommended to use the high-level API provided by the `executorch` crate, where +//! a more detailed documentation can be found. +//! +//! +//! To build the library, you need to build the C++ library first. +//! The C++ library allow for great flexibility with many flags, customizing which modules, kernels, and extensions are built. +//! Multiple static libraries are built, and the Rust library links to them. +//! In the following example we build the C++ library with the necessary flags to run example `hello_world_add`: +//! ```bash +//! # Clone the C++ library +//! cd ${TEMP_DIR} +//! git clone --depth 1 --branch v0.2.1 https://github.com/pytorch/executorch.git +//! cd executorch +//! git submodule sync --recursive +//! git submodule update --init --recursive +//! +//! # Install requirements +//! ./install_requirements.sh +//! +//! # Build C++ library +//! mkdir cmake-out && cd cmake-out +//! cmake \ +//! -DDEXECUTORCH_SELECT_OPS_LIST=aten::add.out \ +//! -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ +//! -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=OFF \ +//! -DBUILD_EXECUTORCH_PORTABLE_OPS=ON \ +//! -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ +//! -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ +//! -DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \ +//! -DEXECUTORCH_ENABLE_LOGGING=ON \ +//! .. +//! make -j +//! +//! # Static libraries are in cmake-out/ +//! # core: +//! # cmake-out/libexecutorch.a +//! # cmake-out/libexecutorch_no_prim_ops.a +//! # kernels implementations: +//! # cmake-out/kernels/portable/libportable_ops_lib.a +//! # cmake-out/kernels/portable/libportable_kernels.a +//! # extension data loader, enabled with EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON: +//! # cmake-out/extension/data_loader/libextension_data_loader.a +//! # extension module, enabled with EXECUTORCH_BUILD_EXTENSION_MODULE=ON: +//! # cmake-out/extension/module/libextension_module_static.a +//! +//! # Run example +//! # We set EXECUTORCH_RS_EXECUTORCH_LIB_DIR to the path of the C++ build output +//! cd ${EXECUTORCH_RS_DIR}/examples/hello_world_add +//! python export_model.py +//! EXECUTORCH_RS_EXECUTORCH_LIB_DIR=${TEMP_DIR}/executorch/cmake-out cargo run +//! ``` +//! +//! The `executorch` crate will always look for the following static libraries: +//! - `libexecutorch.a` +//! - `libexecutorch_no_prim_ops.a` +//! +//! Additional libs are required if feature flags are enabled (see next section): +//! - `libextension_data_loader.a` +//! - `libextension_module_static.a` +//! +//! The static libraries of the kernels implementations are required only if your model uses them, and they should be **linked manually** by the binary that uses the `executorch` crate. +//! For example, the `hello_world_add` example uses a model with a single addition operation, so it compile the C++ library with `DEXECUTORCH_SELECT_OPS_LIST=aten::add.out` and contain the following lines in its `build.rs`: +//! ```rust +//! println!("cargo::rustc-link-lib=static=portable_kernels"); +//! println!("cargo::rustc-link-lib=static:+whole-archive=portable_ops_lib"); +//! +//! let libs_dir = std::env::var("EXECUTORCH_RS_EXECUTORCH_LIB_DIR").unwrap(); +//! println!("cargo::rustc-link-search={}/kernels/portable/", libs_dir); +//! ``` +//! Note that the `portable_ops_lib` is linked with `+whole-archive` to ensure that all symbols are included in the binary. +//! +//! ## Cargo Features +//! By default all features are disabled. +//! - `data-loader`: include the `FileDataLoader` struct. The `libextension_data_loader.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON`. +//! - `module`: include the `Module` struct. The `libextension_module_static.a` static library is required, compile C++ `executorch` with `EXECUTORCH_BUILD_EXTENSION_MODULE=ON`. + mod c_link { #![allow(dead_code)] #![allow(unused_imports)] diff --git a/src/data_loader.rs b/src/data_loader.rs index ad83899..362b7fe 100644 --- a/src/data_loader.rs +++ b/src/data_loader.rs @@ -1,22 +1,31 @@ +//! Data loaders for loading execution plans (models) from a data source. +//! +//! Data loaders are used to load execution plans from a data source, such as a file or a buffer. +//! To include the data loader functionality, enable the `data-loader` feature. + use std::cell::UnsafeCell; use crate::et_c; /// Loads from a data source. +/// +/// This struct is like a base class for data loaders. All other data loaders implement `AsRef` and other +/// structs, such as `Program`, take a reference to `DataLoader` instead of the concrete data loader type. pub struct DataLoader(pub(crate) UnsafeCell); -#[cfg(feature = "extension-data-loader")] +#[cfg(feature = "data-loader")] pub use file_data_loader::{BufferDataLoader, FileDataLoader, MlockConfig, MmapDataLoader}; -#[cfg(feature = "extension-data-loader")] +#[cfg(feature = "data-loader")] mod file_data_loader { use std::cell::UnsafeCell; use std::ffi::CString; use std::marker::PhantomData; use std::path::Path; + use crate::error::Result; use crate::util::IntoRust; - use crate::{et_c, et_rs_c, Result}; + use crate::{et_c, et_rs_c}; use super::DataLoader; @@ -138,5 +147,10 @@ mod file_data_loader { } } + /// Describes how and whether to lock loaded pages with `mlock()`. + /// + /// Using `mlock()` typically loads all of the pages immediately, and will + /// typically ensure that they are not swapped out. The actual behavior + /// will depend on the host system. pub type MlockConfig = et_c::util::MmapDataLoader_MlockConfig; } diff --git a/src/error.rs b/src/error.rs index d80759e..c4c50c1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,11 @@ +//! Error types used in the `executortorch` crate. + use std::mem::ManuallyDrop; use crate::{et_c, et_rs_c, util::IntoRust}; /// ExecuTorch Error type. -#[derive(Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] #[repr(u8)] pub enum Error { /* System errors */ @@ -48,6 +50,31 @@ pub enum Error { /// Execute stage: The handle is invalid. DelegateInvalidHandle = et_c::Error::DelegateInvalidHandle as u8, } +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let msg = match self { + Error::Internal => "An internal error occurred", + Error::InvalidState => "Executor is in an invalid state for a target", + Error::EndOfMethod => "No more steps of execution to run", + Error::NotSupported => "Operation is not supported in the current context", + Error::NotImplemented => "Operation is not yet implemented", + Error::InvalidArgument => "User provided an invalid argument", + Error::InvalidType => "Object is an invalid type for the operation", + Error::OperatorMissing => "Operator(s) missing in the operator registry", + Error::NotFound => "Requested resource could not be found", + Error::MemoryAllocationFailed => "Could not allocate the requested memory", + Error::AccessFailed => "Could not access a resource", + Error::InvalidProgram => "Error caused by the contents of a program", + Error::DelegateInvalidCompatibility => { + "Backend receives an incompatible delegate version" + } + Error::DelegateMemoryAllocationFailed => "Backend fails to allocate memory", + Error::DelegateInvalidHandle => "The handle is invalid", + }; + write!(f, "{}", msg) + } +} +impl std::error::Error for Error {} impl IntoRust for et_c::Error { type RsType = Result<()>; @@ -73,7 +100,8 @@ impl IntoRust for et_c::Error { } } -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; + impl IntoRust for et_c::Result { type RsType = Result; fn rs(self) -> Self::RsType { @@ -122,3 +150,20 @@ impl IntoRust for et_rs_c::Result_MethodMeta { } } } + +#[cfg(test)] +mod tests { + use super::Error; + + #[test] + fn test_error_send() { + fn assert_send() {} + assert_send::(); + } + + #[test] + fn test_error_sync() { + fn assert_sync() {} + assert_sync::(); + } +} diff --git a/src/evalue.rs b/src/evalue.rs index a89bf64..1b79b52 100644 --- a/src/evalue.rs +++ b/src/evalue.rs @@ -1,12 +1,19 @@ +//! Module for `EValue` and related types. +//! +//! `EValue` is a type-erased value that can hold different types like scalars, lists or tensors. It is used to pass +//! arguments to and return values from the runtime. + +use std::fmt::Debug; use std::marker::PhantomData; use std::mem::ManuallyDrop; +use crate::error::{Error, Result}; use crate::util::{ArrayRef, IntoRust}; -use crate::{et_c, et_rs_c, tensor::Tensor, Error, Result}; +use crate::{et_c, et_rs_c, tensor::Tensor}; /// A tag indicating the type of the value stored in an `EValue`. #[repr(u8)] -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub enum Tag { /// Tag for value `Tensor`. Tensor = et_c::Tag::Tensor as u8, @@ -442,6 +449,26 @@ impl<'a> TryFrom<&EValue<'a>> for &'a [Tensor<'a>] { } } } +impl Debug for EValue<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut st = f.debug_struct("EValue"); + st.field("tag", &self.tag()); + match self.tag() { + Some(Tag::Int) => st.field("value", &self.as_i64()), + Some(Tag::Double) => st.field("value", &self.as_f64()), + Some(Tag::Bool) => st.field("value", &self.as_bool()), + Some(Tag::Tensor) => st.field("value", &self.as_tensor()), + Some(Tag::String) => st.field("value", &self.as_chars()), + Some(Tag::ListInt) => st.field("value", &self.as_i64_arr()), + Some(Tag::ListDouble) => st.field("value", &self.as_f64_arr()), + Some(Tag::ListBool) => st.field("value", &self.as_bool_arr()), + Some(Tag::ListTensor) => st.field("value", &self.as_tensor_arr()), + Some(Tag::ListOptionalTensor) => st.field("value", &"Unsupported type"), + None => st.field("value", &"None"), + }; + st.finish() + } +} /// Helper class used to correlate EValues in the executor table, with the /// unwrapped list of the proper type. Because values in the runtime's values @@ -551,9 +578,15 @@ impl<'a, T: BoxedEvalue> BoxedEvalueList<'a, T> { unsafe { ArrayRef::from_inner(&unwrapped_list).as_slice() } } } +impl Debug for BoxedEvalueList<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.get().fmt(f) + } +} /// A trait for types that can be used within a `BoxedEvalueList`. pub trait BoxedEvalue { + /// The `Tag` variant corresponding to boxed type. const TAG: Tag; private_decl! {} } diff --git a/src/lib.rs b/src/lib.rs index 1a15abf..67b5b1d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,35 +1,85 @@ #![deny(warnings)] +#![deny(missing_docs)] + +//! Bindings for ExecuTorch - On-device AI across mobile, embedded and edge for PyTorch. +//! +//! Provides a high-level Rust API for executing PyTorch models on mobile, embedded and edge devices using the +//! [ExecuTorch library](https://pytorch.org/executorch-overview), specifically the [C++ API](https://github.com/pytorch/executorch). +//! PyTorch models are created and exported in Python, and then loaded and executed on-device using the +//! ExecuTorch library. +//! +//! The following example create a simple model in Python, exports it, and then executes it in Rust: +//! +//! Create a model in `Python` and export it: +//! ```python +//! import torch +//! from executorch.exir import to_edge +//! from torch.export import export +//! +//! class Add(torch.nn.Module): +//! def __init__(self): +//! super(Add, self).__init__() +//! +//! def forward(self, x: torch.Tensor, y: torch.Tensor): +//! return x + y +//! +//! +//! aten_dialect = export(Add(), (torch.ones(1), torch.ones(1))) +//! edge_program = to_edge(aten_dialect) +//! executorch_program = edge_program.to_executorch() +//! with open("model.pte", "wb") as file: +//! file.write(executorch_program.buffer) +//! ``` +//! +//! Execute the model in Rust: +//! ```rustuse executorch::evalue::{EValue, Tag}; +//! use executorch::module::Module; +//! use executorch::tensor::{Tensor, TensorImpl}; +//! use ndarray::array; +//! +//! let mut module = Module::new("model.pte", None); +//! +//! let data1 = array![1.0_f32]; +//! let input_tensor1 = TensorImpl::from_array(data1.view()); +//! let input_evalue1 = EValue::from_tensor(Tensor::new(input_tensor1.as_ref())); +//! +//! let data2 = array![1.0_f32]; +//! let input_tensor2 = TensorImpl::from_array(data2.view()); +//! let input_evalue2 = EValue::from_tensor(Tensor::new(input_tensor2.as_ref())); +//! +//! let outputs = module.forward(&[input_evalue1, input_evalue2]).unwrap(); +//! assert_eq!(outputs.len(), 1); +//! let output = outputs.into_iter().next().unwrap(); +//! assert_eq!(output.tag(), Some(Tag::Tensor)); +//! let output = output.as_tensor(); +//! +//! println!("Output tensor computed: {:?}", output); +//! assert_eq!(array![2.0_f32], output.as_array()); +//! ``` +//! +//! The library have a few features that can be enabled or disabled, by default all are disabled: +//! - `module`: Enable the [`module`] API, a high-level API for loading and executing PyTorch models. It is an alternative +//! to the lower-level `Program` API, which is mort suitable for embedded systems. +//! - `data_loader`: Enable the [`data_loader`] module for loading data. +//! +//! The C++ API is still in Alpha, and this Rust lib will continue to change with it. Currently the supported +//! executorch version is `0.2.1`. +//! +//! To use the library you must compile the C++ executorch library yourself, as there are many configurations that +//! determines which modules, backends, and operations are supported. See the `executorch-sys` crate for more info. use executorch_sys::executorch_rs as et_rs_c; use executorch_sys::torch::executor as et_c; #[macro_use] mod private; - -mod error; -pub use error::{Error, Result}; - pub mod data_loader; - -mod memory; -pub use memory::{HierarchicalAllocator, MallocMemoryAllocator, MemoryManager}; - -mod program; -pub use program::{Method, MethodMeta, Program, ProgramVerification}; - -#[cfg(feature = "extension-module")] -mod module; -#[cfg(feature = "extension-module")] -pub use module::Module; - -mod evalue; -pub use evalue::{EValue, Tag}; - -mod tensor; -pub use tensor::{Tensor, TensorImpl, TensorInfo, TensorMut}; - +pub mod error; +pub mod evalue; +pub mod memory; +#[cfg(feature = "module")] +pub mod module; +pub mod platform; +pub mod program; +pub mod tensor; pub mod util; - -pub fn pal_init() { - unsafe { executorch_sys::et_pal_init() }; -} diff --git a/src/memory.rs b/src/memory.rs index 4f1cbd7..16c0c24 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,3 +1,9 @@ +//! Memory management classes. +//! +//! The ExecuTorch library allow the user to control memory allocation using the structs in the module. +//! This enable using the library in embedded systems where dynamic memory allocation is not allowed, or when allocation +//! is a performance bottleneck. + use std::cell::UnsafeCell; use std::marker::PhantomData; use std::ptr; diff --git a/src/module.rs b/src/module.rs index 3fe7a98..7680969 100644 --- a/src/module.rs +++ b/src/module.rs @@ -1,12 +1,26 @@ +//! A higher-level for simple execution of programs. +//! +//! This module provides a higher-level interface for loading programs and executing methods within them. +//! Compared to the lower-level [`program`](crate::program) interface, the `module` interface is more user-friendly, +//! uses the default memory allocator, and provides automatic memory management. +//! +//! This module is enabled by the `module` feature. +//! +//! See the `hello_world_add_module` example for how to load and execute a module. + use std::collections::HashSet; use std::path::Path; use std::ptr; +use crate::error::Result; +use crate::evalue::EValue; +use crate::program::{MethodMeta, ProgramVerification}; use crate::util::{self, ArrayRef, IntoRust}; -use crate::{et_c, et_rs_c, EValue, Result}; -use crate::{MethodMeta, ProgramVerification}; +use crate::{et_c, et_rs_c}; /// A facade class for loading programs and executing methods within them. +/// +/// See the `hello_world_add_module` example for how to load and execute a module. pub struct Module(et_c::Module); impl Module { /// Constructs an instance by loading a program from a file with specified @@ -180,4 +194,5 @@ impl Drop for Module { } } +/// Enum to define memory locking behavior. pub type MlockConfig = et_c::Module_MlockConfig; diff --git a/src/platform.rs b/src/platform.rs new file mode 100644 index 0000000..b918a1a --- /dev/null +++ b/src/platform.rs @@ -0,0 +1,12 @@ +//! Platform abstraction layer to allow individual platform libraries to override +//! symbols in ExecuTorch. +//! +//! PAL functions are defined as C functions so a platform library implementer can use C in lieu of C++. + +/// Initialize the platform abstraction layer. +/// +/// This function should be called before any other function provided by the PAL +/// to initialize any global state. Typically overridden by PAL implementer. +pub fn pal_init() { + unsafe { executorch_sys::et_pal_init() }; +} diff --git a/src/program.rs b/src/program.rs index bf63644..5ad46a8 100644 --- a/src/program.rs +++ b/src/program.rs @@ -1,13 +1,27 @@ +//! Lower-level API for loading and executing ExecuTorch programs. +//! +//! This module is the lowest level API for the ExecuTorch library. It provides the ability to load and execute +//! programs, while controlling memory allocation and execution. +//! +//! See the `hello_world_add` example for how to load and execute a program. + use std::ffi::{CStr, CString}; use std::marker::PhantomData; +use std::ops::Index; use std::ptr; use crate::data_loader::DataLoader; +use crate::error::Result; use crate::evalue::EValue; +use crate::evalue::Tag; +use crate::memory::MemoryManager; +use crate::tensor::TensorInfo; use crate::util::IntoRust; -use crate::{et_c, et_rs_c, MemoryManager, Result, Tag, TensorInfo}; +use crate::{et_c, et_rs_c}; /// A deserialized ExecuTorch program binary. +/// +/// See the `hello_world_add` example for how to load and execute a program. pub struct Program<'a>(et_c::Program, PhantomData<&'a ()>); impl<'a> Program<'a> { /// Loads a Program from the provided loader. The Program will hold a pointer @@ -108,7 +122,9 @@ impl Drop for Program<'_> { } } +/// Types of validation that the Program can do before parsing the data. pub type ProgramVerification = et_c::Program_Verification; +/// Describes the presence of an ExecuTorch program header. pub type HeaderStatus = et_c::Program_HeaderStatus; /// Describes a a method in an ExecuTorch program. @@ -218,12 +234,16 @@ impl<'a> MethodMeta<'a> { } } +/// An executable method of an ExecuTorch program. Maps to a python method like +/// `forward()` on the original `nn.Module`. pub struct Method<'a>(et_c::Method, PhantomData<&'a ()>); impl<'a> Method<'a> { + /// Starts the execution of the method. pub fn start_execution(&mut self) -> Execution<'_> { Execution::new(&mut self.0) } + /// Returns the number of inputs the Method expects. pub fn inputs_size(&self) -> usize { unsafe { self.0.inputs_size() } } @@ -234,6 +254,7 @@ impl Drop for Method<'_> { } } +/// An method execution builder used to set inputs and execute the method. pub struct Execution<'a> { method: &'a mut et_c::Method, set_inputs: u64, @@ -250,12 +271,23 @@ impl<'a> Execution<'a> { } } + /// Sets the internal input value to be equivalent to the provided value. + /// + /// # Arguments + /// + /// * `input` - The evalue to copy into the method input. If the evalue is a tensor, the data is copied in most + /// cases, so the tensor passed in here does not always need to outlive this call. But there is a case where the + /// Method will keep a pointer to the tensor's data. Based on the memory plan of the method, the inputs may not + /// have buffer space pre-allocated for them. In this case the executor will alias the memory of the tensors + /// provided as inputs here rather then deepcopy the input into the memory planned arena. + /// * `input_idx` - Zero-based index of the input to set. Must be less than the value returned by inputs_size(). pub fn set_input<'b: 'a>(&mut self, input: &'b EValue, input_idx: usize) -> Result<()> { unsafe { self.method.set_input(&input.0, input_idx) }.rs()?; self.set_inputs |= 1 << input_idx; Ok(()) } + /// Execute the method. pub fn execute(self) -> Result> { assert_eq!( self.set_inputs, @@ -266,6 +298,10 @@ impl<'a> Execution<'a> { Ok(Outputs::new(self.method)) } } + +/// The outputs of a method execution. +/// +/// Access the outputs of a method execution by indexing into the Outputs object. pub struct Outputs<'a> { method: &'a mut et_c::Method, } @@ -274,8 +310,16 @@ impl<'a> Outputs<'a> { Self { method } } - pub fn get_output(&self, output_idx: usize) -> &'a EValue<'a> { - let val = unsafe { &*self.method.get_output(output_idx) }; + /// Returns the number of outputs the Method returns. + pub fn len(&self) -> usize { + unsafe { self.method.outputs_size() } + } +} +impl<'a> Index for Outputs<'a> { + type Output = EValue<'a>; + + fn index(&self, index: usize) -> &Self::Output { + let val = unsafe { &*self.method.get_output(index) }; // SAFETY: et_c::EValue as EValue has the same memory layout unsafe { std::mem::transmute::<&et_c::EValue, &EValue<'a>>(val) } } diff --git a/src/tensor.rs b/src/tensor.rs index fe7d632..0fba740 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,7 @@ +//! Tensor struct is a type erased input or output tensor to a executorch program. + use std::any::TypeId; +use std::fmt::Debug; use std::marker::PhantomData; use ndarray::{ArrayBase, ArrayView, ArrayViewD, ArrayViewMut, Dimension, IxDyn, ShapeBuilder}; @@ -20,28 +23,51 @@ pub type StridesType = executorch_sys::exec_aten::StridesType; #[repr(u8)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ScalarType { + /// 8-bit unsigned integer, `u8` Byte = et_c::ScalarType::Byte as u8, + /// 8-bit signed, integer, `i8` Char = et_c::ScalarType::Char as u8, + /// 16-bit signed integer, `i16` Short = et_c::ScalarType::Short as u8, + /// 32-bit signed integer, `i32` Int = et_c::ScalarType::Int as u8, + /// 64-bit signed integer, `i64` Long = et_c::ScalarType::Long as u8, + /// **\[Unsupported\]** 16-bit floating point Half = et_c::ScalarType::Half as u8, + /// 32-bit floating point, `f32` Float = et_c::ScalarType::Float as u8, + /// 64-bit floating point, `f64` Double = et_c::ScalarType::Double as u8, + /// **\[Unsupported\]** 16-bit complex floating point ComplexHalf = et_c::ScalarType::ComplexHalf as u8, + /// **\[Unsupported\]** 32-bit complex floating point ComplexFloat = et_c::ScalarType::ComplexFloat as u8, + /// **\[Unsupported\]** 64-bit complex floating point ComplexDouble = et_c::ScalarType::ComplexDouble as u8, + /// Boolean, `bool` Bool = et_c::ScalarType::Bool as u8, + /// **\[Unsupported\]** 8-bit quantized integer QInt8 = et_c::ScalarType::QInt8 as u8, + /// **\[Unsupported\]** 8-bit quantized unsigned integer QUInt8 = et_c::ScalarType::QUInt8 as u8, + /// **\[Unsupported\]** 32-bit quantized integer QInt32 = et_c::ScalarType::QInt32 as u8, + /// **\[Unsupported\]** 16-bit floating point using the bfloat16 format BFloat16 = et_c::ScalarType::BFloat16 as u8, + /// **\[Unsupported\]** QUInt4x2 = et_c::ScalarType::QUInt4x2 as u8, + /// **\[Unsupported\]** QUInt2x4 = et_c::ScalarType::QUInt2x4 as u8, + /// **\[Unsupported\]** Bits1x8 = et_c::ScalarType::Bits1x8 as u8, + /// **\[Unsupported\]** Bits2x4 = et_c::ScalarType::Bits2x4 as u8, + /// **\[Unsupported\]** Bits4x2 = et_c::ScalarType::Bits4x2 as u8, + /// **\[Unsupported\]** Bits8 = et_c::ScalarType::Bits8 as u8, + /// **\[Unsupported\]** Bits16 = et_c::ScalarType::Bits16 as u8, } impl ScalarType { @@ -106,6 +132,7 @@ impl ScalarType { /// A trait for types that can be used as scalar types in Tensors. pub trait Scalar { + /// The `ScalarType` enum variant of the implementing type. const TYPE: ScalarType; private_decl! {} } @@ -483,6 +510,61 @@ impl<'a, D: Data> TensorImplBase<'a, D> { } } +impl Debug for TensorBase<'_, D> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut st = f.debug_struct("Tensor"); + st.field("scalar_type", &self.scalar_type()); + + fn add_data_field( + this: &TensorBase<'_, D>, + st: &mut std::fmt::DebugStruct, + ) { + match this.dim() { + 0 => st.field("data", &this.as_array::()), + 1 => st.field("data", &this.as_array::()), + 2 => st.field("data", &this.as_array::()), + 3 => st.field("data", &this.as_array::()), + 4 => st.field("data", &this.as_array::()), + 5 => st.field("data", &this.as_array::()), + 6 => st.field("data", &this.as_array::()), + _ => st.field("data", &this.as_array_dyn::()), + }; + } + fn add_data_field_unsupported(st: &mut std::fmt::DebugStruct) { + st.field("data", &"unsupported"); + } + match self.scalar_type() { + Some(ScalarType::Byte) => add_data_field::<_, u8>(self, &mut st), + Some(ScalarType::Char) => add_data_field::<_, i8>(self, &mut st), + Some(ScalarType::Short) => add_data_field::<_, i16>(self, &mut st), + Some(ScalarType::Int) => add_data_field::<_, i32>(self, &mut st), + Some(ScalarType::Long) => add_data_field::<_, i64>(self, &mut st), + Some(ScalarType::Half) => add_data_field_unsupported(&mut st), + Some(ScalarType::Float) => add_data_field::<_, f32>(self, &mut st), + Some(ScalarType::Double) => add_data_field::<_, f64>(self, &mut st), + Some(ScalarType::ComplexHalf) => add_data_field_unsupported(&mut st), + Some(ScalarType::ComplexFloat) => add_data_field_unsupported(&mut st), + Some(ScalarType::ComplexDouble) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bool) => add_data_field::<_, bool>(self, &mut st), + Some(ScalarType::QInt8) => add_data_field_unsupported(&mut st), + Some(ScalarType::QUInt8) => add_data_field_unsupported(&mut st), + Some(ScalarType::QInt32) => add_data_field_unsupported(&mut st), + Some(ScalarType::BFloat16) => add_data_field_unsupported(&mut st), + Some(ScalarType::QUInt4x2) => add_data_field_unsupported(&mut st), + Some(ScalarType::QUInt2x4) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bits1x8) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bits2x4) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bits4x2) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bits8) => add_data_field_unsupported(&mut st), + Some(ScalarType::Bits16) => add_data_field_unsupported(&mut st), + None => { + st.field("data", &"None"); + } + }; + st.finish() + } +} + /// An immutable tensor implementation that does not own the underlying data. pub type TensorImpl<'a> = TensorImplBase<'a, View>; impl<'a> TensorImpl<'a> { @@ -701,6 +783,17 @@ impl<'a> TensorInfo<'a> { unsafe { et_c::TensorInfo_nbytes(&self.0) } } } +impl Debug for TensorInfo<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TensorInfo") + .field("sizes", &self.sizes()) + .field("dim_order", &self.dim_order()) + .field("scalar_type", &self.scalar_type()) + .field("nbytes", &self.nbytes()) + .finish() + } +} + #[cfg(test)] mod tests { use ndarray::{arr1, arr2, Array1, Array2, Array3, Ix3}; diff --git a/src/util.rs b/src/util.rs index c47919c..65564af 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,11 @@ +//! Utility functions and types. +//! +//! Most of the structs in this module may seems redundant in Rust, but they are wrappers around C++ types +//! that are used in the C++ API. Some structs and functions accept these types as arguments, so they are +//! necessary to interact with the C++ API. + +use std::fmt::Debug; +use std::hash::Hash; use std::marker::PhantomData; use std::mem::ManuallyDrop; @@ -57,6 +65,11 @@ impl<'a, T> ArrayRef<'a, T> { unsafe { std::slice::from_raw_parts(self.0.Data, self.0.Length) } } } +impl Debug for ArrayRef<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_slice().fmt(f) + } +} /// Represent a reference to an array (0 or more elements /// consecutively in memory), i.e. a start pointer and a length. It allows @@ -99,6 +112,11 @@ impl<'a, T> Span<'a, T> { unsafe { std::slice::from_raw_parts_mut(self.0.data_, self.0.length_) } } } +impl Debug for Span<'_, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_slice().fmt(f) + } +} /// Leaner optional class, subset of c10, std, and boost optional APIs. pub struct Optional(et_c::optional); @@ -152,6 +170,27 @@ impl From> for Optional { Optional::new(opt) } } +impl Clone for Optional { + fn clone(&self) -> Self { + Self::new(self.as_ref().cloned()) + } +} +impl PartialEq for Optional { + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} +impl Eq for Optional {} +impl Hash for Optional { + fn hash(&self, state: &mut H) { + self.as_ref().hash(state) + } +} +impl Debug for Optional { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_ref().fmt(f) + } +} #[allow(dead_code)] pub(crate) fn str2chars(s: &str) -> Result<&[std::os::raw::c_char], &'static str> {