Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/arm/image_classification_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Image Classification Example Application

This end-to-end example shows how to use the Arm backend in ExecuTorch across both ahead-of-time (AoT) and runtime flows. It covers
this by providing examples of:

- Scripts to fine-tune a DeiT-Tiny model on the Oxford-IIIT Pet dataset, quantize it, and export an Ethos-U–ready ExecuTorch program.
- A simple bare-metal image-classification app for Corstone-320 (Ethos-U85-256) that embeds the exported program and a sample image.
- Running the app on the Corstone-320 Fixed Virtual Platform (FVP).

## Layout

The example is divided into two sections:

- `model_export/README.md` — Covers fine-tuning a model for a new usecase, quantization to INT8, lowering to Ethos-U via ExecuTorch and `.pte` generation.
- `runtime/README.md` — Covers building the bare-metal app, generating headers from the `.pte` and image, and running on the FVP.

In addition, this example uses `../executor_runner/` for various utilities (linker scripts, memory allocators, and the PTE-to-header converter).
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# DEiT Fine-Tuning & Export

This example provides two scripts:

- `train_deit.py` — Fine-tunes the DEiT-Tiny model, initially trained on ImageNet 1k, on the Oxford-IIIT Pet dataset to repurpose the network
to classify cat and dog breeds. This is intended to demonstrate the process of preparing a model for a new usecase, before lowering it via ExecuTorch.
- `export_deit.py` — Loads the trained checkpoint from `train_deit.py`, applies post-training quantization (PT2E), evaluates, and exports an Ethos-U–ready ExecuTorch program.

The Oxford-IIIT Pet dataset is used by both scripts as it's a relatively small dataset, allowing this example to be run on a high-end laptop or desktop.

See the sections below for requirements and exact commands.

## Requirements

- Python 3.10+ with `executorch` and the dependencies in `requirements-examples.txt`.
- Internet access to download pretrained weights and the Oxford-IIIT Pet dataset.

## Fine-tuning DEiT Tiny

The `train_deit.py` script can be run as follows:

```bash
python examples/arm/image_classification_example/model_export/train_deit.py \
--output-dir ./deit-tiny-oxford-pet \
--num-epochs 3
```

The script splits the training set for validation, fine-tunes the model, reports test accuracy, and by default outputs the model to `deit-tiny-oxford-pet/final_model`.
Running this script achieves a test set accuracy of 86.10% in FP32.

## Export and quantize

The `export_deit.py` script can be run as follows:

```bash
python examples/arm/image_classification_example/model_export/export_deit.py \
--model-path ./deit-tiny-oxford-pet/final_model \
--output-path ./deit_quantized_exported.pte \
--num-calibration-samples 300 \
--num-test-samples 100
```

During export, the script:
- Exports the FP32 model using `torch.export.export()`.
- Applies symmetric quantization to each operator.
- Targets `Ethos-U85-256` with shared SRAM and lowers the network to Ethos-U.
- Writes the ExecuTorch program to the requested path.

Running this script following the `train_deit.py` script achieves a test set accuracy of 85.00% for the quantized model on 100 samples.

### Interpreting Vela Output

After the model has been compiled for Ethos-U, the Vela compiler will output a network summary. You will see output similar to:

```
Network summary for out
Accelerator configuration Ethos_U85_256
System configuration Ethos_U85_SYS_DRAM_Mid
Memory mode Shared_Sram
Accelerator clock 1000 MHz
Design peak SRAM bandwidth 29.80 GB/s
Design peak DRAM bandwidth 11.18 GB/s

Total SRAM used 1291.80 KiB
Total DRAM used 5289.91 KiB

CPU operators = 0 (0.0%)
NPU operators = 898 (100.0%)

... (Truncated)
```

Some of this information is key to understanding the example application, which will run this model on device:

- The `Accelerator configuration` is `Ethos_U85_256`, so it will only work on an Ethos-U85 system. The FVP for this is Corstone-320.
- The `Memory mode` is `Shared_Sram`, so the tensor arena is allocated in SRAM while the model data is read from flash and DRAM.
- The `Total SRAM used` is `1291.80 KiB`, so at least this much memory will need to be allocated for the tensor arena.
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import torch
import tqdm
from datasets import DatasetDict, load_dataset

from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
)
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.extension.export_util.utils import save_pte_program

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from transformers import AutoImageProcessor
from transformers.models.vit.modeling_vit import ViTForImageClassification


def make_transform(preprocessor):
def transform(batch):
img = [item.convert("RGB") for item in batch["image"]]
inputs = preprocessor(img, return_tensors="pt")

return {
"pixel_values": inputs["pixel_values"].unsqueeze(0),
"labels": batch["label"],
}

return transform


def quantize_model(model, quantizer, calibration_data):
example_input = calibration_data[0]["pixel_values"]

exported_model = torch.export.export(
model,
(example_input,),
).module()

quantize = prepare_pt2e(exported_model, quantizer)

print("\nCalibrating the model...")
for example in tqdm.tqdm(calibration_data):
quantize(example["pixel_values"])

pt2e_deit = convert_pt2e(quantize)

return torch.export.export(
pt2e_deit,
(example_input,),
)


def measure_accuracy(quantized_model, test_set):
examples = 0
correct = 0

print("\nMeasuring accuracy on the test set...")
tbar = tqdm.tqdm(test_set)
for example in tbar:
img = example["pixel_values"]
output = quantized_model(img)
output = output.logits.argmax(dim=-1).item()

if output == example["labels"]:
correct += 1
examples += 1
accuracy = correct / examples

tbar.set_description(f"Accuracy: {accuracy:.4f}")

print(f"Top-1 accuracy on {examples} test samples: {accuracy:.4f}")
return accuracy


if __name__ == "__main__":
argparser = argparse.ArgumentParser(description="Export ViT model")
argparser.add_argument(
"--model-path",
type=str,
default="./deit-tiny-oxford-pet/final_model",
required=True,
help="Path to the fine-tuned ViT model.",
)
argparser.add_argument(
"--output-path",
type=str,
default="./deit_quantized_exported.pte",
help="Path to save the exported quantized model.",
)
argparser.add_argument(
"--num-calibration-samples",
type=int,
default=300,
help="Number of samples to use for calibration.",
)
argparser.add_argument(
"--num-test-samples",
type=int,
default=100,
help="Number of samples to use for testing accuracy.",
)
args = argparser.parse_args()

deit = ViTForImageClassification.from_pretrained(
args.model_path,
num_labels=37,
ignore_mismatched_sizes=True,
).eval()
image_preprocessor = AutoImageProcessor.from_pretrained(
"facebook/deit-tiny-patch16-224", use_fast=True
)

compile_spec = EthosUCompileSpec(
target="ethos-u85-256",
memory_mode="Shared_Sram",
)

quantizer = EthosUQuantizer(compile_spec)
operator_config = get_symmetric_quantization_config()
quantizer.set_global(operator_config)

ds = load_dataset("timm/oxford-iiit-pet")

split = ds["train"].train_test_split(test_size=0.1, seed=42)
dataset = DatasetDict(
{
"train": split["train"],
"validation": split["test"],
"test": ds["test"],
}
)
dataset = dataset.with_transform(make_transform(image_preprocessor))

with torch.no_grad():
quantized_deit = quantize_model(
deit,
quantizer,
dataset["train"].take(args.num_calibration_samples),
)
measure_accuracy(
quantized_deit.module(), dataset["test"].take(args.num_test_samples)
)

partition = EthosUPartitioner(compile_spec)
edge_encoder = to_edge_transform_and_lower(
programs=quantized_deit,
partitioner=[partition],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
edge_manager = edge_encoder.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

save_pte_program(edge_manager, args.output_path)
print(f"\nExported model saved to {args.output_path}")
130 changes: 130 additions & 0 deletions examples/arm/image_classification_example/model_export/train_deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse

from pathlib import Path

import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from evaluate import load as load_metric
from transformers import AutoImageProcessor, set_seed, Trainer, TrainingArguments

from transformers.models.vit.modeling_vit import ViTForImageClassification


def make_transform(preprocessor):
def transform(batch):
img = [item.convert("RGB") for item in batch["image"]]
inputs = preprocessor(img, return_tensors="pt")

return {
"pixel_values": inputs["pixel_values"],
"labels": batch["label"],
}

return transform


def make_compute_metrics(accuracy_metric):
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return accuracy_metric.compute(predictions=preds, references=labels)

return compute_metrics


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


if __name__ == "__main__":
# Set the seed for reproducibility
set_seed(42)

argparser = argparse.ArgumentParser(
description="Fine-tune DeIT model on Oxford-IIIT Pet dataset"
)
argparser.add_argument(
"--output-dir",
type=str,
default="./deit-tiny-oxford-pet",
help="Directory to save the trained model",
)
argparser.add_argument(
"--num-epochs", type=int, default=3, help="Number of training epochs"
)
args = argparser.parse_args()
ds = load_dataset("timm/oxford-iiit-pet")

# Create the mappings between labels and IDs
labels = ds["train"].features["label"].names
ids2label = dict(enumerate(labels))
label2ids = {l: i for i, l in enumerate(labels)}

deit = ViTForImageClassification.from_pretrained(
"facebook/deit-tiny-patch16-224",
num_labels=37,
ignore_mismatched_sizes=True,
id2label=ids2label,
label2id=label2ids,
)
image_preprocessor = AutoImageProcessor.from_pretrained(
"facebook/deit-tiny-patch16-224", use_fast=True
)

# Create a validation set by splitting the training set into two parts
split = ds["train"].train_test_split(test_size=0.1, seed=42)
dataset = DatasetDict(
{
"train": split["train"],
"validation": split["test"],
"test": ds["test"],
}
)
dataset = dataset.with_transform(make_transform(image_preprocessor))

training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=16,
eval_strategy="steps",
num_train_epochs=args.num_epochs,
fp16=False,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
load_best_model_at_end=True,
report_to="none",
use_mps_device=torch.backends.mps.is_available(),
)

accuracy_metric = load_metric("accuracy")
trainer = Trainer(
model=deit,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
data_collator=collate_fn,
compute_metrics=make_compute_metrics(accuracy_metric),
)

print("\n Starting training DEiT Tiny on Oxford-IIIT Pet dataset...")
trainer.train()

print("\nEvaluating the model on the test set...")
result = trainer.evaluate(dataset["test"])
print(f"Test set accuracy: {result['eval_accuracy']:.4f}")

final_model_path = Path(args.output_dir) / "final_model"
trainer.save_model(str(final_model_path))
print(f"\nTrained model saved to {final_model_path}")
Loading