Skip to content

Commit bd37c45

Browse files
yonigozlanSangbumChoiRUFFY-369haithamkhedrsangbum choi
authored
Add EdgeTAM (#39800)
* initial comment * test * initial conversion for outline * intermediate commit for configuration * chore:init files for sam2 * adding arbitary undefined config * check * add vision * make style * init sam2 base model * Fix imports * Linting * chore:sam to sam2 classes * Linting * Add sam2 to models.__init__ * chore:match prompt encoder with sam2 code * chore:prepare kwargs for mask decoder * Add image/video predictors * Add CUDA kernel * Add output classes * linting * Add logging info * tmp commit * docs for sam2 * enable image processing * check difference of original SAM2 - difference is the order of ToTensor() - please see https://pytorch.org/vision/main/_modules/torchvision/transforms/functional.html#resize * enable promptencoder of sam2 * fix promprencoder * Confirmed that PromptEncoder is exactly same (Be aware of bfloat16 and float32 difference) * Confirmed that ImageEncoder is exactly same (Be aware the linting of init) * Confirmed that MaskDecoder is exactly same (TO DO: lint variable name) * SamModel is now available (Need more chore for name) * make fix-copies * make style * make CI happy * Refactor VisionEncoder and PostioinEmbedding * TO DO : fix the image_embeddings and sparse_embeddings part * pure image inference done * reusable features fix and make style * styling * refactor memoryattention * tmp * tmp * refactor memoryencoder TO DO : convert and inference the video pipeline * TO DO : fix the image_encoder shape * conversion finish TO DO: need to check video inference * make style * remove video model * lint * change * python utils/check_docstringspy --check_all * python utils/check_config_attributes.py * remove copies for sam2promptencoder due to configuration * change __init__.py * remove tensorflow version * fix that to not use direct comparison * make style * add missing import * fix image_embedding_size * refactor Sam2 Attention * add fully working video inference (refactoring todo) * clarify _prepare_memory_conditioned_features * simplify modeling code, remove unused paths * use one model * use auto_docstring * refactor rope embeddings * nit * not using multimask when several points given * add all sam2.1 * add video tmp * add Sam2VideoSessionState + fast image proc + video proc * remove init_states from model * fix batch inference * add image integration tests * uniformize modeling code with other sam models and use modular * pass vision tests an most model tests * All tests passing * add offloading inference state and video to cpu * fix inference from image embedding and existing mask * fix multi_boxes mask inference * Fix batch images + batch boxes inference * improve processing for image inference * add support for mask generation pipeline * add support for get_connected_components post processing in mask generation * add fast image processor sam, image processor tests and use modular for sam2 image processor * fix mistake in sam after #39120 * fix init weights * refactor convert * add integration tests for video + other improvements * add needed missing docstrings * Improve docstrings and * improve inference speed by avoiding cuda sync * add test * skip test for vision_model * minor fix for vision_model * fix vision_model by adding sam2model and change the torch dependencies * remove patch_size * remove image_embedding_size * fix patch_size * fix test * make style * Separate hieradet and vision encoder in sam2 * fixup * review changes part 1 * remove MemoryEncoderConfig and MemoryAttentionConfig * pass q_stride instead of q_pool module * add inference on streamed videos * explicitely process streamed frames * nit * Improve docstrings in Sam2Model * update sam2 modeling with better gestion of inference state and cache, and separate Sam2Model and Sam2VideoModel * improve video inference api * change inference_state to inference_session * use modular for Sam2Model * fix convert sam2 hf * modular * Update src/transformers/models/sam2/video_processing_sam2.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fix minor config * fix attention loading error * update modeling tests to use hub checkpoints * Use CI A10 runner for integration tests values + higher tolerance for video integration tests * PR review part 1 * fix doc * nit improvements * enforce one input format for points, labels and boxes * nit * last few nits from PR review * fix style * fix the input type * fix docs * add sam2 model as conversion script * improve sam2 doc * add rough necessarry changes * first working edgetam * fix issue with object pointers * Use modular as much as possible * nit fixes + optimization * refactor spatial perceiver * cleanup after merge * add working edgetam * improve perceiver resampler code * simplify/unify rope attention logic * Improve comments in apply_rotary_pos_emb_2d * add working tests * fix test timmwrapper * add docs * make fixup * nits * fix modular * fix modular * PR review part 1 * split apply_rotary_pos_emb_2d * add granularity to _prepare_memory_conditioned_features * add dates to doc * add separate mlp for memory attention * Fix memory on wrong device * store processed frames in dict * update checkpoints in tests * update dates --------- Co-authored-by: sangbumchoi <danielsejong55@gmail.com> Co-authored-by: RUFFY-369 <prakarshkaushik369@gmail.com> Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Co-authored-by: Haitham Khedr <haithamkhedr@meta.com> Co-authored-by: sangbum choi <sangbumchoi@sangbumui-MacBookAir.local> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent c1db386 commit bd37c45

32 files changed

+9621
-297
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,10 @@
10331033
title: DePlot
10341034
- local: model_doc/donut
10351035
title: Donut
1036+
- local: model_doc/edgetam
1037+
title: EdgeTAM
1038+
- local: model_doc/edgetam_video
1039+
title: EdgeTamVideo
10361040
- local: model_doc/emu3
10371041
title: Emu3
10381042
- local: model_doc/evolla
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
*This model was released on 2025-01-13 and added to Hugging Face Transformers on 2025-09-29.*
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
22+
</div>
23+
</div>
24+
25+
# EdgeTAM
26+
27+
## Overview
28+
29+
The EdgeTAM model was proposed in [EdgeTAM: On-Device Track Anything Model](https://huggingface.co/papers/2501.07256) Chong Zhou, Chenchen Zhu, Yunyang Xiong, Saksham Suri, Fanyi Xiao, Lemeng Wu, Raghuraman Krishnamoorthi, Bo Dai, Chen Change Loy, Vikas Chandra, Bilge Soran.
30+
31+
EdgeTAM is an efficient adaptation of SAM 2 that introduces a 2D Spatial Perceiver architecture to optimize memory attention mechanisms for real-time video segmentation on mobile devices.
32+
33+
The abstract from the paper is the following:
34+
35+
*On top of Segment Anything Model (SAM), SAM 2 further extends its capability from image to video inputs through a memory bank mechanism and obtains a remarkable performance compared with previous methods, making it a foundation model for video segmentation task. In this paper, we aim at making SAM 2 much more efficient so that it even runs on mobile devices while maintaining a comparable performance. Despite several works optimizing SAM for better efficiency, we find they are not sufficient for SAM 2 because they all focus on compressing the image encoder, while our benchmark shows that the newly introduced memory attention blocks are also the latency bottleneck. Given this observation, we propose EdgeTAM, which leverages a novel 2D Spatial Perceiver to reduce the computational cost. In particular, the proposed 2D Spatial Perceiver encodes the densely stored frame-level memories with a lightweight Transformer that contains a fixed set of learnable queries. Given that video segmentation is a dense prediction task, we find preserving the spatial structure of the memories is essential so that the queries are split into global-level and patch-level groups. We also propose a distillation pipeline that further improves the performance without inference overhead. As a result, EdgeTAM achieves 87.7, 70.0, 72.3, and 71.7 J&F on DAVIS 2017, MOSE, SA-V val, and SA-V test, while running at 16 FPS on iPhone 15 Pro Max.*
36+
37+
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
38+
The original code can be found [here](https://github.com/facebookresearch/EdgeTAM).
39+
40+
## Usage example
41+
42+
### Automatic Mask Generation with Pipeline
43+
44+
EdgeTAM can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline:
45+
46+
```python
47+
>>> from transformers import pipeline
48+
49+
>>> generator = pipeline("mask-generation", model="yonigozlan/edgetam-1", device=0)
50+
>>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
51+
>>> outputs = generator(image_url, points_per_batch=64)
52+
53+
>>> len(outputs["masks"]) # Number of masks generated
54+
39
55+
```
56+
57+
### Basic Image Segmentation
58+
59+
#### Single Point Click
60+
61+
You can segment objects by providing a single point click on the object you want to segment:
62+
63+
```python
64+
>>> from transformers import Sam2Processor, EdgeTamModel, infer_device
65+
>>> import torch
66+
>>> from PIL import Image
67+
>>> import requests
68+
69+
>>> device = infer_device()
70+
71+
>>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device)
72+
>>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1")
73+
74+
>>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
75+
>>> raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
76+
77+
>>> input_points = [[[[500, 375]]]] # Single point click, 4 dimensions (image_dim, object_dim, point_per_object_dim, coordinates)
78+
>>> input_labels = [[[1]]] # 1 for positive click, 0 for negative click, 3 dimensions (image_dim, object_dim, point_label)
79+
80+
>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device)
81+
82+
>>> with torch.no_grad():
83+
... outputs = model(**inputs)
84+
85+
>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
86+
87+
>>> # The model outputs multiple mask predictions ranked by quality score
88+
>>> print(f"Generated {masks.shape[1]} masks with shape {masks.shape}")
89+
Generated 3 masks with shape torch.Size([1, 3, 1200, 1800])
90+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
91+
IoU scores: tensor([0.0463, 0.4859, 0.7616], device='cuda:0')
92+
```
93+
94+
#### Multiple Points for Refinement
95+
96+
You can provide multiple points to refine the segmentation:
97+
98+
```python
99+
>>> # Add both positive and negative points to refine the mask
100+
>>> input_points = [[[[500, 375], [1125, 625]]]] # Multiple points for refinement
101+
>>> input_labels = [[[1, 1]]] # Both positive clicks
102+
103+
>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
104+
105+
>>> with torch.no_grad():
106+
... outputs = model(**inputs)
107+
108+
>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
109+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
110+
IoU scores: tensor([0.8362, 0.6900, 0.2120], device='cuda:0')
111+
```
112+
113+
#### Bounding Box Input
114+
115+
EdgeTAM also supports bounding box inputs for segmentation:
116+
117+
```python
118+
>>> # Define bounding box as [x_min, y_min, x_max, y_max]
119+
>>> input_boxes = [[[75, 275, 1725, 850]]]
120+
121+
>>> inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(device)
122+
123+
>>> with torch.no_grad():
124+
... outputs = model(**inputs)
125+
126+
>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
127+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
128+
IoU scores: tensor([0.9301, 0.9348, 0.6605], device='cuda:0')
129+
```
130+
131+
#### Multiple Objects Segmentation
132+
133+
You can segment multiple objects simultaneously:
134+
135+
```python
136+
>>> # Define points for two different objects
137+
>>> input_points = [[[[500, 375]], [[650, 750]]]] # Points for two objects in same image
138+
>>> input_labels = [[[1], [1]]] # Positive clicks for both objects
139+
140+
>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
141+
142+
>>> with torch.no_grad():
143+
... outputs = model(**inputs, multimask_output=False)
144+
145+
>>> # Each object gets its own mask
146+
>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
147+
>>> print(f"Generated masks for {masks.shape[0]} objects")
148+
Generated masks for 2 objects
149+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
150+
IoU scores: tensor([0.7616, 0.9465], device='cuda:0')
151+
```
152+
153+
### Batch Inference
154+
155+
#### Batched Images
156+
157+
Process multiple images simultaneously for improved efficiency:
158+
159+
```python
160+
>>> from transformers import Sam2Processor, EdgeTamModel, infer_device
161+
>>> import torch
162+
>>> from PIL import Image
163+
>>> import requests
164+
165+
>>> device = infer_device()
166+
167+
>>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device)
168+
>>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1")
169+
170+
>>> # Load multiple images
171+
>>> image_urls = [
172+
... "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg",
173+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
174+
... ]
175+
>>> raw_images = [Image.open(requests.get(url, stream=True).raw).convert("RGB") for url in image_urls]
176+
177+
>>> # Single point per image
178+
>>> input_points = [[[[500, 375]]], [[[770, 200]]]] # One point for each image
179+
>>> input_labels = [[[1]], [[1]]] # Positive clicks for both images
180+
181+
>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device)
182+
183+
>>> with torch.no_grad():
184+
... outputs = model(**inputs, multimask_output=False)
185+
186+
>>> # Post-process masks for each image
187+
>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
188+
>>> print(f"Processed {len(all_masks)} images, each with {all_masks[0].shape[0]} objects")
189+
Processed 2 images, each with 1 objects
190+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
191+
IoU scores: tensor([0.7618, 0.7999], device='cuda:0')
192+
```
193+
194+
#### Batched Objects per Image
195+
196+
Segment multiple objects within each image using batch inference:
197+
198+
```python
199+
>>> # Multiple objects per image - different numbers of objects per image
200+
>>> input_points = [
201+
... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects
202+
... [[[770, 200]]] # Dog image: 1 object
203+
... ]
204+
>>> input_labels = [
205+
... [[1], [1]], # Truck image: positive clicks for both objects
206+
... [[1]] # Dog image: positive click for the object
207+
... ]
208+
209+
>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
210+
211+
>>> with torch.no_grad():
212+
... outputs = model(**inputs, multimask_output=False)
213+
214+
>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
215+
```
216+
217+
#### Batched Images with Batched Objects and Multiple Points
218+
219+
Handle complex batch scenarios with multiple points per object:
220+
221+
```python
222+
>>> # Add groceries image for more complex example
223+
>>> groceries_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
224+
>>> groceries_image = Image.open(requests.get(groceries_url, stream=True).raw).convert("RGB")
225+
>>> raw_images = [raw_images[0], groceries_image] # Use truck and groceries images
226+
227+
>>> # Complex batching: multiple images, multiple objects, multiple points per object
228+
>>> input_points = [
229+
... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects with 1 point each
230+
... [[[400, 300]], [[630, 300], [550, 300]]] # Groceries image: obj1 has 1 point, obj2 has 2 points
231+
... ]
232+
>>> input_labels = [
233+
... [[1], [1]], # Truck image: positive clicks
234+
... [[1], [1, 1]] # Groceries image: positive clicks for refinement
235+
... ]
236+
237+
>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
238+
239+
>>> with torch.no_grad():
240+
... outputs = model(**inputs, multimask_output=False)
241+
242+
>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
243+
```
244+
245+
#### Batched Bounding Boxes
246+
247+
Process multiple images with bounding box inputs:
248+
249+
```python
250+
>>> # Multiple bounding boxes per image (using truck and groceries images)
251+
>>> input_boxes = [
252+
... [[75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750]], # Truck image: 4 boxes
253+
... [[450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350]] # Groceries image: 4 boxes
254+
... ]
255+
256+
>>> # Update images for this example
257+
>>> raw_images = [raw_images[0], groceries_image] # truck and groceries
258+
259+
>>> inputs = processor(images=raw_images, input_boxes=input_boxes, return_tensors="pt").to(device)
260+
261+
>>> with torch.no_grad():
262+
... outputs = model(**inputs, multimask_output=False)
263+
264+
>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
265+
>>> print(f"Processed {len(input_boxes)} images with {len(input_boxes[0])} and {len(input_boxes[1])} boxes respectively")
266+
Processed 2 images with 4 and 4 boxes respectively
267+
>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}")
268+
IoU scores: tensor([0.9301, 0.9348, 0.6605, 0.9465], device='cuda:0')
269+
```
270+
271+
### Using Previous Masks as Input
272+
273+
EdgeTAM can use masks from previous predictions as input to refine segmentation:
274+
275+
```python
276+
>>> # Get initial segmentation
277+
>>> input_points = [[[[500, 375]]]]
278+
>>> input_labels = [[[1]]]
279+
>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
280+
281+
>>> with torch.no_grad():
282+
... outputs = model(**inputs)
283+
284+
>>> # Use the best mask as input for refinement
285+
>>> mask_input = outputs.pred_masks[:, :, torch.argmax(outputs.iou_scores.squeeze())]
286+
287+
>>> # Add additional points with the mask input
288+
>>> new_input_points = [[[[500, 375], [450, 300]]]]
289+
>>> new_input_labels = [[[1, 1]]]
290+
>>> inputs = processor(
291+
... input_points=new_input_points,
292+
... input_labels=new_input_labels,
293+
... original_sizes=inputs["original_sizes"],
294+
... return_tensors="pt",
295+
... ).to(device)
296+
297+
>>> with torch.no_grad():
298+
... refined_outputs = model(
299+
... **inputs,
300+
... input_masks=mask_input,
301+
... image_embeddings=outputs.image_embeddings,
302+
... multimask_output=False,
303+
... )
304+
```
305+
306+
307+
## EdgeTamConfig
308+
309+
[[autodoc]] EdgeTamConfig
310+
311+
## EdgeTamVisionConfig
312+
313+
[[autodoc]] EdgeTamVisionConfig
314+
315+
## EdgeTamMaskDecoderConfig
316+
317+
[[autodoc]] EdgeTamMaskDecoderConfig
318+
319+
## EdgeTamPromptEncoderConfig
320+
321+
[[autodoc]] EdgeTamPromptEncoderConfig
322+
323+
## EdgeTamVisionModel
324+
325+
[[autodoc]] EdgeTamVisionModel
326+
- forward
327+
328+
## EdgeTamModel
329+
330+
[[autodoc]] EdgeTamModel
331+
- forward

0 commit comments

Comments
 (0)