diff --git a/examples/arm/image_classification_example/README.md b/examples/arm/image_classification_example/README.md new file mode 100644 index 00000000000..25faa6eb65a --- /dev/null +++ b/examples/arm/image_classification_example/README.md @@ -0,0 +1,17 @@ +# Image Classification Example Application + +This end-to-end example shows how to use the Arm backend in ExecuTorch across both ahead-of-time (AoT) and runtime flows. It covers +this by providing examples of: + +- Scripts to fine-tune a DeiT-Tiny model on the Oxford-IIIT Pet dataset, quantize it, and export an Ethos-U–ready ExecuTorch program. +- A simple bare-metal image-classification app for Corstone-320 (Ethos-U85-256) that embeds the exported program and a sample image. +- Running the app on the Corstone-320 Fixed Virtual Platform (FVP). + +## Layout + +The example is divided into two sections: + +- `model_export/README.md` — Covers fine-tuning a model for a new usecase, quantization to INT8, lowering to Ethos-U via ExecuTorch and `.pte` generation. +- `runtime/README.md` — Covers building the bare-metal app, generating headers from the `.pte` and image, and running on the FVP. + +In addition, this example uses `../executor_runner/` for various utilities (linker scripts, memory allocators, and the PTE-to-header converter). diff --git a/examples/arm/image_classification_example/model_export/README.md b/examples/arm/image_classification_example/model_export/README.md new file mode 100644 index 00000000000..5ab18b6060f --- /dev/null +++ b/examples/arm/image_classification_example/model_export/README.md @@ -0,0 +1,77 @@ +# DEiT Fine-Tuning & Export + +This example provides two scripts: + +- `train_deit.py` — Fine-tunes the DEiT-Tiny model, initially trained on ImageNet 1k, on the Oxford-IIIT Pet dataset to repurpose the network +to classify cat and dog breeds. This is intended to demonstrate the process of preparing a model for a new usecase, before lowering it via ExecuTorch. +- `export_deit.py` — Loads the trained checkpoint from `train_deit.py`, applies post-training quantization (PT2E), evaluates, and exports an Ethos-U–ready ExecuTorch program. + +The Oxford-IIIT Pet dataset is used by both scripts as it's a relatively small dataset, allowing this example to be run on a high-end laptop or desktop. + +See the sections below for requirements and exact commands. + +## Requirements + +- Python 3.10+ with `executorch` and the dependencies in `requirements-examples.txt`. +- Internet access to download pretrained weights and the Oxford-IIIT Pet dataset. + +## Fine-tuning DEiT Tiny + +The `train_deit.py` script can be run as follows: + +```bash +python examples/arm/image_classification_example/model_export/train_deit.py \ + --output-dir ./deit-tiny-oxford-pet \ + --num-epochs 3 +``` + +The script splits the training set for validation, fine-tunes the model, reports test accuracy, and by default outputs the model to `deit-tiny-oxford-pet/final_model`. +Running this script achieves a test set accuracy of 86.10% in FP32. + +## Export and quantize + +The `export_deit.py` script can be run as follows: + +```bash +python examples/arm/image_classification_example/model_export/export_deit.py \ + --model-path ./deit-tiny-oxford-pet/final_model \ + --output-path ./deit_quantized_exported.pte \ + --num-calibration-samples 300 \ + --num-test-samples 100 +``` + +During export, the script: +- Exports the FP32 model using `torch.export.export()`. +- Applies symmetric quantization to each operator. +- Targets `Ethos-U85-256` with shared SRAM and lowers the network to Ethos-U. +- Writes the ExecuTorch program to the requested path. + +Running this script following the `train_deit.py` script achieves a test set accuracy of 85.00% for the quantized model on 100 samples. + +### Interpreting Vela Output + +After the model has been compiled for Ethos-U, the Vela compiler will output a network summary. You will see output similar to: + +``` +Network summary for out +Accelerator configuration Ethos_U85_256 +System configuration Ethos_U85_SYS_DRAM_Mid +Memory mode Shared_Sram +Accelerator clock 1000 MHz +Design peak SRAM bandwidth 29.80 GB/s +Design peak DRAM bandwidth 11.18 GB/s + +Total SRAM used 1291.80 KiB +Total DRAM used 5289.91 KiB + +CPU operators = 0 (0.0%) +NPU operators = 898 (100.0%) + +... (Truncated) +``` + +Some of this information is key to understanding the example application, which will run this model on device: + +- The `Accelerator configuration` is `Ethos_U85_256`, so it will only work on an Ethos-U85 system. The FVP for this is Corstone-320. +- The `Memory mode` is `Shared_Sram`, so the tensor arena is allocated in SRAM while the model data is read from flash and DRAM. +- The `Total SRAM used` is `1291.80 KiB`, so at least this much memory will need to be allocated for the tensor arena. diff --git a/examples/arm/image_classification_example/model_export/export_deit.py b/examples/arm/image_classification_example/model_export/export_deit.py new file mode 100644 index 00000000000..9f2bcdd54f1 --- /dev/null +++ b/examples/arm/image_classification_example/model_export/export_deit.py @@ -0,0 +1,168 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch +import tqdm +from datasets import DatasetDict, load_dataset + +from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner +from executorch.backends.arm.quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, +) +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from transformers import AutoImageProcessor +from transformers.models.vit.modeling_vit import ViTForImageClassification + + +def make_transform(preprocessor): + def transform(batch): + img = [item.convert("RGB") for item in batch["image"]] + inputs = preprocessor(img, return_tensors="pt") + + return { + "pixel_values": inputs["pixel_values"].unsqueeze(0), + "labels": batch["label"], + } + + return transform + + +def quantize_model(model, quantizer, calibration_data): + example_input = calibration_data[0]["pixel_values"] + + exported_model = torch.export.export( + model, + (example_input,), + ).module() + + quantize = prepare_pt2e(exported_model, quantizer) + + print("\nCalibrating the model...") + for example in tqdm.tqdm(calibration_data): + quantize(example["pixel_values"]) + + pt2e_deit = convert_pt2e(quantize) + + return torch.export.export( + pt2e_deit, + (example_input,), + ) + + +def measure_accuracy(quantized_model, test_set): + examples = 0 + correct = 0 + + print("\nMeasuring accuracy on the test set...") + tbar = tqdm.tqdm(test_set) + for example in tbar: + img = example["pixel_values"] + output = quantized_model(img) + output = output.logits.argmax(dim=-1).item() + + if output == example["labels"]: + correct += 1 + examples += 1 + accuracy = correct / examples + + tbar.set_description(f"Accuracy: {accuracy:.4f}") + + print(f"Top-1 accuracy on {examples} test samples: {accuracy:.4f}") + return accuracy + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser(description="Export ViT model") + argparser.add_argument( + "--model-path", + type=str, + default="./deit-tiny-oxford-pet/final_model", + required=True, + help="Path to the fine-tuned ViT model.", + ) + argparser.add_argument( + "--output-path", + type=str, + default="./deit_quantized_exported.pte", + help="Path to save the exported quantized model.", + ) + argparser.add_argument( + "--num-calibration-samples", + type=int, + default=300, + help="Number of samples to use for calibration.", + ) + argparser.add_argument( + "--num-test-samples", + type=int, + default=100, + help="Number of samples to use for testing accuracy.", + ) + args = argparser.parse_args() + + deit = ViTForImageClassification.from_pretrained( + args.model_path, + num_labels=37, + ignore_mismatched_sizes=True, + ).eval() + image_preprocessor = AutoImageProcessor.from_pretrained( + "facebook/deit-tiny-patch16-224", use_fast=True + ) + + compile_spec = EthosUCompileSpec( + target="ethos-u85-256", + memory_mode="Shared_Sram", + ) + + quantizer = EthosUQuantizer(compile_spec) + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + + ds = load_dataset("timm/oxford-iiit-pet") + + split = ds["train"].train_test_split(test_size=0.1, seed=42) + dataset = DatasetDict( + { + "train": split["train"], + "validation": split["test"], + "test": ds["test"], + } + ) + dataset = dataset.with_transform(make_transform(image_preprocessor)) + + with torch.no_grad(): + quantized_deit = quantize_model( + deit, + quantizer, + dataset["train"].take(args.num_calibration_samples), + ) + measure_accuracy( + quantized_deit.module(), dataset["test"].take(args.num_test_samples) + ) + + partition = EthosUPartitioner(compile_spec) + edge_encoder = to_edge_transform_and_lower( + programs=quantized_deit, + partitioner=[partition], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + edge_manager = edge_encoder.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + save_pte_program(edge_manager, args.output_path) + print(f"\nExported model saved to {args.output_path}") diff --git a/examples/arm/image_classification_example/model_export/train_deit.py b/examples/arm/image_classification_example/model_export/train_deit.py new file mode 100644 index 00000000000..d4cf7d8274d --- /dev/null +++ b/examples/arm/image_classification_example/model_export/train_deit.py @@ -0,0 +1,130 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from pathlib import Path + +import numpy as np +import torch +from datasets import DatasetDict, load_dataset +from evaluate import load as load_metric +from transformers import AutoImageProcessor, set_seed, Trainer, TrainingArguments + +from transformers.models.vit.modeling_vit import ViTForImageClassification + + +def make_transform(preprocessor): + def transform(batch): + img = [item.convert("RGB") for item in batch["image"]] + inputs = preprocessor(img, return_tensors="pt") + + return { + "pixel_values": inputs["pixel_values"], + "labels": batch["label"], + } + + return transform + + +def make_compute_metrics(accuracy_metric): + def compute_metrics(eval_pred): + logits, labels = eval_pred + preds = np.argmax(logits, axis=-1) + return accuracy_metric.compute(predictions=preds, references=labels) + + return compute_metrics + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["labels"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + + +if __name__ == "__main__": + # Set the seed for reproducibility + set_seed(42) + + argparser = argparse.ArgumentParser( + description="Fine-tune DeIT model on Oxford-IIIT Pet dataset" + ) + argparser.add_argument( + "--output-dir", + type=str, + default="./deit-tiny-oxford-pet", + help="Directory to save the trained model", + ) + argparser.add_argument( + "--num-epochs", type=int, default=3, help="Number of training epochs" + ) + args = argparser.parse_args() + ds = load_dataset("timm/oxford-iiit-pet") + + # Create the mappings between labels and IDs + labels = ds["train"].features["label"].names + ids2label = dict(enumerate(labels)) + label2ids = {l: i for i, l in enumerate(labels)} + + deit = ViTForImageClassification.from_pretrained( + "facebook/deit-tiny-patch16-224", + num_labels=37, + ignore_mismatched_sizes=True, + id2label=ids2label, + label2id=label2ids, + ) + image_preprocessor = AutoImageProcessor.from_pretrained( + "facebook/deit-tiny-patch16-224", use_fast=True + ) + + # Create a validation set by splitting the training set into two parts + split = ds["train"].train_test_split(test_size=0.1, seed=42) + dataset = DatasetDict( + { + "train": split["train"], + "validation": split["test"], + "test": ds["test"], + } + ) + dataset = dataset.with_transform(make_transform(image_preprocessor)) + + training_args = TrainingArguments( + output_dir=args.output_dir, + per_device_train_batch_size=16, + eval_strategy="steps", + num_train_epochs=args.num_epochs, + fp16=False, + save_steps=100, + eval_steps=100, + logging_steps=10, + learning_rate=2e-4, + save_total_limit=2, + remove_unused_columns=False, + push_to_hub=False, + load_best_model_at_end=True, + report_to="none", + use_mps_device=torch.backends.mps.is_available(), + ) + + accuracy_metric = load_metric("accuracy") + trainer = Trainer( + model=deit, + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["validation"], + data_collator=collate_fn, + compute_metrics=make_compute_metrics(accuracy_metric), + ) + + print("\n Starting training DEiT Tiny on Oxford-IIIT Pet dataset...") + trainer.train() + + print("\nEvaluating the model on the test set...") + result = trainer.evaluate(dataset["test"]) + print(f"Test set accuracy: {result['eval_accuracy']:.4f}") + + final_model_path = Path(args.output_dir) / "final_model" + trainer.save_model(str(final_model_path)) + print(f"\nTrained model saved to {final_model_path}")