Skip to content

Commit 51d9c75

Browse files
authored
Arm backend: Add training & export scripts for DEiT example (#16315)
* Add example training and export scripts into examples/arm/image_classification_example * Add README.md documenting usage Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent 3233761 commit 51d9c75

File tree

4 files changed

+392
-0
lines changed

4 files changed

+392
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Image Classification Example Application
2+
3+
This end-to-end example shows how to use the Arm backend in ExecuTorch across both ahead-of-time (AoT) and runtime flows. It covers
4+
this by providing examples of:
5+
6+
- Scripts to fine-tune a DeiT-Tiny model on the Oxford-IIIT Pet dataset, quantize it, and export an Ethos-U–ready ExecuTorch program.
7+
- A simple bare-metal image-classification app for Corstone-320 (Ethos-U85-256) that embeds the exported program and a sample image.
8+
- Running the app on the Corstone-320 Fixed Virtual Platform (FVP).
9+
10+
## Layout
11+
12+
The example is divided into two sections:
13+
14+
- `model_export/README.md` — Covers fine-tuning a model for a new usecase, quantization to INT8, lowering to Ethos-U via ExecuTorch and `.pte` generation.
15+
- `runtime/README.md` — Covers building the bare-metal app, generating headers from the `.pte` and image, and running on the FVP.
16+
17+
In addition, this example uses `../executor_runner/` for various utilities (linker scripts, memory allocators, and the PTE-to-header converter).
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# DEiT Fine-Tuning & Export
2+
3+
This example provides two scripts:
4+
5+
- `train_deit.py` — Fine-tunes the DEiT-Tiny model, initially trained on ImageNet 1k, on the Oxford-IIIT Pet dataset to repurpose the network
6+
to classify cat and dog breeds. This is intended to demonstrate the process of preparing a model for a new usecase, before lowering it via ExecuTorch.
7+
- `export_deit.py` — Loads the trained checkpoint from `train_deit.py`, applies post-training quantization (PT2E), evaluates, and exports an Ethos-U–ready ExecuTorch program.
8+
9+
The Oxford-IIIT Pet dataset is used by both scripts as it's a relatively small dataset, allowing this example to be run on a high-end laptop or desktop.
10+
11+
See the sections below for requirements and exact commands.
12+
13+
## Requirements
14+
15+
- Python 3.10+ with `executorch` and the dependencies in `requirements-examples.txt`.
16+
- Internet access to download pretrained weights and the Oxford-IIIT Pet dataset.
17+
18+
## Fine-tuning DEiT Tiny
19+
20+
The `train_deit.py` script can be run as follows:
21+
22+
```bash
23+
python examples/arm/image_classification_example/model_export/train_deit.py \
24+
--output-dir ./deit-tiny-oxford-pet \
25+
--num-epochs 3
26+
```
27+
28+
The script splits the training set for validation, fine-tunes the model, reports test accuracy, and by default outputs the model to `deit-tiny-oxford-pet/final_model`.
29+
Running this script achieves a test set accuracy of 86.10% in FP32.
30+
31+
## Export and quantize
32+
33+
The `export_deit.py` script can be run as follows:
34+
35+
```bash
36+
python examples/arm/image_classification_example/model_export/export_deit.py \
37+
--model-path ./deit-tiny-oxford-pet/final_model \
38+
--output-path ./deit_quantized_exported.pte \
39+
--num-calibration-samples 300 \
40+
--num-test-samples 100
41+
```
42+
43+
During export, the script:
44+
- Exports the FP32 model using `torch.export.export()`.
45+
- Applies symmetric quantization to each operator.
46+
- Targets `Ethos-U85-256` with shared SRAM and lowers the network to Ethos-U.
47+
- Writes the ExecuTorch program to the requested path.
48+
49+
Running this script following the `train_deit.py` script achieves a test set accuracy of 85.00% for the quantized model on 100 samples.
50+
51+
### Interpreting Vela Output
52+
53+
After the model has been compiled for Ethos-U, the Vela compiler will output a network summary. You will see output similar to:
54+
55+
```
56+
Network summary for out
57+
Accelerator configuration Ethos_U85_256
58+
System configuration Ethos_U85_SYS_DRAM_Mid
59+
Memory mode Shared_Sram
60+
Accelerator clock 1000 MHz
61+
Design peak SRAM bandwidth 29.80 GB/s
62+
Design peak DRAM bandwidth 11.18 GB/s
63+
64+
Total SRAM used 1291.80 KiB
65+
Total DRAM used 5289.91 KiB
66+
67+
CPU operators = 0 (0.0%)
68+
NPU operators = 898 (100.0%)
69+
70+
... (Truncated)
71+
```
72+
73+
Some of this information is key to understanding the example application, which will run this model on device:
74+
75+
- The `Accelerator configuration` is `Ethos_U85_256`, so it will only work on an Ethos-U85 system. The FVP for this is Corstone-320.
76+
- The `Memory mode` is `Shared_Sram`, so the tensor arena is allocated in SRAM while the model data is read from flash and DRAM.
77+
- The `Total SRAM used` is `1291.80 KiB`, so at least this much memory will need to be allocated for the tensor arena.
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
8+
import torch
9+
import tqdm
10+
from datasets import DatasetDict, load_dataset
11+
12+
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
13+
from executorch.backends.arm.quantizer import (
14+
EthosUQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
from executorch.exir import (
18+
EdgeCompileConfig,
19+
ExecutorchBackendConfig,
20+
to_edge_transform_and_lower,
21+
)
22+
from executorch.extension.export_util.utils import save_pte_program
23+
24+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
25+
from transformers import AutoImageProcessor
26+
from transformers.models.vit.modeling_vit import ViTForImageClassification
27+
28+
29+
def make_transform(preprocessor):
30+
def transform(batch):
31+
img = [item.convert("RGB") for item in batch["image"]]
32+
inputs = preprocessor(img, return_tensors="pt")
33+
34+
return {
35+
"pixel_values": inputs["pixel_values"].unsqueeze(0),
36+
"labels": batch["label"],
37+
}
38+
39+
return transform
40+
41+
42+
def quantize_model(model, quantizer, calibration_data):
43+
example_input = calibration_data[0]["pixel_values"]
44+
45+
exported_model = torch.export.export(
46+
model,
47+
(example_input,),
48+
).module()
49+
50+
quantize = prepare_pt2e(exported_model, quantizer)
51+
52+
print("\nCalibrating the model...")
53+
for example in tqdm.tqdm(calibration_data):
54+
quantize(example["pixel_values"])
55+
56+
pt2e_deit = convert_pt2e(quantize)
57+
58+
return torch.export.export(
59+
pt2e_deit,
60+
(example_input,),
61+
)
62+
63+
64+
def measure_accuracy(quantized_model, test_set):
65+
examples = 0
66+
correct = 0
67+
68+
print("\nMeasuring accuracy on the test set...")
69+
tbar = tqdm.tqdm(test_set)
70+
for example in tbar:
71+
img = example["pixel_values"]
72+
output = quantized_model(img)
73+
output = output.logits.argmax(dim=-1).item()
74+
75+
if output == example["labels"]:
76+
correct += 1
77+
examples += 1
78+
accuracy = correct / examples
79+
80+
tbar.set_description(f"Accuracy: {accuracy:.4f}")
81+
82+
print(f"Top-1 accuracy on {examples} test samples: {accuracy:.4f}")
83+
return accuracy
84+
85+
86+
if __name__ == "__main__":
87+
argparser = argparse.ArgumentParser(description="Export ViT model")
88+
argparser.add_argument(
89+
"--model-path",
90+
type=str,
91+
default="./deit-tiny-oxford-pet/final_model",
92+
required=True,
93+
help="Path to the fine-tuned ViT model.",
94+
)
95+
argparser.add_argument(
96+
"--output-path",
97+
type=str,
98+
default="./deit_quantized_exported.pte",
99+
help="Path to save the exported quantized model.",
100+
)
101+
argparser.add_argument(
102+
"--num-calibration-samples",
103+
type=int,
104+
default=300,
105+
help="Number of samples to use for calibration.",
106+
)
107+
argparser.add_argument(
108+
"--num-test-samples",
109+
type=int,
110+
default=100,
111+
help="Number of samples to use for testing accuracy.",
112+
)
113+
args = argparser.parse_args()
114+
115+
deit = ViTForImageClassification.from_pretrained(
116+
args.model_path,
117+
num_labels=37,
118+
ignore_mismatched_sizes=True,
119+
).eval()
120+
image_preprocessor = AutoImageProcessor.from_pretrained(
121+
"facebook/deit-tiny-patch16-224", use_fast=True
122+
)
123+
124+
compile_spec = EthosUCompileSpec(
125+
target="ethos-u85-256",
126+
memory_mode="Shared_Sram",
127+
)
128+
129+
quantizer = EthosUQuantizer(compile_spec)
130+
operator_config = get_symmetric_quantization_config()
131+
quantizer.set_global(operator_config)
132+
133+
ds = load_dataset("timm/oxford-iiit-pet")
134+
135+
split = ds["train"].train_test_split(test_size=0.1, seed=42)
136+
dataset = DatasetDict(
137+
{
138+
"train": split["train"],
139+
"validation": split["test"],
140+
"test": ds["test"],
141+
}
142+
)
143+
dataset = dataset.with_transform(make_transform(image_preprocessor))
144+
145+
with torch.no_grad():
146+
quantized_deit = quantize_model(
147+
deit,
148+
quantizer,
149+
dataset["train"].take(args.num_calibration_samples),
150+
)
151+
measure_accuracy(
152+
quantized_deit.module(), dataset["test"].take(args.num_test_samples)
153+
)
154+
155+
partition = EthosUPartitioner(compile_spec)
156+
edge_encoder = to_edge_transform_and_lower(
157+
programs=quantized_deit,
158+
partitioner=[partition],
159+
compile_config=EdgeCompileConfig(
160+
_check_ir_validity=False,
161+
),
162+
)
163+
edge_manager = edge_encoder.to_executorch(
164+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
165+
)
166+
167+
save_pte_program(edge_manager, args.output_path)
168+
print(f"\nExported model saved to {args.output_path}")
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
8+
from pathlib import Path
9+
10+
import numpy as np
11+
import torch
12+
from datasets import DatasetDict, load_dataset
13+
from evaluate import load as load_metric
14+
from transformers import AutoImageProcessor, set_seed, Trainer, TrainingArguments
15+
16+
from transformers.models.vit.modeling_vit import ViTForImageClassification
17+
18+
19+
def make_transform(preprocessor):
20+
def transform(batch):
21+
img = [item.convert("RGB") for item in batch["image"]]
22+
inputs = preprocessor(img, return_tensors="pt")
23+
24+
return {
25+
"pixel_values": inputs["pixel_values"],
26+
"labels": batch["label"],
27+
}
28+
29+
return transform
30+
31+
32+
def make_compute_metrics(accuracy_metric):
33+
def compute_metrics(eval_pred):
34+
logits, labels = eval_pred
35+
preds = np.argmax(logits, axis=-1)
36+
return accuracy_metric.compute(predictions=preds, references=labels)
37+
38+
return compute_metrics
39+
40+
41+
def collate_fn(examples):
42+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
43+
labels = torch.tensor([example["labels"] for example in examples])
44+
return {"pixel_values": pixel_values, "labels": labels}
45+
46+
47+
if __name__ == "__main__":
48+
# Set the seed for reproducibility
49+
set_seed(42)
50+
51+
argparser = argparse.ArgumentParser(
52+
description="Fine-tune DeIT model on Oxford-IIIT Pet dataset"
53+
)
54+
argparser.add_argument(
55+
"--output-dir",
56+
type=str,
57+
default="./deit-tiny-oxford-pet",
58+
help="Directory to save the trained model",
59+
)
60+
argparser.add_argument(
61+
"--num-epochs", type=int, default=3, help="Number of training epochs"
62+
)
63+
args = argparser.parse_args()
64+
ds = load_dataset("timm/oxford-iiit-pet")
65+
66+
# Create the mappings between labels and IDs
67+
labels = ds["train"].features["label"].names
68+
ids2label = dict(enumerate(labels))
69+
label2ids = {l: i for i, l in enumerate(labels)}
70+
71+
deit = ViTForImageClassification.from_pretrained(
72+
"facebook/deit-tiny-patch16-224",
73+
num_labels=37,
74+
ignore_mismatched_sizes=True,
75+
id2label=ids2label,
76+
label2id=label2ids,
77+
)
78+
image_preprocessor = AutoImageProcessor.from_pretrained(
79+
"facebook/deit-tiny-patch16-224", use_fast=True
80+
)
81+
82+
# Create a validation set by splitting the training set into two parts
83+
split = ds["train"].train_test_split(test_size=0.1, seed=42)
84+
dataset = DatasetDict(
85+
{
86+
"train": split["train"],
87+
"validation": split["test"],
88+
"test": ds["test"],
89+
}
90+
)
91+
dataset = dataset.with_transform(make_transform(image_preprocessor))
92+
93+
training_args = TrainingArguments(
94+
output_dir=args.output_dir,
95+
per_device_train_batch_size=16,
96+
eval_strategy="steps",
97+
num_train_epochs=args.num_epochs,
98+
fp16=False,
99+
save_steps=100,
100+
eval_steps=100,
101+
logging_steps=10,
102+
learning_rate=2e-4,
103+
save_total_limit=2,
104+
remove_unused_columns=False,
105+
push_to_hub=False,
106+
load_best_model_at_end=True,
107+
report_to="none",
108+
use_mps_device=torch.backends.mps.is_available(),
109+
)
110+
111+
accuracy_metric = load_metric("accuracy")
112+
trainer = Trainer(
113+
model=deit,
114+
args=training_args,
115+
train_dataset=dataset["train"],
116+
eval_dataset=dataset["validation"],
117+
data_collator=collate_fn,
118+
compute_metrics=make_compute_metrics(accuracy_metric),
119+
)
120+
121+
print("\n Starting training DEiT Tiny on Oxford-IIIT Pet dataset...")
122+
trainer.train()
123+
124+
print("\nEvaluating the model on the test set...")
125+
result = trainer.evaluate(dataset["test"])
126+
print(f"Test set accuracy: {result['eval_accuracy']:.4f}")
127+
128+
final_model_path = Path(args.output_dir) / "final_model"
129+
trainer.save_model(str(final_model_path))
130+
print(f"\nTrained model saved to {final_model_path}")

0 commit comments

Comments
 (0)