diff --git a/examples/arm/image_classification_example/runtime/CMakeLists.txt b/examples/arm/image_classification_example/runtime/CMakeLists.txt new file mode 100644 index 00000000000..8ec999a8f4b --- /dev/null +++ b/examples/arm/image_classification_example/runtime/CMakeLists.txt @@ -0,0 +1,147 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.20) + +project(image_classification_minimal_application) + +# Example ExecuTorch demo for bare metal Cortex-M based systems +set(ET_DIR_PATH + "${CMAKE_CURRENT_SOURCE_DIR}/../../../.." + CACHE PATH "Path to ExecuTorch dir" +) + +set(ET_BUILD_DIR_PATH + "${ET_DIR_PATH}/cmake-out-arm" + CACHE PATH "Path to ExecuTorch build/install dir" +) +set(ET_INCLUDE_PATH + "${ET_DIR_PATH}/.." + CACHE PATH "Path to ExecuTorch headers" +) + +set(ET_PTE_FILE_PATH + "" + CACHE PATH "Path to ExecuTorch model pte" +) + +set(IMAGE_PATH + "" + CACHE + PATH + "Path to an RGB image to use for the application(e.g. a jpg image of a cat or a dog)" +) + +set(ETHOS_SDK_PATH + "${ET_DIR_PATH}/examples/arm/ethos-u-scratch/ethos-u" + CACHE PATH "Path to Ethos-U bare metal driver/env" +) + +set(PYTHON_EXECUTABLE + "python" + CACHE PATH "Define to override python executable used" +) + +if(NOT EXISTS "${IMAGE_PATH}") + message( + FATAL_ERROR + "Image not provided. Please provide path to an image for the image classification application and retry." + ) +endif() +if(NOT EXISTS "${ET_PTE_FILE_PATH}") + message( + FATAL_ERROR + "pte file not provided. Please provide pte file for the application and retry." + ) +endif() +if(NOT EXISTS "${ETHOS_SDK_PATH}") + message( + FATAL_ERROR + "The ${ETHOS_SDK_PATH} directory does not exist. Please run examples/arm/setup.sh script and retry." + ) +endif() + +find_package( + executorch REQUIRED HINTS "${ET_BUILD_DIR_PATH}/lib/cmake/ExecuTorch" +) + +# The core_platform project defines the corstone-320 target and contains the +# start-up code for the Cortex-M +add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target) + +add_executable(img_class_example main.cpp) + +target_sources( + img_class_example + PRIVATE main.cpp ../executor_runner/arm_memory_allocator.cpp + ../executor_runner/arm_perf_monitor.cpp +) + +target_link_libraries( + img_class_example + PUBLIC executorch + ethosu_target_init + extension_runner_util + quantized_ops_lib + portable_kernels + cortex_m_kernels + cortex_m_ops_lib +) + +# We need to include whole archive for the EthosUBackend +target_link_libraries( + img_class_example PUBLIC "-Wl,--whole-archive" executorch_delegate_ethos_u + "-Wl,--no-whole-archive" +) + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(LINK_FILE_EXT ld) + set(LINK_FILE_OPTION "-T") + set(COMPILER_PREPROCESSOR_OPTIONS -E -x c -P) +endif() + +set(LINK_FILE_OUT_BASE "platform_linker_script") +set(LINK_FILE_IN "${CMAKE_SOURCE_DIR}/../../executor_runner/Corstone-320.ld") +set(LINK_FILE_OUT + ${CMAKE_CURRENT_BINARY_DIR}/${LINK_FILE_OUT_BASE}.${LINK_FILE_EXT} +) +# The ETHOSU_ARENA symbol is defined in the Corstone-320 linker script and it +# controls the placement of the intermediate tensors in the application memory +# map. Here, because in the compile spec in the AoT flow, we generate pte for +# Shared_Sram, in the application, we set ETHOSU_ARENA to 0 so that the +# intermediate tensors are placed in the SRAM. If you generate a pte for a +# different memory mode, you need to change the placement in the linker script. +# Read +# https://docs.pytorch.org/executorch/stable/backends-arm-ethos-u.html#ethos-u-memory-modes +# for more information. +set(ETHOSU_ARENA "0") +# Generate linker script - we have a few if/else statements in +# Corstone-320.ld/Corstone-300.ld that are compiled into a final linker script. +execute_process( + COMMAND ${CMAKE_C_COMPILER} ${COMPILER_PREPROCESSOR_OPTIONS} -DETHOSU_ARENA=0 + -o ${LINK_FILE_OUT} ${LINK_FILE_IN} +) +target_link_options(img_class_example PRIVATE "-T" "${LINK_FILE_OUT}") + +# Run the pte_to_header.py script to convert the .pte file to an array in a .h +# file +execute_process( + COMMAND + ${PYTHON_EXECUTABLE} + ${CMAKE_SOURCE_DIR}/../../executor_runner/pte_to_header.py --pte + ${ET_PTE_FILE_PATH} --outdir ${CMAKE_CURRENT_BINARY_DIR} +) + +# Convert an RGB image to an array in a .h file +execute_process( + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/rgb_to_array.py --image + ${IMAGE_PATH} --output ${CMAKE_CURRENT_BINARY_DIR}/image.h +) + +target_include_directories( + img_class_example + PRIVATE ${ET_INCLUDE_PATH} ${ET_DIR_PATH}/runtime/core/portable_type/c10 + ${CMAKE_SOURCE_DIR}/../../executor_runner ${CMAKE_CURRENT_BINARY_DIR} +) diff --git a/examples/arm/image_classification_example/runtime/README.md b/examples/arm/image_classification_example/runtime/README.md new file mode 100644 index 00000000000..7e78d98bd41 --- /dev/null +++ b/examples/arm/image_classification_example/runtime/README.md @@ -0,0 +1,24 @@ +1. Make sure you have setup the Ethos-U ExecuTorch dependencies by running the examples/arm/setup.sh script. See the [readme](../../README.md) for instructions on how to do the setup. + +2. Build executorch from the `examples/arm` folder, cross compiled for a Cortex-M device. +```$ cmake --preset arm-baremetal \ +-DCMAKE_BUILD_TYPE=Release \ +-B../../cmake-out-arm ../.. +cmake --build ../../cmake-out-arm --target install -j$(nproc) ```` + +3. Set up the build system. You need to provide path to the DEiT-Tiny pte generated in the +`examples/arm/image_classification_example/export` folder. You also need to provide an image of a dog, you can download such +image from the [HuggingFace Oxford iiit pet dataset](https://huggingface.co/datasets/timm/oxford-iiit-pet). +``` +$ cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake -DET_PTE_FILE_PATH= -DIMAGE_PATH= -Bsimple_app_deit_tiny image_classification_example/runtime +``` + +4. Compile the application. +``` +$ cmake --build simple_app_deit_tiny -j$(nproc) -- img_class_example +``` + +5. Deploy the application on the Corstone-320 Fixed Virtual Platform. Assuming you have the Corstone-320 installed on your path, do the following command to deploy the application. +``` +$ FVP_Corstone_SSE-320 -C mps4_board.subsystem.ethosu.num_macs=256 -C mps4_board.visualisation.disable-visualisation=1 -C vis_hdlcd.disable_visualisation=1 -C mps4_board.telnetterminal0.start_telnet=0 -C mps4_board.uart0.out_file='-' -C mps4_board.uart0.shutdown_on_eot=1 -a simple_app_deit_tiny/img_class_example -C mps4_board.subsystem.ethosu.extra_args="--fast" +``` \ No newline at end of file diff --git a/examples/arm/image_classification_example/runtime/main.cpp b/examples/arm/image_classification_example/runtime/main.cpp new file mode 100644 index 00000000000..d5719cd416c --- /dev/null +++ b/examples/arm/image_classification_example/runtime/main.cpp @@ -0,0 +1,243 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::extension::BufferDataLoader; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +#include "arm_memory_allocator.h" + +#include "image.h" +#include "model_pte.h" + +// This application is for a model fine tuned on the Oxford-IIIT Pet +// dataset(https://huggingface.co/datasets/timm/oxford-iiit-pet/blob/main/README.md) +// The dataset contains the following 37 pet breed labels (cats and dogs). +constexpr const char* labels[] = { + "abyssinian", + "american_bulldog", + "american_pit_bull_terrier", + "basset_hound", + "beagle", + "bengal", + "birman", + "bombay", + "boxer", + "british_shorthair", + "chihuahua", + "egyptian_mau", + "english_cocker_spaniel", + "english_setter", + "german_shorthaired", + "great_pyrenees", + "havanese", + "japanese_chin", + "keeshond", + "leonberger", + "maine_coon", + "miniature_pinscher", + "newfoundland", + "persian", + "pomeranian", + "pug", + "ragdoll", + "russian_blue", + "saint_bernard", + "samoyed", + "scottish_terrier", + "shiba_inu", + "siamese", + "sphynx", + "staffordshire_bull_terrier", + "wheaten_terrier", + "yorkshire_terrier"}; + +const size_t method_allocation_pool_size = 1 * 1024 * 1024; +unsigned char __attribute__(( + section("input_data_sec"), + aligned(16))) method_allocation_pool[method_allocation_pool_size]; + +/* +The to_edge_tranform_and_lower step reports + Total SRAM used 1291.80 KiB +therefore, we allocate 1.3MB in the temporary allocation pool store the peak +intermediate tensor for the inference. +*/ +const size_t temp_allocation_pool_size = 1.3 * 1024 * 1024; +unsigned char __attribute__(( + section(".bss.tensor_arena"), + aligned(16))) temp_allocation_pool[temp_allocation_pool_size]; + +int main() { + executorch::runtime::runtime_init(); + ET_LOG(Info, "Runtime initialized"); + BufferDataLoader loader(model_pte, sizeof(model_pte)); + ET_LOG(Info, "Size of the model = %d", sizeof(model_pte)); + Result program = Program::load(&loader); + ET_CHECK_MSG(program.ok(), "Program::load failed: 0x%x", program.error()); + + const auto method_name_result = program->get_method_name(0); + ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); + const char* method_name = *method_name_result; + ET_LOG(Info, "Running method %s", method_name); + + Result method_meta_result = program->method_meta(method_name); + ET_CHECK_MSG( + method_meta_result.ok(), + "method_meta lookup failed: 0x%x", + method_meta_result.error()); + + ArmMemoryAllocator method_allocator( + method_allocation_pool_size, method_allocation_pool); + ArmMemoryAllocator temp_allocator( + temp_allocation_pool_size, temp_allocation_pool); + + std::vector planned_buffers; // Owns the memory + std::vector> planned_spans; // Passed to the allocator + size_t num_memory_planned_buffers = + method_meta_result->num_memory_planned_buffers(); + ET_LOG(Info, "num_memory_planned_buffers = %zu", num_memory_planned_buffers); + for (size_t id = 0; id < num_memory_planned_buffers; ++id) { + size_t buffer_size = + method_meta_result->memory_planned_buffer_size(id).get(); + ET_LOG(Info, "Planned memory buffer_size %zu %zu bytes", id, buffer_size); + + uint8_t* buffer = reinterpret_cast( + method_allocator.allocate(buffer_size, 16UL)); + + ET_CHECK_MSG( + buffer != nullptr, + "Could not allocate memory for memory planned buffer size %zu", + buffer_size); + planned_buffers.push_back(buffer); + planned_spans.push_back({planned_buffers.back(), buffer_size}); + } + HierarchicalAllocator planned_memory( + {planned_spans.data(), planned_spans.size()}); + + MemoryManager memory_manager( + &method_allocator, &planned_memory, &temp_allocator); + + Result method = program->load_method(method_name, &memory_manager); + + size_t num_inputs = method->inputs_size(); + ET_LOG(Info, "Number of input tensors = %zu", num_inputs); + ET_CHECK_MSG( + num_inputs == 1, + "DEiT-Tiny has a single input tensor, but the provided model has more input tensors"); + + EValue* input_evalues = method_allocator.allocateList(num_inputs); + Error err = method->get_inputs(input_evalues, num_inputs); + ET_CHECK_MSG(err == Error::Ok, "Get inputs failed"); + Tensor& tensor = + input_evalues[0].toTensor(); // DEiT-Tiny has a single input tensor. + size_t expected_elems = tensor.numel(); + + size_t image_elements = sizeof(image_data) / + sizeof(image_data[0]); // number of elements of the array in image.h + ET_CHECK_MSG( + expected_elems == image_elements, + "Input tensor expects %zu elements, but image_data has %zu elements", + expected_elems, + image_elements); + + switch (tensor.scalar_type()) { + case ScalarType::Float: { + float* dst = tensor.mutable_data_ptr(); + for (size_t j = 0; j < tensor.numel(); ++j) { + dst[j] = image_data[j]; + } + break; + } + default: + ET_CHECK_MSG( + false, + "Input tensor datatype is not float. The image data we want to populate in the input tensor is float"); + break; + } + Error status_inference = method->execute(); // run inference + ET_CHECK_MSG( + status_inference == Error::Ok, + "Inference failed 0x%" PRIx32, + status_inference); + + size_t num_outputs = method->outputs_size(); + std::vector outputs(num_outputs); + Error status_outputs = method->get_outputs(outputs.data(), outputs.size()); + ET_CHECK_MSG( + status_outputs == Error::Ok, + "get_outputs failed 0x%" PRIx32, + status_outputs); + + std::set> set_confidence_idx; + for (size_t i = 0; i < outputs.size(); ++i) { + if (!outputs[i].isTensor()) + continue; + Tensor out = outputs[i].toTensor(); + switch (out.scalar_type()) { + case ScalarType::Float: { + // When we generate the pte file in the AoT flow, we use float32 + // datatype as input to the model(with Q/DQ nodes around every + // operator). Therefore, we only handle the float32 case in the + // application logic. + const float* data = out.const_data_ptr(); + for (size_t j = 0; j < out.numel(); ++j) + set_confidence_idx.insert({data[j], j}); + break; + } + default: + ET_LOG( + Info, "Output tensor has unsupported dtype %d", out.scalar_type()); + break; + } + } + size_t printed = 0; + size_t topK = 5; + size_t num_labels = sizeof(labels) / sizeof(labels[0]); + ET_LOG( + Info, + "Top %zu classes in descending order(highest probability is at the top)", + topK); + for (auto it = set_confidence_idx.rbegin(); + it != set_confidence_idx.rend() && printed < topK; + ++it, ++printed) { + size_t class_id = it->second; + const char* class_name = + (class_id < num_labels) ? labels[class_id] : "unknown"; + ET_LOG( + Info, + "Class %zu ( %s ) with score of %f", + class_id, + class_name, + it->first); + } + return 0; +} diff --git a/examples/arm/image_classification_example/runtime/rgb_to_array.py b/examples/arm/image_classification_example/runtime/rgb_to_array.py new file mode 100644 index 00000000000..63b0d73526c --- /dev/null +++ b/examples/arm/image_classification_example/runtime/rgb_to_array.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +from argparse import ArgumentParser + +import numpy as np +from PIL import Image + + +def resize_and_crop_center(img: Image.Image, target_size): + """Resize keeping aspect ratio, then crop center to (width, height)""" + target_w, target_h = target_size + # Compute scaling factor to preserve aspect ratio + scale_ratio = max(target_h, target_w) / min(img.height, img.width) + resized_w = int(img.width * scale_ratio) + resized_h = int(img.height * scale_ratio) + # Resize the image + img = img.resize((resized_w, resized_h), resample=Image.Resampling.BILINEAR) + # Calculate crop box (center crop) + left = (resized_w - target_w) / 2 + top = (resized_h - target_h) / 2 + right = (resized_w + target_w) / 2 + bottom = (resized_h + target_h) / 2 + # Keep the center of the image, crop everything outside of the image + return img.crop((left, top, right, bottom)) + + +def convert_image_to_c_array( + image_path, output_path, image_size=(224, 224), array_name="image_data" +): + # Load and preprocess image + img = Image.open(image_path).convert("RGB") + img = resize_and_crop_center(img, image_size) + # NumPy arrays are stored in channels-last format. Convert to channels-first. + img_channels_first = np.transpose(img, (2, 0, 1)) + data = np.array(img_channels_first, dtype=np.float32) / 255.0 + data = data.flatten() + # Format as C array + array_lines = [] + for i in range(0, len(data), 20): # 20 values per line + line = ", ".join(f"{val:6f}" for val in data[i : i + 20]) + array_lines.append(" " + line + ",") + c_array_str = f"""#include +const float {array_name}[{len(data)}] = {{ +{os.linesep.join(array_lines)} +}}; +""" + # Write to output file + with open(output_path, "w") as f: + f.write(c_array_str) + print(f" Converted '{image_path}' → '{output_path}' ({len(data)} bytes)") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--image", required=True, help="Path to an RGB image (e.g. a jpg file)" + ) + parser.add_argument( + "--output", required=True, help="Output path for the generated C array" + ) + parser.add_argument( + "--resolution", + required=False, + type=int, + nargs=2, + default=(224, 224), + help="Resolution of the output image", + ) + args = parser.parse_args() + image_path = args.image + output_path = args.output + image_size = args.resolution + convert_image_to_c_array(image_path, output_path, image_size)