Skip to content

Commit e20df45

Browse files
Add Backbone API fine-tuning tutorial (#41590)
--------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 19df66d commit e20df45

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@
284284
title: Knowledge Distillation for Computer Vision
285285
- local: tasks/keypoint_matching
286286
title: Keypoint matching
287+
- local: tasks/training_vision_backbone
288+
title: Training vision models using Backbone API
287289
title: Computer vision
288290
- sections:
289291
- local: tasks/image_captioning
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Training Vision Models using Backbone API
18+
19+
Computer vision workflows follow a common pattern. Use a pre-trained backbone for feature extraction ([ViT](../model_doc/vit), [DINOv3](../model_doc/dinov3)). Add a "neck" for feature enhancement. Attach a task-specific head ([DETR](../model_doc/detr) for object detection, [MaskFormer](../model_doc/maskformer) for segmentation).
20+
21+
The Transformers library implements these models and the [backbone API](../backbones) lets you swap different backbones and heads with minimal code.
22+
23+
![Backbone Explanation](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Backbone.png)
24+
25+
This guide combines [DINOv3 with ConvNext architecture](https://huggingface.co/facebook/dinov3-convnext-large-pretrain-lvd1689m) and a [DETR head](https://huggingface.co/facebook/detr-resnet-50). You'll train on the [license plate detection dataset](https://huggingface.co/datasets/merve/license-plates). DINOv3 delivers the best performance as of this writing.
26+
27+
> [!NOTE]
28+
> This model requires access approval. Visit [the model repository](https://huggingface.co/facebook/dinov3-convnext-large-pretrain-lvd1689m) to request access.
29+
30+
Install [trackio](https://github.com/gradio-app/trackio) for experiment tracking and [albumentations](https://albumentations.ai/) for data augmentation. Use the latest transformers version.
31+
32+
```bash
33+
pip install -Uq albumentations trackio transformers datasets
34+
```
35+
36+
Initialize [`DetrConfig`] with the pre-trained DINOv3 ConvNext backbone. Use `num_labels=1` to detect the license plate bounding boxes. Create [`DetrForObjectDetection`] with this configuration. Freeze the backbone to preserve DINOv3 features without updating weights. Load the [`DetrImageProcessor`].
37+
38+
```py
39+
from transformers import DetrConfig, DetrForObjectDetection, AutoImageProcessor
40+
41+
config = DetrConfig(backbone="facebook/dinov3-convnext-large-pretrain-lvd1689m",
42+
use_pretrained_backbone=True, use_timm_backbone=False,
43+
num_labels=1, id2label={0: "license_plate"}, label2id={"license_plate": 0})
44+
model = DetrForObjectDetection(config)
45+
46+
for param in model.model.backbone.parameters():
47+
param.requires_grad = False
48+
image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
49+
```
50+
51+
Load the dataset and split it for training.
52+
53+
```py
54+
from datasets import load_dataset
55+
ds = load_dataset("merve/license-plates")
56+
ds = ds["train"]
57+
58+
ds = ds.train_test_split(test_size=0.05)
59+
train_dataset = ds["train"]
60+
val_dataset = ds["test"]
61+
len(train_dataset)
62+
# 5867
63+
```
64+
65+
Augment the dataset. Rescale images to a maximum size, flip them, and apply affine transforms. Eliminate invalid bounding boxes and ensure annotations stay clean with `rebuild_objects`.
66+
67+
```py
68+
import albumentations as A
69+
import numpy as np
70+
from PIL import Image
71+
72+
train_aug = A.Compose(
73+
[
74+
A.LongestMaxSize(max_size=1024, p=1.0),
75+
A.HorizontalFlip(p=0.5),
76+
A.Affine(rotate=(-5, 5), shear=(-5, 5), translate_percent=(0.05, 0.05), p=0.5),
77+
],
78+
bbox_params=A.BboxParams(format="coco", label_fields=["category_id"], min_visibility=0.0),
79+
)
80+
81+
def train_transform(batch):
82+
imgs_out, objs_out = [], []
83+
original_imgs, original_objs = batch["image"], batch["objects"]
84+
85+
for i, (img_pil, objs) in enumerate(zip(original_imgs, original_objs)):
86+
img = np.array(img_pil)
87+
labels = [0] * len(objs["bbox"])
88+
89+
out = train_aug(image=img, bboxes=list(objs["bbox"]), category_id=labels)
90+
91+
if len(out["bboxes"]) == 0:
92+
imgs_out.append(img_pil) # if no boxes left after augmentation, use original
93+
objs_out.append(objs)
94+
continue
95+
96+
H, W = out["image"].shape[:2]
97+
clamped = []
98+
for (x, y, w, h) in out["bboxes"]:
99+
x = max(0.0, min(x, W - 1.0))
100+
y = max(0.0, min(y, H - 1.0))
101+
w = max(1.0, min(w, W - x))
102+
h = max(1.0, min(h, H - y))
103+
clamped.append([x, y, w, h])
104+
105+
imgs_out.append(Image.fromarray(out["image"]))
106+
objs_out.append(rebuild_objects(clamped, out["category_id"]))
107+
108+
batch["image"] = imgs_out
109+
batch["objects"] = objs_out
110+
return batch
111+
112+
113+
114+
def rebuild_objects(bboxes, labels):
115+
bboxes = [list(map(float, b)) for b in bboxes]
116+
areas = [float(w*h) for (_, _, w, h) in bboxes]
117+
ids = list(range(len(bboxes)))
118+
return {
119+
"id": ids,
120+
"bbox": bboxes,
121+
"category_id": list(map(int, labels)),
122+
"area": areas,
123+
"iscrowd": [0]*len(bboxes),
124+
}
125+
126+
train_dataset = train_dataset.with_transform(train_transform)
127+
```
128+
129+
130+
Build COCO-style annotations for the image processor.
131+
132+
```py
133+
import torch
134+
135+
def format_annotations(image, objects, image_id):
136+
n = len(objects["id"])
137+
anns = []
138+
iscrowd_list = objects.get("iscrowd", [0] * n)
139+
area_list = objects.get("area", None)
140+
141+
for i in range(n):
142+
x, y, w, h = objects["bbox"][i]
143+
area = area_list[i] if area_list is not None else float(w * h)
144+
145+
anns.append({
146+
"id": int(objects["id"][i]),
147+
"iscrowd": int(iscrowd_list[i]),
148+
"bbox": [float(x), float(y), float(w), float(h)],
149+
"category_id": int(objects.get("category_id", objects.get("category"))[i]),
150+
"area": float(area),
151+
})
152+
153+
return {"image_id": int(image_id), "annotations": anns}
154+
```
155+
156+
Create batches in the data collator. Format annotations and pass them with transformed images to the image processor.
157+
158+
```py
159+
def collate_fn(examples):
160+
images = [example["image"] for example in examples]
161+
ann_batch = [format_annotations(example["image"], example["objects"], example["image_id"]) for example in examples]
162+
163+
inputs = image_processor(images=images, annotations=ann_batch, return_tensors="pt")
164+
return inputs
165+
```
166+
167+
Initialize the [`Trainer`] and set up [`TrainingArguments`] for model convergence. Pass datasets, data collator, arguments, and model to `Trainer` to start training.
168+
169+
```py
170+
from transformers import Trainer, TrainingArguments
171+
172+
training_args = TrainingArguments(
173+
output_dir="./license-plate-detr-dinov3",
174+
per_device_train_batch_size=4,
175+
per_device_eval_batch_size=4,
176+
num_train_epochs=8,
177+
learning_rate=1e-5,
178+
weight_decay=1e-4,
179+
warmup_steps=500,
180+
eval_strategy="steps",
181+
eval_steps=500,
182+
save_total_limit=2,
183+
dataloader_pin_memory=False,
184+
fp16=True,
185+
report_to="trackio",
186+
load_best_model_at_end=True,
187+
remove_unused_columns=False,
188+
push_to_hub=True,
189+
)
190+
191+
trainer = Trainer(
192+
model=model,
193+
args=training_args,
194+
train_dataset=train_dataset,
195+
eval_dataset=val_dataset,
196+
data_collator=collate_fn,
197+
)
198+
199+
trainer.train()
200+
```
201+
202+
Push the trainer and image processor to the Hub.
203+
204+
```py
205+
trainer.push_to_hub()
206+
image_processor.push_to_hub("merve/license-plate-detr-dinov3")
207+
```
208+
209+
Test the model with an object detection pipeline.
210+
211+
```py
212+
from transformers import pipeline
213+
214+
obj_detector = pipeline(
215+
"object-detection", model="merve/license-plate-detr-dinov3"
216+
)
217+
results = obj_detector("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/license-plates.jpg", threshold=0.05)
218+
print(results)
219+
```
220+
221+
Visualize the results.
222+
223+
```py
224+
from PIL import Image, ImageDraw
225+
import numpy as np
226+
import requests
227+
228+
229+
def plot_results(image, results, threshold):
230+
image = Image.fromarray(np.uint8(image))
231+
draw = ImageDraw.Draw(image)
232+
width, height = image.size
233+
234+
for result in results:
235+
score = result["score"]
236+
label = result["label"]
237+
box = list(result["box"].values())
238+
239+
if score > threshold:
240+
x1, y1, x2, y2 = tuple(box)
241+
draw.rectangle((x1, y1, x2, y2), outline="red")
242+
draw.text((x1 + 5, y1 + 10), f"{score:.2f}", fill="green" if score > 0.7 else "red")
243+
244+
return image
245+
246+
image = Image.open(requests.get("https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/license-plates.jpg", stream=True).raw)
247+
plot_results(image, results, threshold=0.05)
248+
```
249+
250+
![Results](https://huggingface.co/datasets/huggingface/documentation-images/results/main/transformers/tasks/backbone_training_results.png)

0 commit comments

Comments
 (0)