diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c0258da704be..2bc2222dd24b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1033,6 +1033,10 @@ title: DePlot - local: model_doc/donut title: Donut + - local: model_doc/edgetam + title: EdgeTAM + - local: model_doc/edgetam_video + title: EdgeTamVideo - local: model_doc/emu3 title: Emu3 - local: model_doc/evolla diff --git a/docs/source/en/model_doc/edgetam.md b/docs/source/en/model_doc/edgetam.md new file mode 100644 index 000000000000..780ccb3f70b3 --- /dev/null +++ b/docs/source/en/model_doc/edgetam.md @@ -0,0 +1,331 @@ + +*This model was released on 2025-01-13 and added to Hugging Face Transformers on 2025-09-29.* +
+
+ PyTorch + SDPA + FlashAttention +
+
+ +# EdgeTAM + +## Overview + +The EdgeTAM model was proposed in [EdgeTAM: On-Device Track Anything Model](https://huggingface.co/papers/2501.07256) Chong Zhou, Chenchen Zhu, Yunyang Xiong, Saksham Suri, Fanyi Xiao, Lemeng Wu, Raghuraman Krishnamoorthi, Bo Dai, Chen Change Loy, Vikas Chandra, Bilge Soran. + +EdgeTAM is an efficient adaptation of SAM 2 that introduces a 2D Spatial Perceiver architecture to optimize memory attention mechanisms for real-time video segmentation on mobile devices. + +The abstract from the paper is the following: + +*On top of Segment Anything Model (SAM), SAM 2 further extends its capability from image to video inputs through a memory bank mechanism and obtains a remarkable performance compared with previous methods, making it a foundation model for video segmentation task. In this paper, we aim at making SAM 2 much more efficient so that it even runs on mobile devices while maintaining a comparable performance. Despite several works optimizing SAM for better efficiency, we find they are not sufficient for SAM 2 because they all focus on compressing the image encoder, while our benchmark shows that the newly introduced memory attention blocks are also the latency bottleneck. Given this observation, we propose EdgeTAM, which leverages a novel 2D Spatial Perceiver to reduce the computational cost. In particular, the proposed 2D Spatial Perceiver encodes the densely stored frame-level memories with a lightweight Transformer that contains a fixed set of learnable queries. Given that video segmentation is a dense prediction task, we find preserving the spatial structure of the memories is essential so that the queries are split into global-level and patch-level groups. We also propose a distillation pipeline that further improves the performance without inference overhead. As a result, EdgeTAM achieves 87.7, 70.0, 72.3, and 71.7 J&F on DAVIS 2017, MOSE, SA-V val, and SA-V test, while running at 16 FPS on iPhone 15 Pro Max.* + +This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan). +The original code can be found [here](https://github.com/facebookresearch/EdgeTAM). + +## Usage example + +### Automatic Mask Generation with Pipeline + +EdgeTAM can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline: + +```python +>>> from transformers import pipeline + +>>> generator = pipeline("mask-generation", model="yonigozlan/edgetam-1", device=0) +>>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" +>>> outputs = generator(image_url, points_per_batch=64) + +>>> len(outputs["masks"]) # Number of masks generated +39 +``` + +### Basic Image Segmentation + +#### Single Point Click + +You can segment objects by providing a single point click on the object you want to segment: + +```python +>>> from transformers import Sam2Processor, EdgeTamModel, infer_device +>>> import torch +>>> from PIL import Image +>>> import requests + +>>> device = infer_device() + +>>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device) +>>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1") + +>>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" +>>> raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") + +>>> input_points = [[[[500, 375]]]] # Single point click, 4 dimensions (image_dim, object_dim, point_per_object_dim, coordinates) +>>> input_labels = [[[1]]] # 1 for positive click, 0 for negative click, 3 dimensions (image_dim, object_dim, point_label) + +>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0] + +>>> # The model outputs multiple mask predictions ranked by quality score +>>> print(f"Generated {masks.shape[1]} masks with shape {masks.shape}") +Generated 3 masks with shape torch.Size([1, 3, 1200, 1800]) +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.0463, 0.4859, 0.7616], device='cuda:0') +``` + +#### Multiple Points for Refinement + +You can provide multiple points to refine the segmentation: + +```python +>>> # Add both positive and negative points to refine the mask +>>> input_points = [[[[500, 375], [1125, 625]]]] # Multiple points for refinement +>>> input_labels = [[[1, 1]]] # Both positive clicks + +>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0] +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.8362, 0.6900, 0.2120], device='cuda:0') +``` + +#### Bounding Box Input + +EdgeTAM also supports bounding box inputs for segmentation: + +```python +>>> # Define bounding box as [x_min, y_min, x_max, y_max] +>>> input_boxes = [[[75, 275, 1725, 850]]] + +>>> inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0] +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.9301, 0.9348, 0.6605], device='cuda:0') +``` + +#### Multiple Objects Segmentation + +You can segment multiple objects simultaneously: + +```python +>>> # Define points for two different objects +>>> input_points = [[[[500, 375]], [[650, 750]]]] # Points for two objects in same image +>>> input_labels = [[[1], [1]]] # Positive clicks for both objects + +>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> # Each object gets its own mask +>>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0] +>>> print(f"Generated masks for {masks.shape[0]} objects") +Generated masks for 2 objects +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.7616, 0.9465], device='cuda:0') +``` + +### Batch Inference + +#### Batched Images + +Process multiple images simultaneously for improved efficiency: + +```python +>>> from transformers import Sam2Processor, EdgeTamModel, infer_device +>>> import torch +>>> from PIL import Image +>>> import requests + +>>> device = infer_device() + +>>> model = EdgeTamModel.from_pretrained("yonigozlan/edgetam-1").to(device) +>>> processor = Sam2Processor.from_pretrained("yonigozlan/edgetam-1") + +>>> # Load multiple images +>>> image_urls = [ +... "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg", +... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" +... ] +>>> raw_images = [Image.open(requests.get(url, stream=True).raw).convert("RGB") for url in image_urls] + +>>> # Single point per image +>>> input_points = [[[[500, 375]]], [[[770, 200]]]] # One point for each image +>>> input_labels = [[[1]], [[1]]] # Positive clicks for both images + +>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> # Post-process masks for each image +>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"]) +>>> print(f"Processed {len(all_masks)} images, each with {all_masks[0].shape[0]} objects") +Processed 2 images, each with 1 objects +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.7618, 0.7999], device='cuda:0') +``` + +#### Batched Objects per Image + +Segment multiple objects within each image using batch inference: + +```python +>>> # Multiple objects per image - different numbers of objects per image +>>> input_points = [ +... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects +... [[[770, 200]]] # Dog image: 1 object +... ] +>>> input_labels = [ +... [[1], [1]], # Truck image: positive clicks for both objects +... [[1]] # Dog image: positive click for the object +... ] + +>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"]) +``` + +#### Batched Images with Batched Objects and Multiple Points + +Handle complex batch scenarios with multiple points per object: + +```python +>>> # Add groceries image for more complex example +>>> groceries_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg" +>>> groceries_image = Image.open(requests.get(groceries_url, stream=True).raw).convert("RGB") +>>> raw_images = [raw_images[0], groceries_image] # Use truck and groceries images + +>>> # Complex batching: multiple images, multiple objects, multiple points per object +>>> input_points = [ +... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects with 1 point each +... [[[400, 300]], [[630, 300], [550, 300]]] # Groceries image: obj1 has 1 point, obj2 has 2 points +... ] +>>> input_labels = [ +... [[1], [1]], # Truck image: positive clicks +... [[1], [1, 1]] # Groceries image: positive clicks for refinement +... ] + +>>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"]) +``` + +#### Batched Bounding Boxes + +Process multiple images with bounding box inputs: + +```python +>>> # Multiple bounding boxes per image (using truck and groceries images) +>>> input_boxes = [ +... [[75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750]], # Truck image: 4 boxes +... [[450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350]] # Groceries image: 4 boxes +... ] + +>>> # Update images for this example +>>> raw_images = [raw_images[0], groceries_image] # truck and groceries + +>>> inputs = processor(images=raw_images, input_boxes=input_boxes, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs, multimask_output=False) + +>>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"]) +>>> print(f"Processed {len(input_boxes)} images with {len(input_boxes[0])} and {len(input_boxes[1])} boxes respectively") +Processed 2 images with 4 and 4 boxes respectively +>>> print(f"IoU scores: {outputs.iou_scores.squeeze()}") +IoU scores: tensor([0.9301, 0.9348, 0.6605, 0.9465], device='cuda:0') +``` + +### Using Previous Masks as Input + +EdgeTAM can use masks from previous predictions as input to refine segmentation: + +```python +>>> # Get initial segmentation +>>> input_points = [[[[500, 375]]]] +>>> input_labels = [[[1]]] +>>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> # Use the best mask as input for refinement +>>> mask_input = outputs.pred_masks[:, :, torch.argmax(outputs.iou_scores.squeeze())] + +>>> # Add additional points with the mask input +>>> new_input_points = [[[[500, 375], [450, 300]]]] +>>> new_input_labels = [[[1, 1]]] +>>> inputs = processor( +... input_points=new_input_points, +... input_labels=new_input_labels, +... original_sizes=inputs["original_sizes"], +... return_tensors="pt", +... ).to(device) + +>>> with torch.no_grad(): +... refined_outputs = model( +... **inputs, +... input_masks=mask_input, +... image_embeddings=outputs.image_embeddings, +... multimask_output=False, +... ) +``` + + +## EdgeTamConfig + +[[autodoc]] EdgeTamConfig + +## EdgeTamVisionConfig + +[[autodoc]] EdgeTamVisionConfig + +## EdgeTamMaskDecoderConfig + +[[autodoc]] EdgeTamMaskDecoderConfig + +## EdgeTamPromptEncoderConfig + +[[autodoc]] EdgeTamPromptEncoderConfig + +## EdgeTamVisionModel + +[[autodoc]] EdgeTamVisionModel + - forward + +## EdgeTamModel + +[[autodoc]] EdgeTamModel + - forward diff --git a/docs/source/en/model_doc/edgetam_video.md b/docs/source/en/model_doc/edgetam_video.md new file mode 100644 index 000000000000..381bace4dbe0 --- /dev/null +++ b/docs/source/en/model_doc/edgetam_video.md @@ -0,0 +1,297 @@ + +*This model was released on 2025-01-13 and added to Hugging Face Transformers on 2025-09-29.* + + +
+
+ PyTorch + SDPA + FlashAttention +
+
+ +# EdgeTAMVideo + +## Overview + +The EdgeTAM model was proposed in [EdgeTAM: On-Device Track Anything Model](https://huggingface.co/papers/2501.07256) Chong Zhou, Chenchen Zhu, Yunyang Xiong, Saksham Suri, Fanyi Xiao, Lemeng Wu, Raghuraman Krishnamoorthi, Bo Dai, Chen Change Loy, Vikas Chandra, Bilge Soran. + +EdgeTAM is an efficient adaptation of SAM 2 that introduces a 2D Spatial Perceiver architecture to optimize memory attention mechanisms for real-time video segmentation on mobile devices. + +The abstract from the paper is the following: + +*On top of Segment Anything Model (SAM), SAM 2 further extends its capability from image to video inputs through a memory bank mechanism and obtains a remarkable performance compared with previous methods, making it a foundation model for video segmentation task. In this paper, we aim at making SAM 2 much more efficient so that it even runs on mobile devices while maintaining a comparable performance. Despite several works optimizing SAM for better efficiency, we find they are not sufficient for SAM 2 because they all focus on compressing the image encoder, while our benchmark shows that the newly introduced memory attention blocks are also the latency bottleneck. Given this observation, we propose EdgeTAM, which leverages a novel 2D Spatial Perceiver to reduce the computational cost. In particular, the proposed 2D Spatial Perceiver encodes the densely stored frame-level memories with a lightweight Transformer that contains a fixed set of learnable queries. Given that video segmentation is a dense prediction task, we find preserving the spatial structure of the memories is essential so that the queries are split into global-level and patch-level groups. We also propose a distillation pipeline that further improves the performance without inference overhead. As a result, EdgeTAM achieves 87.7, 70.0, 72.3, and 71.7 J&F on DAVIS 2017, MOSE, SA-V val, and SA-V test, while running at 16 FPS on iPhone 15 Pro Max.* + +This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan). +The original code can be found [here](https://github.com/facebookresearch/EdgeTAM). + +## Usage example + +### Video Segmentation and Tracking + +EdgeTAM Video's key strength is its ability to track objects across video frames efficiently on mobile devices. Here's how to use it for video segmentation: + +#### Basic Video Tracking + +```python +>>> from transformers import EdgeTamVideoModel, Sam2VideoProcessor, infer_device +>>> import torch + +>>> device = infer_device() +>>> model = EdgeTamVideoModel.from_pretrained("yonigozlan/edgetam-video-1").to(device, dtype=torch.bfloat16) +>>> processor = Sam2VideoProcessor.from_pretrained("yonigozlan/edgetam-video-1") + +>>> # Load video frames (example assumes you have a list of PIL Images) +>>> # video_frames = [Image.open(f"frame_{i:05d}.jpg") for i in range(num_frames)] + +>>> # For this example, we'll use the video loading utility +>>> from transformers.video_utils import load_video +>>> video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4" +>>> video_frames, _ = load_video(video_url) + +>>> # Initialize video inference session +>>> inference_session = processor.init_video_session( +... video=video_frames, +... inference_device=device, +... dtype=torch.bfloat16, +... ) + +>>> # Add click on first frame to select object +>>> ann_frame_idx = 0 +>>> ann_obj_id = 1 +>>> points = [[[[210, 350]]]] +>>> labels = [[[1]]] + +>>> processor.add_inputs_to_inference_session( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... obj_ids=ann_obj_id, +... input_points=points, +... input_labels=labels, +... ) + +>>> # Segment the object on the first frame +>>> outputs = model( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... ) +>>> video_res_masks = processor.post_process_masks( +... [outputs.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +>>> print(f"Segmentation shape: {video_res_masks.shape}") +Segmentation shape: torch.Size([1, 1, 540, 960]) + +>>> # Propagate through the entire video +>>> video_segments = {} +>>> for sam2_video_output in model.propagate_in_video_iterator(inference_session): +... video_res_masks = processor.post_process_masks( +... [sam2_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +... video_segments[sam2_video_output.frame_idx] = video_res_masks + +>>> print(f"Tracked object through {len(video_segments)} frames") +Tracked object through 200 frames +``` + +#### Multi-Object Video Tracking + +Track multiple objects simultaneously across video frames: + +```python +>>> # Reset for new tracking session +>>> inference_session.reset_inference_session() + +>>> # Add multiple objects on the first frame +>>> ann_frame_idx = 0 +>>> obj_ids = [2, 3] +>>> input_points = [[[[200, 300]], [[400, 150]]]] # Points for two objects (batched) +>>> input_labels = [[[1], [1]]] + +>>> processor.add_inputs_to_inference_session( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... obj_ids=obj_ids, +... input_points=input_points, +... input_labels=input_labels, +... ) + +>>> # Get masks for both objects on first frame +>>> outputs = model( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... ) + +>>> # Propagate both objects through video +>>> video_segments = {} +>>> for sam2_video_output in model.propagate_in_video_iterator(inference_session): +... video_res_masks = processor.post_process_masks( +... [sam2_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +... video_segments[sam2_video_output.frame_idx] = { +... obj_id: video_res_masks[i] +... for i, obj_id in enumerate(inference_session.obj_ids) +... } + +>>> print(f"Tracked {len(inference_session.obj_ids)} objects through {len(video_segments)} frames") +Tracked 2 objects through 200 frames +``` + +#### Refining Video Segmentation + +You can add additional clicks on any frame to refine the tracking: + +```python +>>> # Add refinement click on a later frame +>>> refine_frame_idx = 50 +>>> ann_obj_id = 2 # Refining first object +>>> points = [[[[220, 280]]]] # Additional point +>>> labels = [[[1]]] # Positive click + +>>> processor.add_inputs_to_inference_session( +... inference_session=inference_session, +... frame_idx=refine_frame_idx, +... obj_ids=ann_obj_id, +... input_points=points, +... input_labels=labels, +... ) + +>>> # Re-propagate with the additional information +>>> video_segments = {} +>>> for sam2_video_output in model.propagate_in_video_iterator(inference_session): +... video_res_masks = processor.post_process_masks( +... [sam2_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +... video_segments[sam2_video_output.frame_idx] = video_res_masks +``` + +### Streaming Video Inference + +For real-time applications, EdgeTAM Video supports processing video frames as they arrive: + +```python +>>> # Initialize session for streaming +>>> inference_session = processor.init_video_session( +... inference_device=device, +... dtype=torch.bfloat16, +... ) + +>>> # Process frames one by one +>>> for frame_idx, frame in enumerate(video_frames[:10]): # Process first 10 frames +... inputs = processor(images=frame, device=device, return_tensors="pt") +... +... if frame_idx == 0: +... # Add point input on first frame +... processor.add_inputs_to_inference_session( +... inference_session=inference_session, +... frame_idx=0, +... obj_ids=1, +... input_points=[[[[210, 350], [250, 220]]]], +... input_labels=[[[1, 1]]], +... original_size=inputs.original_sizes[0], # need to be provided when using streaming video inference +... ) +... +... # Process current frame +... sam2_video_output = model(inference_session=inference_session, frame=inputs.pixel_values[0]) +... +... video_res_masks = processor.post_process_masks( +... [sam2_video_output.pred_masks], original_sizes=inputs.original_sizes, binarize=False +... )[0] +... print(f"Frame {frame_idx}: mask shape {video_res_masks.shape}") + +Frame 0: mask shape torch.Size([1, 1, 540, 960]) +... +``` + +#### Video Batch Processing for Multiple Objects + +Track multiple objects simultaneously in video by adding them all at once: + +```python +>>> # Initialize video session +>>> inference_session = processor.init_video_session( +... video=video_frames, +... inference_device=device, +... dtype=torch.bfloat16, +... ) + +>>> # Add multiple objects on the first frame using batch processing +>>> ann_frame_idx = 0 +>>> obj_ids = [2, 3] # Track two different objects +>>> input_points = [ +... [[[200, 300], [230, 250], [275, 175]], [[400, 150]]] +... ] # Object 2: 3 points (2 positive, 1 negative); Object 3: 1 point +>>> input_labels = [ +... [[1, 1, 0], [1]] +... ] # Object 2: positive, positive, negative; Object 3: positive + +>>> processor.add_inputs_to_inference_session( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... obj_ids=obj_ids, +... input_points=input_points, +... input_labels=input_labels, +... ) + +>>> # Get masks for all objects on the first frame +>>> outputs = model( +... inference_session=inference_session, +... frame_idx=ann_frame_idx, +... ) +>>> video_res_masks = processor.post_process_masks( +... [outputs.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +>>> print(f"Generated masks for {video_res_masks.shape[0]} objects") +Generated masks for 2 objects + +>>> # Propagate all objects through the video +>>> video_segments = {} +>>> for sam2_video_output in model.propagate_in_video_iterator(inference_session): +... video_res_masks = processor.post_process_masks( +... [sam2_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False +... )[0] +... video_segments[sam2_video_output.frame_idx] = { +... obj_id: video_res_masks[i] +... for i, obj_id in enumerate(inference_session.obj_ids) +... } + +>>> print(f"Tracked {len(inference_session.obj_ids)} objects through {len(video_segments)} frames") +Tracked 2 objects through 200 frames +``` + +## EdgeTamVideoMaskDecoderConfig + +[[autodoc]] EdgeTamVideoMaskDecoderConfig + +## EdgeTamVideoPromptEncoderConfig + +[[autodoc]] EdgeTamVideoPromptEncoderConfig + +## EdgeTamVideoConfig + +[[autodoc]] EdgeTamVideoConfig + +## EdgeTamVideoInferenceSession + +[[autodoc]] EdgeTamVideoInferenceSession + +## EdgeTamVideoModel + +[[autodoc]] EdgeTamVideoModel + - forward diff --git a/docs/source/en/model_doc/qwen3_vl.md b/docs/source/en/model_doc/qwen3_vl.md index 626b4119aa44..33c8c7e96aee 100644 --- a/docs/source/en/model_doc/qwen3_vl.md +++ b/docs/source/en/model_doc/qwen3_vl.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2025-02-19 and added to Hugging Face Transformers on 2025-09-15.* +*This model was released on 2025-09-23 and added to Hugging Face Transformers on 2025-09-15.*
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2905a842612e..c721f24a506d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -108,6 +108,8 @@ from .dots1 import * from .dpr import * from .dpt import * + from .edgetam import * + from .edgetam_video import * from .efficientloftr import * from .efficientnet import * from .electra import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c40b5a37b02a..f6a12e7cef98 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -127,6 +127,9 @@ ("dots1", "Dots1Config"), ("dpr", "DPRConfig"), ("dpt", "DPTConfig"), + ("edgetam", "EdgeTamConfig"), + ("edgetam_video", "EdgeTamVideoConfig"), + ("edgetam_vision_model", "EdgeTamVisionConfig"), ("efficientformer", "EfficientFormerConfig"), ("efficientloftr", "EfficientLoFTRConfig"), ("efficientnet", "EfficientNetConfig"), @@ -563,6 +566,9 @@ ("dots1", "dots1"), ("dpr", "DPR"), ("dpt", "DPT"), + ("edgetam", "EdgeTAM"), + ("edgetam_video", "EdgeTamVideo"), + ("edgetam_vision_model", "EdgeTamVisionModel"), ("efficientformer", "EfficientFormer"), ("efficientloftr", "EfficientLoFTR"), ("efficientnet", "EfficientNet"), @@ -983,6 +989,7 @@ ("qwen3_vl_moe_text", "qwen3_vl_moe"), ("sam_vision_model", "sam"), ("sam2_vision_model", "sam2"), + ("edgetam_vision_model", "edgetam"), ("sam2_hiera_det_model", "sam2"), ("sam_hq_vision_model", "sam_hq"), ("llama4_text", "llama4"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index aa16ac3555eb..a272735af207 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -91,6 +91,7 @@ ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")), ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")), + ("edgetam", (None, "Sam2ImageProcessorFast")), ("efficientformer", ("EfficientFormerImageProcessor", None)), ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")), ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 297d4890d131..298834bebe93 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -131,6 +131,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("dots1", "Dots1Model"), ("dpr", "DPRQuestionEncoder"), ("dpt", "DPTModel"), + ("edgetam", "EdgeTamModel"), + ("edgetam_video", "EdgeTamVideoModel"), + ("edgetam_vision_model", "EdgeTamVisionModel"), ("efficientformer", "EfficientFormerModel"), ("efficientloftr", "EfficientLoFTRModel"), ("efficientnet", "EfficientNetModel"), @@ -1709,6 +1712,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( [ + ("edgetam", "EdgeTamModel"), + ("edgetam_video", "EdgeTamModel"), ("sam", "SamModel"), ("sam2", "Sam2Model"), ("sam2_video", "Sam2Model"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 2b1ca09bb8df..11862a5896b9 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -66,6 +66,7 @@ ("deepseek_vl", "DeepseekVLProcessor"), ("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"), ("dia", "DiaProcessor"), + ("edgetam", "Sam2Processor"), ("emu3", "Emu3Processor"), ("evolla", "EvollaProcessor"), ("flava", "FlavaProcessor"), diff --git a/src/transformers/models/edgetam/__init__.py b/src/transformers/models/edgetam/__init__.py new file mode 100644 index 000000000000..d9c1a55fc5bc --- /dev/null +++ b/src/transformers/models/edgetam/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_edgetam import * + from .modeling_edgetam import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/edgetam/configuration_edgetam.py b/src/transformers/models/edgetam/configuration_edgetam.py new file mode 100644 index 000000000000..07ccee36e932 --- /dev/null +++ b/src/transformers/models/edgetam/configuration_edgetam.py @@ -0,0 +1,332 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class EdgeTamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*): + Configuration for the vision backbone. This is used to instantiate the backbone using + `AutoModel.from_config`. + backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`): + The list of channel dimensions for the backbone. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`): + The spatial sizes of the feature maps from the backbone. + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + fpn_kernel_size (`int`, *optional*, defaults to 1): + The kernel size for the convolutions in the neck. + fpn_stride (`int`, *optional*, defaults to 1): + The stride for the convolutions in the neck. + fpn_padding (`int`, *optional*, defaults to 0): + The padding for the convolutions in the neck. + fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`): + The levels for the top-down FPN connections. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels from the FPN to use. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "vision_config" + model_type = "edgetam_vision_model" + sub_configs = { + "backbone_config": AutoConfig, + } + + def __init__( + self, + backbone_config=None, + backbone_channel_list=None, + backbone_feature_sizes=None, + fpn_hidden_size=256, + fpn_kernel_size=1, + fpn_stride=1, + fpn_padding=0, + fpn_top_down_levels=None, + num_feature_levels=3, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list + backbone_feature_sizes = ( + [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes + ) + fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels + + if isinstance(backbone_config, dict): + backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper") + backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config) + elif isinstance(backbone_config, AutoConfig): + backbone_config = backbone_config + elif backbone_config is None: + backbone_config = AutoConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]}, + ) + + self.backbone_config = backbone_config + + # Neck + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.fpn_kernel_size = fpn_kernel_size + self.fpn_stride = fpn_stride + self.fpn_padding = fpn_padding + self.fpn_top_down_levels = fpn_top_down_levels + self.num_feature_levels = num_feature_levels + + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class EdgeTamPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamPromptEncoder`]. The [`EdgeTamPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + scale (`float`, *optional*, defaults to 1): + The scale factor for the prompt encoder. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + scale=1, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.scale = scale + + +class EdgeTamMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamMaskDecoder`]. It is used to instantiate a EDGETAM + memory encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the EDGETAM mask decoder. + mlp_dim (`int`, *optional*, defaults to 2048): + The dimension of the MLP in the two-way transformer. + num_hidden_layers (`int`, *optional*, defaults to 2): + The number of hidden layers in the two-way transformer. + num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in the two-way transformer. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsample rate for the attention layers. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of multimask outputs. + iou_head_depth (`int`, *optional*, defaults to 3): + The depth of the IoU head. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the IoU head. + dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`): + Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05): + The stability delta for the dynamic multimask. + dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98): + The stability threshold for the dynamic multimask. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="gelu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_multimask_outputs = num_multimask_outputs + self.hidden_act = hidden_act + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + # TwoWayTransformer configuration + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.mlp_dim = mlp_dim + self.attention_downsample_rate = attention_downsample_rate + + +class EdgeTamConfig(PretrainedConfig): + r""" + [`EdgeTamConfig`] is the configuration class to store the configuration of a [`EdgeTamModel`]. It is used to instantiate a + EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder + configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny + [facebook/edgetam.1-hiera-tiny](https://huggingface.co/facebook/edgetam.1-hiera-tiny) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `EdgeTamVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVisionConfig`]. + prompt_encoder_config (Union[`dict`, `EdgeTamPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `EdgeTamMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`]. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation for parameter initialization. + + Example: + + ```python + >>> from transformers import ( + ... EdgeTamVisionConfig, + ... EdgeTamPromptEncoderConfig, + ... EdgeTamMaskDecoderConfig, + ... EdgeTamModel, + ... ) + + >>> # Initializing a EdgeTamConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> configuration = EdgeTamconfig() + + >>> # Initializing a EdgeTamModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> model = EdgeTamModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig + + >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations + >>> vision_config = EdgeTamVisionConfig() + >>> prompt_encoder_config = EdgeTamPromptEncoderConfig() + >>> mask_decoder_config = EdgeTamMaskDecoderConfig() + + >>> config = EdgeTamConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "edgetam" + sub_configs = { + "vision_config": AutoConfig, + "prompt_encoder_config": EdgeTamPromptEncoderConfig, + "mask_decoder_config": EdgeTamMaskDecoderConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "edgetam_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + if isinstance(prompt_encoder_config, EdgeTamPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, EdgeTamMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = vision_config + self.prompt_encoder_config = EdgeTamPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = EdgeTamMaskDecoderConfig(**mask_decoder_config) + + self.initializer_range = initializer_range + + +__all__ = ["EdgeTamConfig", "EdgeTamVisionConfig", "EdgeTamPromptEncoderConfig", "EdgeTamMaskDecoderConfig"] diff --git a/src/transformers/models/edgetam/convert_edgetam_to_hf.py b/src/transformers/models/edgetam/convert_edgetam_to_hf.py new file mode 100644 index 000000000000..382bc1559ec4 --- /dev/null +++ b/src/transformers/models/edgetam/convert_edgetam_to_hf.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert SAM checkpoints from the original repository. + +URL: https://github.com/facebookresearch/segment-anything-2. +""" + +import argparse +import re + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + EdgeTamConfig, + EdgeTamMaskDecoderConfig, + EdgeTamModel, + EdgeTamPromptEncoderConfig, + EdgeTamVisionConfig, + Sam2ImageProcessorFast, + Sam2Processor, + TimmWrapperConfig, +) + + +def get_config(model_name): + backbone_config = TimmWrapperConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": (0, 1, 2, 3)}, + ) + vision_config = EdgeTamVisionConfig(backbone_config=backbone_config) + + prompt_encoder_config = EdgeTamPromptEncoderConfig() + mask_decoder_config = EdgeTamMaskDecoderConfig() + enable_temporal_pos_encoding_for_object_pointers = False + project_temporal_pos_encoding_in_object_pointers = False + enable_occlusion_spatial_embedding = False + + config = EdgeTamConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + enable_temporal_pos_encoding_for_object_pointers=enable_temporal_pos_encoding_for_object_pointers, + project_temporal_pos_encoding_in_object_pointers=project_temporal_pos_encoding_in_object_pointers, + enable_occlusion_spatial_embedding=enable_occlusion_spatial_embedding, + ) + + return config + + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "dwconv": "depthwise_conv", + "pwconv": "pointwise_conv", + "fuser": "memory_fuser", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "obj_ptr_tpos_proj": "temporal_positional_encoding_projection_layer", + "no_obj_embed_spatial": "occlusion_spatial_embedding_parameter", + "sam_prompt_encoder": "prompt_encoder", + "sam_mask_decoder": "mask_decoder", + "maskmem_tpos_enc": "memory_temporal_positional_encoding", + "gamma": "scale", + "image_encoder.neck": "vision_encoder.neck", + "image_encoder": "vision_encoder.backbone", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "pix_feat_proj": "feature_projection", + "patch_embed.proj": "patch_embed.projection", + "no_mem_embed": "no_memory_embedding", + "no_mem_pos_enc": "no_memory_positional_encoding", + "obj_ptr": "object_pointer", + ".norm": ".layer_norm", + "trunk.": "", + "out_proj": "o_proj", + "body.": "timm_model.", + "ff.0": "feed_forward.layer_norm", + "ff.1": "feed_forward.linear1", + "ff.3": "feed_forward.linear2", +} + + +def replace_keys(state_dict): + model_state_dict = {} + output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" + output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*" + output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*" + output_vision_encoder_mlps_pattern = r"vision_encoder.backbone.blocks.(\d+).mlp.layers.(\d+).*" + output_vision_encoder_neck_pattern = r"vision_encoder.neck.convs.(\d+).conv" + output_memory_encoder_projection_pattern = r"memory_encoder.o_proj.*" + output_object_pointer_proj_pattern = r"object_pointer_proj.layers.(\d+).*" + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + # vision_encoder.blocks.0.mlp.layers.1.weight -> vision_encoder.blocks.0.mlp.proj_out.weight + if re.match(output_vision_encoder_mlps_pattern, key): + layer_nb = int(re.match(output_vision_encoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "proj_out") + + # mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight + if re.match(output_mask_decoder_mlps_pattern, key): + layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("mlp.layers.0", "mlp.proj_in") + elif layer_nb == 1: + key = key.replace("mlp.layers.1", "mlp.proj_out") + + # mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight + if re.match(output_mask_decoder_score_head_pattern, key): + layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + if re.match(output_hypernetworks_mlps_pattern, key): + layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + # vision_encoder.neck.convs.1.conv.bias -> vision_encoder.neck.convs.1.bias + if re.match(output_vision_encoder_neck_pattern, key): + key = key.replace(".conv.", ".") + + # memory_encoder.o_proj.weight -> memory_encoder.projection.weight + if re.match(output_memory_encoder_projection_pattern, key): + key = key.replace(".o_proj.", ".projection.") + + if re.match(output_object_pointer_proj_pattern, key): + layer_nb = int(re.match(output_object_pointer_proj_pattern, key).group(1)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + key = key.replace("layers.2", "proj_out") + + model_state_dict[key] = value + + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + model_state_dict["prompt_encoder.point_embed.weight"] = torch.cat( + [model_state_dict.pop(f"prompt_encoder.point_embed.{i}.weight") for i in range(4)], + dim=0, + ) + + return model_state_dict + + +def convert_edgetam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, run_sanity_check): + config = get_config(model_name) + + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + state_dict = replace_keys(state_dict) + + image_processor = Sam2ImageProcessorFast() + processor = Sam2Processor(image_processor=image_processor) + hf_model = EdgeTamModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) + hf_model = hf_model.to(device) + for pattern in EdgeTamModel._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] + if missing_keys or unexpected_keys: + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + raise ValueError("Missing or unexpected keys in the state dict") + + if run_sanity_check: + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[[1000, 600]]]] + input_labels = [[[1]]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model(**inputs) + scores = output.iou_scores.squeeze() + + assert torch.allclose(scores, torch.tensor([0.0356, 0.2141, 0.9707]).cuda(), atol=1e-3) + + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"yonigozlan/{pytorch_dump_folder.split('/')[-1]}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["EdgeTAM"] + parser.add_argument( + "--model_name", + default="EdgeTAM", + choices=choices, + type=str, + help="Name of the original model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the original checkpoint", + ) + parser.add_argument("--pytorch_dump_folder_path", default="", type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model and processor to the hub after converting", + ) + parser.add_argument( + "--run_sanity_check", + action="store_true", + help="Whether to run the sanity check after converting", + ) + + args = parser.parse_args() + + hf_model_name = args.model_name.replace("_", "-") + checkpoint_path = ( + hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name.lower()}.pt") + if args.checkpoint_path is None + else args.checkpoint_path + ) + + convert_edgetam_checkpoint( + args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.run_sanity_check + ) diff --git a/src/transformers/models/edgetam/modeling_edgetam.py b/src/transformers/models/edgetam/modeling_edgetam.py new file mode 100644 index 000000000000..d7e3ee6009cf --- /dev/null +++ b/src/transformers/models/edgetam/modeling_edgetam.py @@ -0,0 +1,1252 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, auto_docstring +from ..auto import AutoModel +from .configuration_edgetam import ( + EdgeTamConfig, + EdgeTamMaskDecoderConfig, + EdgeTamPromptEncoderConfig, + EdgeTamVisionConfig, +) + + +# fix this in modular +if True: + from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel + + +class EdgeTamLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +@dataclass +@auto_docstring(custom_intro="Base class for the vision encoder's outputs.") +class EdgeTamVisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + fpn_hidden_states: Optional[torch.FloatTensor] = None + fpn_position_encoding: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EdgeTamAttention(nn.Module): + """ + EDGETAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_similarity: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class EdgeTamTwoWayAttentionBlock(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`EdgeTamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = EdgeTamAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = EdgeTamAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = EdgeTamFeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = EdgeTamAttention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class EdgeTamFeedForward(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +@auto_docstring +class EdgeTamPreTrainedModel(PreTrainedModel): + config_class = EdgeTamConfig + base_model_prefix = "edgetam" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, EdgeTamModel): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class EdgeTamSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[Tensor] = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class EdgeTamVisionNeck(nn.Module): + def __init__(self, config: EdgeTamVisionConfig): + super().__init__() + self.config = config + + self.position_encoding = EdgeTamSinePositionEmbedding( + num_pos_feats=config.fpn_hidden_size // 2, normalize=True + ) + self.convs = nn.ModuleList() + for in_channels in config.backbone_channel_list: + self.convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.fpn_hidden_size, + kernel_size=config.fpn_kernel_size, + stride=config.fpn_stride, + padding=config.fpn_padding, + ), + ) + self.fpn_top_down_levels = config.fpn_top_down_levels + + def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + fpn_hidden_states = () + fpn_position_encoding = () + + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + lateral_features = hidden_states[i].permute(0, 3, 1, 2) + lateral_features = self.convs[n - i](lateral_features) + if i not in self.fpn_top_down_levels or i == n: + prev_features = lateral_features + else: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode="nearest", + align_corners=None, + antialias=False, + ).to(lateral_features.dtype) + prev_features = lateral_features + top_down_features + + prev_position_encoding = self.position_encoding( + prev_features.shape, prev_features.device, prev_features.dtype + ).to(prev_features.dtype) + + fpn_hidden_states += (prev_features,) + fpn_position_encoding += (prev_position_encoding,) + + return fpn_hidden_states, fpn_position_encoding + + +@auto_docstring( + custom_intro=""" + The vision model from EdgeTAM without any head or projection on top. + """ +) +class EdgeTamVisionModel(EdgeTamPreTrainedModel): + config_class = EdgeTamVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel} + + def __init__(self, config: EdgeTamVisionConfig): + super().__init__(config) + self.config = config + + self.backbone = AutoModel.from_config(config.backbone_config) + + self.neck = EdgeTamVisionNeck(config) + self.num_feature_levels = config.num_feature_levels + + self.post_init() + + @check_model_inputs + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, EdgeTamVisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values) + intermediate_hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states] + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return EdgeTamVisionEncoderOutput( + last_hidden_state=intermediate_hidden_states[-1], + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +@dataclass +@auto_docstring(custom_intro="Base class for the EdgeTam model's output.") +class EdgeTamImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(torch.FloatTensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + object_score_logits: Optional[torch.FloatTensor] = None + image_embeddings: tuple[torch.FloatTensor, ...] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EdgeTamPositionalEmbedding(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.scale = config.scale + positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) + self.register_buffer("positional_embedding", positional_embedding) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(torch.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class EdgeTamMaskEmbedding(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = EdgeTamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = EdgeTamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class EdgeTamPromptEncoder(nn.Module): + def __init__(self, config: EdgeTamPromptEncoderConfig): + super().__init__() + self.shared_embedding = EdgeTamPositionalEmbedding(config) + self.mask_embed = EdgeTamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def forward( + self, + input_points: Optional[tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class EdgeTamTwoWayTransformer(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(EdgeTamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = EdgeTamAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class EdgeTamMaskDecoder(nn.Module): + def __init__(self, config: EdgeTamMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = EdgeTamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = EdgeTamLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [EdgeTamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + self.iou_prediction_head = EdgeTamFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = EdgeTamFeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + high_resolution_features: list[torch.Tensor], + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[torch.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`torch.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`torch.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) + ) + best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +@auto_docstring( + custom_intro=""" + Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and + input points and labels, boxes, or masks. + """ +) +class EdgeTamModel(EdgeTamPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"spatial_perceiver.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def __init__(self, config: EdgeTamConfig): + super().__init__(config) + self.shared_image_embedding = EdgeTamPositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = EdgeTamPromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = EdgeTamMaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_image_wide_positional_embeddings(self) -> torch.Tensor: + size = self.prompt_encoder.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones(size, device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> list[torch.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @check_model_inputs + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> EdgeTamImageSegmentationOutput: + r""" + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("danelcsb/edgetam.1_hiera_tiny") + >>> processor = AutoProcessor.from_pretrained("danelcsb/edgetam.1_hiera_tiny") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = torch.zeros( + batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device + ) + input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + return EdgeTamImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_multimasks, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ + list[torch.Tensor], + list[torch.Tensor], + Optional[tuple[torch.FloatTensor, ...]], + Optional[tuple[torch.FloatTensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`torch.FloatTensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + +__all__ = ["EdgeTamModel", "EdgeTamVisionModel", "EdgeTamPreTrainedModel"] diff --git a/src/transformers/models/edgetam/modular_edgetam.py b/src/transformers/models/edgetam/modular_edgetam.py new file mode 100644 index 000000000000..e26d58d96b81 --- /dev/null +++ b/src/transformers/models/edgetam/modular_edgetam.py @@ -0,0 +1,261 @@ +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SAM 2 model.""" + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig +from transformers.models.sam2.modeling_sam2 import ( + Sam2Attention, + Sam2FeedForward, + Sam2LayerNorm, + Sam2Model, + Sam2PreTrainedModel, + Sam2TwoWayAttentionBlock, + Sam2VisionEncoderOutput, + Sam2VisionModel, +) +from transformers.utils.generic import TransformersKwargs, check_model_inputs + +from ...configuration_utils import PretrainedConfig +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, +) +from ..auto import CONFIG_MAPPING, AutoConfig + + +# fix this in modular +if True: + from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel + + +class EdgeTamVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*): + Configuration for the vision backbone. This is used to instantiate the backbone using + `AutoModel.from_config`. + backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`): + The list of channel dimensions for the backbone. + backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`): + The spatial sizes of the feature maps from the backbone. + fpn_hidden_size (`int`, *optional*, defaults to 256): + The hidden dimension of the FPN. + fpn_kernel_size (`int`, *optional*, defaults to 1): + The kernel size for the convolutions in the neck. + fpn_stride (`int`, *optional*, defaults to 1): + The stride for the convolutions in the neck. + fpn_padding (`int`, *optional*, defaults to 0): + The padding for the convolutions in the neck. + fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`): + The levels for the top-down FPN connections. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels from the FPN to use. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the neck. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon for the layer normalization. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + """ + + base_config_key = "vision_config" + model_type = "edgetam_vision_model" + sub_configs = { + "backbone_config": AutoConfig, + } + + def __init__( + self, + backbone_config=None, + backbone_channel_list=None, + backbone_feature_sizes=None, + fpn_hidden_size=256, + fpn_kernel_size=1, + fpn_stride=1, + fpn_padding=0, + fpn_top_down_levels=None, + num_feature_levels=3, + hidden_act="gelu", + layer_norm_eps=1e-6, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list + backbone_feature_sizes = ( + [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes + ) + fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels + + if isinstance(backbone_config, dict): + backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper") + backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config) + elif isinstance(backbone_config, AutoConfig): + backbone_config = backbone_config + elif backbone_config is None: + backbone_config = AutoConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]}, + ) + + self.backbone_config = backbone_config + + # Neck + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.fpn_kernel_size = fpn_kernel_size + self.fpn_stride = fpn_stride + self.fpn_padding = fpn_padding + self.fpn_top_down_levels = fpn_top_down_levels + self.num_feature_levels = num_feature_levels + + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + + +class EdgeTamPromptEncoderConfig(Sam2PromptEncoderConfig): + pass + + +class EdgeTamMaskDecoderConfig(Sam2MaskDecoderConfig): + pass + + +class EdgeTamConfig(Sam2Config): + pass + + +class EdgeTamLayerNorm(Sam2LayerNorm): + pass + + +class EdgeTamVisionEncoderOutput(Sam2VisionEncoderOutput): + pass + + +class EdgeTamAttention(Sam2Attention): + pass + + +class EdgeTamTwoWayAttentionBlock(Sam2TwoWayAttentionBlock): + pass + + +class EdgeTamFeedForward(Sam2FeedForward): + pass + + +@auto_docstring +class EdgeTamPreTrainedModel(Sam2PreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, EdgeTamModel): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +@auto_docstring( + custom_intro=""" + The vision model from EdgeTAM without any head or projection on top. + """ +) +class EdgeTamVisionModel(Sam2VisionModel): + config_class = EdgeTamVisionConfig + main_input_name = "pixel_values" + _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel} + + def get_input_embeddings(self): + raise NotImplementedError("Can't get input embeddings from timm wrapper model") + + @check_model_inputs + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, EdgeTamVisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values) + intermediate_hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states] + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return EdgeTamVisionEncoderOutput( + last_hidden_state=intermediate_hidden_states[-1], + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +class EdgeTamModel(Sam2Model): + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"spatial_perceiver.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def get_input_embeddings(self): + raise NotImplementedError("Can't get input embeddings from timm wrapper model") + + +__all__ = [ + "EdgeTamModel", + "EdgeTamVisionModel", + "EdgeTamPreTrainedModel", + "EdgeTamConfig", + "EdgeTamVisionConfig", + "EdgeTamPromptEncoderConfig", + "EdgeTamMaskDecoderConfig", +] diff --git a/src/transformers/models/edgetam_video/__init__.py b/src/transformers/models/edgetam_video/__init__.py new file mode 100644 index 000000000000..669dd64ec304 --- /dev/null +++ b/src/transformers/models/edgetam_video/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_edgetam_video import * + from .modeling_edgetam_video import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/edgetam_video/configuration_edgetam_video.py b/src/transformers/models/edgetam_video/configuration_edgetam_video.py new file mode 100644 index 000000000000..954864397dcb --- /dev/null +++ b/src/transformers/models/edgetam_video/configuration_edgetam_video.py @@ -0,0 +1,435 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam_video/modular_edgetam_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class EdgeTamVideoPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVideoPromptEncoder`]. The [`EdgeTamVideoPromptEncoder`] + module is used to encode the input 2D points and bounding boxes. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + scale (`float`, *optional*, defaults to 1): + The scale factor for the prompt encoder. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + scale=1, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.scale = scale + + +class EdgeTamVideoMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EdgeTamVideoMaskDecoder`]. It is used to instantiate a EDGETAM_VIDEO + memory encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the EDGETAM_VIDEO mask decoder. + mlp_dim (`int`, *optional*, defaults to 2048): + The dimension of the MLP in the two-way transformer. + num_hidden_layers (`int`, *optional*, defaults to 2): + The number of hidden layers in the two-way transformer. + num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in the two-way transformer. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsample rate for the attention layers. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of multimask outputs. + iou_head_depth (`int`, *optional*, defaults to 3): + The depth of the IoU head. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the IoU head. + dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`): + Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05): + The stability delta for the dynamic multimask. + dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98): + The stability threshold for the dynamic multimask. + + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="gelu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_multimask_outputs = num_multimask_outputs + self.hidden_act = hidden_act + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + # TwoWayTransformer configuration + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.mlp_dim = mlp_dim + self.attention_downsample_rate = attention_downsample_rate + + +class EdgeTamVideoConfig(PretrainedConfig): + r""" + [`EdgeTamVideoConfig`] is the configuration class to store the configuration of a [`EdgeTamVideoModel`]. It is used to instantiate a + EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder + configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `EdgeTamVideoVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVideoVisionConfig`]. + prompt_encoder_config (Union[`dict`, `EdgeTamVideoPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVideoPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `EdgeTamVideoMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`]. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation for parameter initialization. + num_maskmem (`int`, *optional*, defaults to 7): + The number of memory slots for the mask memory. + image_size (`int`, *optional*, defaults to 1024): + The size of the input images. + sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0): + Scale factor for the sigmoid function in the memory encoder. + sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0): + Bias for the sigmoid function in the memory encoder. + enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`): + Whether to enable spatial embedding for occlusions. + multimask_output_in_sam (`bool`, *optional*, defaults to `True`): + Whether to output multiple masks from the SAM head. + multimask_min_pt_num (`int`, *optional*, defaults to 0): + The minimum number of points to trigger multimask output. + multimask_max_pt_num (`int`, *optional*, defaults to 1): + The maximum number of points to trigger multimask output. + multimask_output_for_tracking (`bool`, *optional*, defaults to `True`): + Whether to use multimask output for tracking. + max_object_pointers_in_encoder (`int`, *optional*, defaults to 16): + The maximum number of object pointers in the encoder. + enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`): + Whether to enable temporal positional encoding for object pointers. + memory_attention_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the memory attention hidden states. + memory_attention_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the memory attention module. + memory_attention_num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer in the memory attention. + memory_attention_downsample_rate (`int`, *optional*, defaults to 1): + The downsample rate for the attention layers. + memory_attention_mlp_hidden_size (`int`, *optional*, defaults to 2048): + The dimension of the feedforward network in the memory attention module. + memory_attention_mlp_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the feedforward network in the memory attention module. + memory_attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the memory attention module. + memory_attention_rope_theta (`float`, *optional*, defaults to 10000): + The Rope theta parameter. + memory_attention_rope_feat_sizes (`Tuple[int, int]`, *optional*, defaults to `[64, 64]`): + The feature sizes for the Rope positional encoding. + memory_attention_rope_k_sizes (`List[int]`, *optional*, defaults to `[16, 16]`): + The key feature sizes for the RoPE positional encoding in memory attention. + memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the Rope positional encoding. + perceiver_resampler_num_latents (`int`, *optional*, defaults to 256): + The number of 1D latent tokens in the perceiver resampler. + perceiver_resampler_num_latents_2d (`int`, *optional*, defaults to 256): + The number of 2D latent tokens in the perceiver resampler. + perceiver_resampler_hidden_size (`int`, *optional*, defaults to 64): + The hidden size of the perceiver resampler. + perceiver_resampler_mlp_intermediate_size (`int`, *optional*, defaults to 256): + The intermediate size of the feedforward network in the perceiver resampler. + perceiver_resampler_num_attention_heads (`int`, *optional*, defaults to 1): + The number of attention heads in the perceiver resampler. + perceiver_resampler_attention_head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head in the perceiver resampler. + perceiver_resampler_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the perceiver resampler. + perceiver_resampler_hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the hidden layers in the perceiver resampler. + perceiver_resampler_attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the attention layers in the perceiver resampler. + memory_encoder_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the memory encoder hidden states. + memory_encoder_output_channels (`int`, *optional*, defaults to 64): + The number of output channels for the memory encoder. + mask_downsampler_embed_dim (`int`, *optional*, defaults to 256): + The dimension of the mask downsampler embedding. + memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024): + The intermediate dimension of the memory fuser feedforward network. + mask_downsampler_kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the mask downsampler. + mask_downsampler_stride (`int`, *optional*, defaults to 2): + The stride for the mask downsampler. + mask_downsampler_padding (`int`, *optional*, defaults to 1): + The padding for the mask downsampler. + mask_downsampler_total_stride (`int`, *optional*, defaults to 16): + The total stride for the mask downsampler. + mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the mask downsampler. + memory_fuser_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the memory fuser. + memory_fuser_embed_dim (`int`, *optional*, defaults to 256): + The dimension of the memory fuser embedding. + memory_fuser_kernel_size (`int`, *optional*, defaults to 7): + The kernel size for the memory fuser. + memory_fuser_padding (`int`, *optional*, defaults to 3): + The padding for the memory fuser. + memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06): + The initial value for the layer scale in the memory fuser. + memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the memory fuser. + + Example: + + ```python + >>> from transformers import ( + ... EdgeTamVisionConfig, + ... EdgeTamVideoPromptEncoderConfig, + ... EdgeTamVideoMaskDecoderConfig, + ... EdgeTamVideoModel, + ... EdgeTamVideoConfig, + ... ) + + >>> # Initializing a EdgeTamVideoConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> configuration = EdgeTamVideoConfig() + + >>> # Initializing a EdgeTamVideoModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> model = EdgeTamVideoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig + + >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations + >>> vision_config = EdgeTamVisionConfig() + >>> prompt_encoder_config = EdgeTamVideoPromptEncoderConfig() + >>> mask_decoder_config = EdgeTamVideoMaskDecoderConfig() + + >>> config = EdgeTamVideoConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "edgetam_video" + sub_configs = { + "vision_config": AutoConfig, + "prompt_encoder_config": EdgeTamVideoPromptEncoderConfig, + "mask_decoder_config": EdgeTamVideoMaskDecoderConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + enable_occlusion_spatial_embedding=True, + multimask_output_in_sam=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + multimask_output_for_tracking=True, + max_object_pointers_in_encoder=16, + enable_temporal_pos_encoding_for_object_pointers=True, + # memory attention + memory_attention_hidden_size=256, + memory_attention_num_layers=2, + memory_attention_num_attention_heads=1, + memory_attention_downsample_rate=1, + memory_attention_mlp_hidden_size=2048, + memory_attention_mlp_hidden_act="relu", + memory_attention_dropout=0.1, + memory_attention_rope_theta=10000, + memory_attention_rope_feat_sizes=None, + memory_attention_rope_k_sizes=None, + memory_attention_rope_dropout=0.1, + # spatial perceiver resampler + perceiver_resampler_num_latents=256, + perceiver_resampler_num_latents_2d=256, + perceiver_resampler_hidden_size=64, + perceiver_resampler_mlp_intermediate_size=256, + perceiver_resampler_num_attention_heads=1, + perceiver_resampler_attention_head_dim=64, + perceiver_resampler_num_layers=2, + perceiver_resampler_hidden_dropout=0.0, + perceiver_resampler_attention_dropout=0.0, + # memory encoder + memory_encoder_hidden_size=256, + memory_encoder_output_channels=64, + mask_downsampler_embed_dim=256, + memory_fuser_intermediate_dim=1024, + mask_downsampler_kernel_size=3, + mask_downsampler_stride=2, + mask_downsampler_padding=1, + mask_downsampler_total_stride=16, + mask_downsampler_hidden_act="gelu", + memory_fuser_num_layers=2, + memory_fuser_embed_dim=256, + memory_fuser_kernel_size=7, + memory_fuser_padding=3, + memory_fuser_layer_scale_init_value=1e-6, + memory_fuser_hidden_act="gelu", + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + memory_attention_rope_feat_sizes = ( + [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes + ) + memory_attention_rope_k_sizes = ( + [16, 16] if memory_attention_rope_k_sizes is None else memory_attention_rope_k_sizes + ) + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + if isinstance(prompt_encoder_config, EdgeTamVideoPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, EdgeTamVideoMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = vision_config + self.prompt_encoder_config = EdgeTamVideoPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = EdgeTamVideoMaskDecoderConfig(**mask_decoder_config) + + self.initializer_range = initializer_range + self.num_maskmem = num_maskmem # default 1 input frame + 6 previous frames + self.image_size = image_size + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc # scale factor for mask sigmoid prob + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc # bias factor for mask sigmoid prob + self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.max_object_pointers_in_encoder = max_object_pointers_in_encoder + self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers + + # memory attention + self.memory_attention_hidden_size = memory_attention_hidden_size + self.memory_attention_num_layers = memory_attention_num_layers + self.memory_attention_num_attention_heads = memory_attention_num_attention_heads + self.memory_attention_downsample_rate = memory_attention_downsample_rate + self.memory_attention_mlp_hidden_size = memory_attention_mlp_hidden_size + self.memory_attention_mlp_hidden_act = memory_attention_mlp_hidden_act + self.memory_attention_dropout = memory_attention_dropout + self.memory_attention_rope_theta = memory_attention_rope_theta + self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes + self.memory_attention_rope_k_sizes = memory_attention_rope_k_sizes + self.memory_attention_rope_dropout = memory_attention_rope_dropout + + # spatial perceiver resampler + self.perceiver_resampler_num_latents = perceiver_resampler_num_latents + self.perceiver_resampler_num_latents_2d = perceiver_resampler_num_latents_2d + self.perceiver_resampler_hidden_size = perceiver_resampler_hidden_size + self.perceiver_resampler_mlp_intermediate_size = perceiver_resampler_mlp_intermediate_size + self.perceiver_resampler_attention_head_dim = perceiver_resampler_attention_head_dim + self.perceiver_resampler_num_attention_heads = perceiver_resampler_num_attention_heads + self.perceiver_resampler_num_layers = perceiver_resampler_num_layers + self.perceiver_resampler_hidden_dropout = perceiver_resampler_hidden_dropout + self.perceiver_resampler_attention_dropout = perceiver_resampler_attention_dropout + + # memory encoder + self.memory_encoder_hidden_size = memory_encoder_hidden_size + self.memory_encoder_output_channels = memory_encoder_output_channels + self.mask_downsampler_embed_dim = mask_downsampler_embed_dim + self.mask_downsampler_kernel_size = mask_downsampler_kernel_size + self.mask_downsampler_stride = mask_downsampler_stride + self.mask_downsampler_padding = mask_downsampler_padding + self.mask_downsampler_total_stride = mask_downsampler_total_stride + self.mask_downsampler_hidden_act = mask_downsampler_hidden_act + self.memory_fuser_num_layers = memory_fuser_num_layers + self.memory_fuser_embed_dim = memory_fuser_embed_dim + self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim + self.memory_fuser_kernel_size = memory_fuser_kernel_size + self.memory_fuser_padding = memory_fuser_padding + self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value + self.memory_fuser_hidden_act = memory_fuser_hidden_act + + +__all__ = ["EdgeTamVideoMaskDecoderConfig", "EdgeTamVideoPromptEncoderConfig", "EdgeTamVideoConfig"] diff --git a/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py b/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py new file mode 100644 index 000000000000..6290bef5e1c8 --- /dev/null +++ b/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert SAM checkpoints from the original repository. + +URL: https://github.com/facebookresearch/segment-anything-2. +""" + +import argparse +import re + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ( + EdgeTamVideoConfig, + EdgeTamVideoMaskDecoderConfig, + EdgeTamVideoModel, + EdgeTamVideoPromptEncoderConfig, + EdgeTamVisionConfig, + Sam2ImageProcessorFast, + Sam2VideoProcessor, + Sam2VideoVideoProcessor, + TimmWrapperConfig, +) + + +def get_config(model_name): + backbone_config = TimmWrapperConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={"in_chans": 3, "features_only": True, "out_indices": (0, 1, 2, 3)}, + ) + vision_config = EdgeTamVisionConfig(backbone_config=backbone_config) + + prompt_encoder_config = EdgeTamVideoPromptEncoderConfig() + mask_decoder_config = EdgeTamVideoMaskDecoderConfig() + enable_temporal_pos_encoding_for_object_pointers = False + enable_occlusion_spatial_embedding = False + + config = EdgeTamVideoConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + enable_temporal_pos_encoding_for_object_pointers=enable_temporal_pos_encoding_for_object_pointers, + enable_occlusion_spatial_embedding=enable_occlusion_spatial_embedding, + ) + + return config + + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "dwconv": "depthwise_conv", + "pwconv": "pointwise_conv", + "fuser": "memory_fuser", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "obj_ptr_tpos_proj": "temporal_positional_encoding_projection_layer", + "no_obj_embed_spatial": "occlusion_spatial_embedding_parameter", + "sam_prompt_encoder": "prompt_encoder", + "sam_mask_decoder": "mask_decoder", + "maskmem_tpos_enc": "memory_temporal_positional_encoding", + "gamma": "scale", + "image_encoder.neck": "vision_encoder.neck", + "image_encoder": "vision_encoder.backbone", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "pix_feat_proj": "feature_projection", + "patch_embed.proj": "patch_embed.projection", + "no_mem_embed": "no_memory_embedding", + "no_mem_pos_enc": "no_memory_positional_encoding", + "obj_ptr": "object_pointer", + ".norm": ".layer_norm", + "trunk.": "", + "out_proj": "o_proj", + "body.": "timm_model.", + "ff.0": "mlp.layer_norm", + "ff.1": "mlp.up_proj", + "ff.3": "mlp.down_proj", +} + + +def replace_keys(state_dict): + model_state_dict = {} + output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" + output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*" + output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*" + output_vision_encoder_mlps_pattern = r"vision_encoder.backbone.blocks.(\d+).mlp.layers.(\d+).*" + output_vision_encoder_neck_pattern = r"vision_encoder.neck.convs.(\d+).conv" + output_memory_encoder_projection_pattern = r"memory_encoder.o_proj.*" + memory_attention_pattern = r"memory_attention.*" + output_object_pointer_proj_pattern = r"object_pointer_proj.layers.(\d+).*" + output_memory_encoder_mask_downsampler_pattern = r"memory_encoder.mask_downsampler.encoder.(\d+).*" + perceiver_resampler_patterns = { + r"spatial_perceiver.latents": r"spatial_perceiver.latents_1d", + r"spatial_perceiver.latents_1d_2d": r"spatial_perceiver.latents_2d", + r"spatial_perceiver.layers.(\d+).attn.layer_norm_x": r"spatial_perceiver.layers.\1.layer_norm_input", + r"spatial_perceiver.layers.(\d+).attn.layer_norm_latents": r"spatial_perceiver.layers.\1.layer_norm_latents", + r"spatial_perceiver.layers.(\d+).self_attn.layer_norm": r"spatial_perceiver.layers.\1.layer_norm_self", + r"spatial_perceiver.layers.(\d+).attn.to_q": r"spatial_perceiver.layers.\1.cross_attention.q_proj", + r"spatial_perceiver.layers.(\d+).attn.to_kv": r"spatial_perceiver.layers.\1.cross_attention.kv_proj_combined", + r"spatial_perceiver.layers.(\d+).attn.to_out": r"spatial_perceiver.layers.\1.cross_attention.o_proj", + r"spatial_perceiver.layers.(\d+).self_attn.to_q": r"spatial_perceiver.layers.\1.self_attention.q_proj", + r"spatial_perceiver.layers.(\d+).self_attn.to_kv": r"spatial_perceiver.layers.\1.self_attention.kv_proj_combined", + r"spatial_perceiver.layers.(\d+).self_attn.to_out": r"spatial_perceiver.layers.\1.self_attention.o_proj", + r"spatial_perceiver.layers.(\d+).attn": r"spatial_perceiver.layers.\1.cross_attention", + r"spatial_perceiver.layers.(\d+).self_attn": r"spatial_perceiver.layers.\1.self_attention", + } + + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + for pattern, replacement in perceiver_resampler_patterns.items(): + if re.match(pattern, key): + key = re.sub(pattern, replacement, key) + + # vision_encoder.blocks.0.mlp.layers.1.weight -> vision_encoder.blocks.0.mlp.proj_out.weight + if re.match(output_vision_encoder_mlps_pattern, key): + layer_nb = int(re.match(output_vision_encoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "proj_out") + + if re.match(memory_attention_pattern, key): + key = key.replace("linear1", "mlp.up_proj") + key = key.replace("linear2", "mlp.down_proj") + + # mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight + if re.match(output_mask_decoder_mlps_pattern, key): + layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("mlp.layers.0", "mlp.proj_in") + elif layer_nb == 1: + key = key.replace("mlp.layers.1", "mlp.proj_out") + + # mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight + if re.match(output_mask_decoder_score_head_pattern, key): + layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + if re.match(output_hypernetworks_mlps_pattern, key): + layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + # vision_encoder.neck.convs.1.conv.bias -> vision_encoder.neck.convs.1.bias + if re.match(output_vision_encoder_neck_pattern, key): + key = key.replace(".conv.", ".") + + # memory_encoder.o_proj.weight -> memory_encoder.projection.weight + if re.match(output_memory_encoder_projection_pattern, key): + key = key.replace(".o_proj.", ".projection.") + + if re.match(output_object_pointer_proj_pattern, key): + layer_nb = int(re.match(output_object_pointer_proj_pattern, key).group(1)) + if layer_nb == 0: + key = key.replace("layers.0", "proj_in") + elif layer_nb == 1: + key = key.replace("layers.1", "layers.0") + elif layer_nb == 2: + key = key.replace("layers.2", "proj_out") + + key = key.replace("layers.2", "proj_out") + + if re.match(output_memory_encoder_mask_downsampler_pattern, key): + layer_nb = int(re.match(output_memory_encoder_mask_downsampler_pattern, key).group(1)) + if layer_nb == 12: + key = key.replace(f"encoder.{layer_nb}", "final_conv") + elif layer_nb % 3 == 0: + key = key.replace(f"encoder.{layer_nb}", f"layers.{layer_nb // 3}.conv") + elif layer_nb % 3 == 1: + key = key.replace(f"encoder.{layer_nb}", f"layers.{layer_nb // 3}.layer_norm") + if "kv_proj_combined" in key: + # Split the weight tensor in half along dimension 0 (output dimension) + k_weight, v_weight = torch.chunk(value, 2, dim=0) + # Create the k_proj and v_proj keys + k_key = key.replace("kv_proj_combined", "k_proj") + v_key = key.replace("kv_proj_combined", "v_proj") + model_state_dict[k_key] = k_weight + model_state_dict[v_key] = v_weight + continue + + model_state_dict[key] = value + + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + model_state_dict["prompt_encoder.point_embed.weight"] = torch.cat( + [model_state_dict.pop(f"prompt_encoder.point_embed.{i}.weight") for i in range(4)], + dim=0, + ) + + return model_state_dict + + +def convert_edgetam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, run_sanity_check): + config = get_config(model_name) + + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + state_dict = replace_keys(state_dict) + + image_processor = Sam2ImageProcessorFast() + video_processor = Sam2VideoVideoProcessor() + processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor) + hf_model = EdgeTamVideoModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=True) + hf_model = hf_model.to(device) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + if run_sanity_check: + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[[1000, 600]]]] + input_labels = [[[1]]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + output = hf_model._single_frame_forward(**inputs) + scores = output.iou_scores.squeeze() + + assert torch.allclose(scores, torch.tensor([0.0356, 0.2141, 0.9707]).cuda(), atol=1e-3) + + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"yonigozlan/{pytorch_dump_folder.split('/')[-1]}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["EdgeTAM"] + parser.add_argument( + "--model_name", + default="EdgeTAM", + choices=choices, + type=str, + help="Name of the original model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the original checkpoint", + ) + parser.add_argument("--pytorch_dump_folder_path", default="", type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model and processor to the hub after converting", + ) + parser.add_argument( + "--run_sanity_check", + action="store_true", + help="Whether to run the sanity check after converting", + ) + + args = parser.parse_args() + + hf_model_name = args.model_name.replace("_", "-") + checkpoint_path = ( + hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name.lower()}.pt") + if args.checkpoint_path is None + else args.checkpoint_path + ) + + convert_edgetam_checkpoint( + args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.run_sanity_check + ) diff --git a/src/transformers/models/edgetam_video/modeling_edgetam_video.py b/src/transformers/models/edgetam_video/modeling_edgetam_video.py new file mode 100644 index 000000000000..3ba7ab4ebf2f --- /dev/null +++ b/src/transformers/models/edgetam_video/modeling_edgetam_video.py @@ -0,0 +1,3062 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/edgetam_video/modular_edgetam_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_edgetam_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from tqdm import tqdm + +from transformers.utils.generic import OutputRecorder + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ModelOutput, auto_docstring +from ...utils.generic import TransformersKwargs +from ..auto import AutoModel +from .configuration_edgetam_video import ( + EdgeTamVideoConfig, + EdgeTamVideoMaskDecoderConfig, + EdgeTamVideoPromptEncoderConfig, +) + + +class EdgeTamVideoLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class EdgeTamVideoMemoryFuserCXBlock(GradientCheckpointingLayer): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.depthwise_conv = nn.Conv2d( + config.memory_fuser_embed_dim, + config.memory_fuser_embed_dim, + kernel_size=config.memory_fuser_kernel_size, + padding=config.memory_fuser_padding, + groups=config.memory_fuser_embed_dim, + ) # depthwise conv + self.layer_norm = EdgeTamVideoLayerNorm(config.memory_fuser_embed_dim, eps=1e-6, data_format="channels_first") + self.activation = ACT2FN[config.memory_fuser_hidden_act] + self.pointwise_conv1 = nn.Linear( + config.memory_fuser_embed_dim, config.memory_fuser_intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.pointwise_conv2 = nn.Linear(config.memory_fuser_intermediate_dim, config.memory_fuser_embed_dim) + self.scale = nn.Parameter( + config.memory_fuser_layer_scale_init_value * torch.ones(config.memory_fuser_embed_dim), + requires_grad=True, + ) + + def forward(self, hidden_states): + input = hidden_states + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + hidden_states = self.pointwise_conv1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.scale * hidden_states + hidden_states = hidden_states.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + hidden_states = input + hidden_states + return hidden_states + + +@dataclass +@auto_docstring(custom_intro="Base class for the vision encoder's outputs.") +class EdgeTamVideoVisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(torch.FloatTensor)`): + Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + fpn_hidden_states: Optional[torch.FloatTensor] = None + fpn_position_encoding: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class EdgeTamVideoVisionRotaryEmbedding(nn.Module): + """ + Vision Rotary Position Embedding for SAM2, following transformers library standards. + Supports 2D (axial) rotary embeddings for spatial dimensions. + """ + + def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None): + super().__init__() + dim = config.memory_attention_hidden_size // ( + config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads + ) + # Ensure even dimension for proper axial splitting + if dim % 4 != 0: + raise ValueError("Dimension must be divisible by 4 for axial RoPE") + end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y) + freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + # Generate 2D position indices for axial rotary embedding + flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) + x_positions = flattened_indices % end_x + y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") + freqs_x = torch.outer(x_positions, freqs).float() + freqs_y = torch.outer(y_positions, freqs).float() + inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) + inv_freq = inv_freq.repeat_interleave(2, dim=-1) + # directly register the cos and sin embeddings as we have a fixed feature shape + self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) + self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) + + @torch.no_grad() + def forward(self) -> tuple[torch.Tensor, torch.Tensor]: + # As the feature map size is fixed, we can just return the pre-computed embeddings. + return self.rope_embeddings_cos, self.rope_embeddings_sin + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EdgeTamVideoAttention(nn.Module): + """ + EDGETAM_VIDEO's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_similarity: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +def rotate_pairwise(x): + """ + pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation. + + This is an optimized version of the following more explicit implementation: + ```python + x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device) + x_rotated[..., ::2] = -x[..., 1::2] + x_rotated[..., 1::2] = x[..., ::2] + return x_rotated + ``` + """ + x = x.view(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(start_dim=-2) + + +def apply_rotary_pos_emb_2d_self_attn( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding to query and key tensors for self-attention. + + Args: + q: Query tensor of shape (..., seq_len, head_dim) + k: Key tensor of shape (..., seq_len, head_dim) + cos: Cosine position embedding of shape (seq_len, head_dim) + sin: Sine position embedding of shape (seq_len, head_dim) + + Returns: + Rotated (q, k) tensors + """ + # Apply RoPE to queries + q_embed = q.float() # force upscale to float32 as in the original implementation + q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) + + # Apply RoPE to keys (same embeddings as queries for self-attention) + k_embed = k.float() # force upscale to float32 as in the original implementation + k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin) + + return q_embed.type_as(q), k_embed.type_as(k) + + +class EdgeTamVideoRoPESelfAttention(nn.Module): + """Self-attention with rotary position encoding.""" + + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate + self.num_attention_heads = config.memory_attention_num_attention_heads + self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.dropout_p = config.memory_attention_rope_dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tensor: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + cos, sin = position_embeddings + # Apply rotary position encoding for self-attention + query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def apply_rotary_pos_emb_2d_cross_attn( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cos_k: torch.Tensor, + sin_k: torch.Tensor, + num_k_exclude_rope: int = 0, + repeat_freqs_k: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding to query and key tensors for cross-attention. + + Args: + q: Query tensor of shape (..., seq_len, head_dim) + k: Key tensor of shape (..., seq_len, head_dim) + cos: Cosine position embedding of shape (seq_len, head_dim) + sin: Sine position embedding of shape (seq_len, head_dim) + cos_k: Cosine position embedding for keys of shape (seq_len, head_dim) + sin_k: Sine position embedding for keys of shape (seq_len, head_dim) + num_k_exclude_rope: Number of tokens at end of k to exclude from RoPE (e.g., object pointer tokens) + repeat_freqs_k: Frequency repetition for keys in cross-attention (e.g., for spatial memory tokens) + + Returns: + Rotated (q, k) tensors + """ + # Apply RoPE to queries (always straightforward) + q_embed = q.float() + q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) + + # Split keys: RoPE tokens and excluded tokens (e.g., object pointers) + num_total_k_tokens = k.shape[-2] + k_for_rope = k[..., : num_total_k_tokens - num_k_exclude_rope, :] + k_excluded = k[..., num_total_k_tokens - num_k_exclude_rope :, :] + + # Early return if no keys need RoPE + if k_for_rope.shape[-2] == 0: + return q_embed.type_as(q), k_excluded + + batch_size, num_heads, k_seq_len, channels_per_head = k_for_rope.shape + + # Handle temporal/spatial token structure for memory + # Keys have temporal + spatial structure, only spatial tokens get RoPE + tokens_per_group = k_seq_len // repeat_freqs_k + spatial_tokens = cos_k.shape[-2] + temporal_tokens = tokens_per_group - spatial_tokens + + # Reshape and separate temporal/spatial tokens + k_grouped = k_for_rope.view(batch_size, num_heads, repeat_freqs_k, tokens_per_group, channels_per_head) + k_temporal = k_grouped[..., :temporal_tokens, :].reshape(batch_size, num_heads, -1, channels_per_head) + k_spatial = k_grouped[..., temporal_tokens:, :].reshape(batch_size, num_heads, -1, channels_per_head) + + # Only apply RoPE to spatial tokens + k_rope_input = k_spatial + + # Prepare position embeddings for repeated groups + if repeat_freqs_k > 1: + cos_k = cos_k.repeat(1, 1, repeat_freqs_k, 1) + sin_k = sin_k.repeat(1, 1, repeat_freqs_k, 1) + + # Apply RoPE to spatial tokens + k_spatial_embed = k_rope_input.float() + k_spatial_embed = (k_spatial_embed * cos_k) + (rotate_pairwise(k_spatial_embed) * sin_k) + + # Reconstruct: temporal + spatial tokens back to original structure + k_spatial_reshaped = k_spatial_embed.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head) + k_temporal_reshaped = k_temporal.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head) + k_final = torch.cat([k_temporal_reshaped, k_spatial_reshaped], dim=3) + k_final = k_final.view(batch_size, num_heads, k_seq_len, channels_per_head) + + # Combine RoPE-processed keys with excluded tokens + k_embed = torch.cat([k_final.type_as(k), k_excluded], dim=-2) + return q_embed.type_as(q), k_embed + + +class EdgeTamVideoRoPECrossAttention(nn.Module): + """Cross-attention with rotary position encoding.""" + + def __init__(self, config: EdgeTamVideoConfig, kv_in_dim: int): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate + self.num_attention_heads = config.memory_attention_num_attention_heads + self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.kv_in_dim = kv_in_dim + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.dropout_p = config.memory_attention_rope_dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings_k: tuple[torch.Tensor, torch.Tensor], + num_k_exclude_rope: int = 0, + rope_k_repeat: int = 0, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tensor: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos_k, sin_k = position_embeddings_k + # Apply rotary position encoding for cross-attention + query, key = apply_rotary_pos_emb_2d_cross_attn( + query, + key, + cos=cos, + sin=sin, + cos_k=cos_k, + sin_k=sin_k, + repeat_freqs_k=rope_k_repeat, + num_k_exclude_rope=num_k_exclude_rope, + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EdgeTamVideoTwoWayAttentionBlock(nn.Module): + def __init__(self, config: EdgeTamVideoMaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`EdgeTamVideoMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = EdgeTamVideoAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = EdgeTamVideoAttention(config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size) + + self.mlp = EdgeTamVideoFeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = EdgeTamVideoAttention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class EdgeTamVideoPositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + @compile_compatible_method_lru_cache(maxsize=2) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[Tensor] = None, + ) -> Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class EdgeTamVideoMemoryFuser(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.layers = nn.ModuleList( + [EdgeTamVideoMemoryFuserCXBlock(config) for _ in range(config.memory_fuser_num_layers)] + ) + + def forward(self, hidden_states): + # normally hidden_states: (N, C, H, W) + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EdgeTamVideoMaskDownSamplerLayer(nn.Module): + def __init__(self, config: EdgeTamVideoConfig, in_channels: int, out_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=config.mask_downsampler_kernel_size, + stride=config.mask_downsampler_stride, + padding=config.mask_downsampler_padding, + ) + self.layer_norm = EdgeTamVideoLayerNorm(out_channels, eps=1e-6, data_format="channels_first") + self.activation = ACT2FN[config.mask_downsampler_hidden_act] + + def forward(self, x): + return self.activation(self.layer_norm(self.conv(x))) + + +class EdgeTamVideoMaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + + num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride)) + + self.layers = nn.ModuleList() + self.activation = ACT2FN[config.mask_downsampler_hidden_act] + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2) + self.layers.append(EdgeTamVideoMaskDownSamplerLayer(config, mask_in_chans, mask_out_chans)) + mask_in_chans = mask_out_chans + + self.final_conv = nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.final_conv(x) + return x + + +class EdgeTamVideoMemoryEncoder(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + + hidden_size = config.memory_encoder_hidden_size + output_channels = config.memory_encoder_output_channels + self.mask_downsampler = EdgeTamVideoMaskDownSampler(config) + self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) + self.memory_fuser = EdgeTamVideoMemoryFuser(config) + self.position_encoding = EdgeTamVideoPositionEmbeddingSine(num_pos_feats=output_channels // 2, normalize=True) + self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1) + + def forward( + self, + vision_features: torch.Tensor, + masks: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + ## Process masks + masks = self.mask_downsampler(masks) + ## Fuse pixel_features and downsampled masks + + vision_features = self.feature_projection(vision_features) + vision_features = vision_features + masks + vision_features = self.memory_fuser(vision_features) + vision_features = self.projection(vision_features) + + vision_pos_enc = self.position_encoding(vision_features.shape, vision_features.device, vision_features.dtype) + + return vision_features, vision_pos_enc + + +class EdgeTamVideoFeedForward(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +@auto_docstring +class EdgeTamVideoPreTrainedModel(PreTrainedModel): + config_class = EdgeTamVideoConfig + base_model_prefix = "edgetam_video" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, EdgeTamVideoLayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, EdgeTamVideoModel): + if module.no_memory_positional_encoding is not None: + module.no_memory_positional_encoding.data.zero_() + if module.memory_temporal_positional_encoding is not None: + module.memory_temporal_positional_encoding.data.zero_() + if module.no_object_pointer is not None: + module.no_object_pointer.data.zero_() + if module.occlusion_spatial_embedding_parameter is not None: + module.occlusion_spatial_embedding_parameter.data.zero_() + if isinstance(module, EdgeTamVideoMemoryFuserCXBlock): + if module.scale is not None: + module.scale.data.zero_() + + +class EdgeTamVideoInferenceCache: + """Cache for vision features and model constants.""" + + def __init__( + self, + inference_device: Union[torch.device, str] = "cpu", + inference_state_device: Union[torch.device, str] = "cpu", + max_vision_features_cache_size: int = 1, + ): + self.inference_device = inference_device + self.inference_state_device = inference_state_device + self.max_vision_features_cache_size = max_vision_features_cache_size + + self._vision_features = {} + + def cache_vision_features(self, frame_idx: int, features: dict): + """Cache vision features with automatic device management.""" + cached = {} + if len(self._vision_features) >= self.max_vision_features_cache_size: + # remove the oldest frame + self._vision_features.pop(min(self._vision_features.keys())) + + for key, value in features.items(): + if isinstance(value, torch.Tensor): + cached[key] = value.to(self.inference_state_device, non_blocking=True) + elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor): + cached[key] = [v.to(self.inference_state_device, non_blocking=True) for v in value] + else: + cached[key] = value + self._vision_features[frame_idx] = cached + + def get_vision_features(self, frame_idx: int) -> Optional[dict]: + """Get cached vision features, automatically moved to inference device.""" + if frame_idx not in self._vision_features: + return None + + cached = self._vision_features[frame_idx] + moved = {} + for key, value in cached.items(): + if isinstance(value, torch.Tensor): + moved[key] = value.to(self.inference_device, non_blocking=True) + elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor): + moved[key] = [v.to(self.inference_device, non_blocking=True) for v in value] + else: + moved[key] = value + return moved + + def clear_all(self): + """Clear all cached data.""" + self._vision_features.clear() + + +class EdgeTamVideoInferenceSession: + r""" + Manages video inference session parameters, state and cache. + + Args: + video (`torch.FloatTensor`, *optional*): + The video to process. No need to provide when streaming. + video_height (`int`, *optional*): + The height of the video. + video_width (`int`, *optional*): + The width of the video. + inference_device (`torch.device`, *optional*, defaults to `"cpu"`): + The device to use for inference. + inference_state_device (`torch.device`, *optional*, defaults to `"cpu"`): + The device to store the inference state on. + video_storage_device (`torch.device`, *optional*, defaults to `"cpu"`): + The device to store the video on. + dtype (`torch.dtype`, *optional*, defaults to `"float32"`): + The dtype to use for the video. + max_vision_features_cache_size (`int`, *optional*, defaults to 1): + The maximum number of vision features to cache. + """ + + def __init__( + self, + video: Optional[torch.FloatTensor] = None, + video_height: Optional[int] = None, + video_width: Optional[int] = None, + inference_device: Union[torch.device, str] = "cpu", + inference_state_device: Union[torch.device, str] = "cpu", + video_storage_device: Union[torch.device, str] = "cpu", + dtype: Union[torch.dtype, str] = "float32", + max_vision_features_cache_size: int = 1, + ): + # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames + self.processed_frames = ( + dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None + ) + self.video_height = video_height + self.video_width = video_width + + self.inference_device = inference_device + self.inference_state_device = inference_state_device + self.video_storage_device = video_storage_device + self.dtype = dtype + self.max_vision_features_cache_size = max_vision_features_cache_size + + # Cache for computed features + self.cache = EdgeTamVideoInferenceCache( + inference_device=self.inference_device, + inference_state_device=self.inference_state_device, + max_vision_features_cache_size=self.max_vision_features_cache_size, + ) + + # Persistent object tracking state + self._obj_id_to_idx = OrderedDict() + self._obj_idx_to_id = OrderedDict() + self.obj_ids = [] + + # Persistent user inputs + self.point_inputs_per_obj = {} + self.mask_inputs_per_obj = {} + + # Persistent model outputs/history + self.output_dict_per_obj = {} + self.frames_tracked_per_obj = {} + + # Session state flags + self.obj_with_new_inputs = [] + + @property + def num_frames(self) -> Optional[int]: + return len(self.processed_frames) if self.processed_frames is not None else None + + # Object management + def obj_id_to_idx(self, obj_id: int) -> int: + """Map object ID to index, creating new entry if needed.""" + obj_idx = self._obj_id_to_idx.get(obj_id, None) + if obj_idx is not None: + return obj_idx + + obj_idx = len(self._obj_id_to_idx) + self._obj_id_to_idx[obj_id] = obj_idx + self._obj_idx_to_id[obj_idx] = obj_id + self.obj_ids = list(self._obj_id_to_idx) + + self.point_inputs_per_obj[obj_idx] = {} + self.mask_inputs_per_obj[obj_idx] = {} + self.output_dict_per_obj[obj_idx] = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + self.frames_tracked_per_obj[obj_idx] = {} + + return obj_idx + + # Video Inference specific functions + def obj_idx_to_id(self, obj_idx: int) -> int: + """Map model-side object index to client-side object id.""" + return self._obj_idx_to_id[obj_idx] + + def get_obj_num(self) -> int: + """Get the total number of unique object ids received so far in this session.""" + return len(self._obj_idx_to_id) + + # Input management with device handling + def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict): + """Add point inputs with automatic device placement.""" + device_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + device_inputs[key] = value.to(self.inference_device, non_blocking=True) + else: + device_inputs[key] = value + self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs + + def remove_point_inputs(self, obj_idx: int, frame_idx: int): + """Remove point inputs.""" + self.point_inputs_per_obj[obj_idx].pop(frame_idx, None) + + def add_mask_inputs(self, obj_idx: int, frame_idx: int, inputs: torch.Tensor): + """Add mask inputs with automatic device placement.""" + self.mask_inputs_per_obj[obj_idx][frame_idx] = inputs.to( + self.inference_device, dtype=self.dtype, non_blocking=True + ) + + def remove_mask_inputs(self, obj_idx: int, frame_idx: int): + """Remove mask inputs.""" + self.mask_inputs_per_obj[obj_idx].pop(frame_idx, None) + + # Output management with smart device placement + def store_output( + self, + obj_idx: int, + frame_idx: int, + output_key: Optional[str] = None, + output_value: Optional[Union[torch.Tensor, dict]] = None, + is_conditioning_frame: bool = True, + ): + """ + Store output with smart device management. + If output_key is None, the output is stored as a dictionary. + + Args: + obj_idx (int): The index of the object. + frame_idx (int): The index of the frame. + output_key (Optional[str]): The key of the output. If None, the output is stored as a dictionary. + output_value (Optional[Union[torch.Tensor, dict]]): The value of the output. + is_conditioning_frame (bool): Whether the output is for a conditioning frame. + """ + storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs" + + if output_key is None and isinstance(output_value, dict): + self.output_dict_per_obj[obj_idx][storage_key][frame_idx] = {} + for key, value in output_value.items(): + self.store_output(obj_idx, frame_idx, key, value, is_conditioning_frame) + return + + # Device placement: small tensors stay on inference device, large ones go to inference state device + if output_key in ["object_pointer", "object_score_logits"]: # Small tensors + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value + elif isinstance(output_value, torch.Tensor): # Large tensors like masks, features + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value.to( + self.inference_state_device, non_blocking=True + ) + else: + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value + + def get_output( + self, + obj_idx: int, + frame_idx: int, + output_key: str, + is_conditioning_frame: bool = True, + ): + """ + Get output with smart device management. + + Args: + obj_idx (int): The index of the object. + frame_idx (int): The index of the frame. + output_key (str): The key of the output. + is_conditioning_frame (bool): Whether the output is for a conditioning frame. + """ + storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs" + out = self.output_dict_per_obj[obj_idx][storage_key].get(frame_idx, None) + # move to inference device if needed + if out is None: + return None + value = out[output_key] + if isinstance(value, torch.Tensor): + value = value.to(self.inference_device, non_blocking=True) + return value + + # Video frame management + def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int: + """Add new frame with automatic device placement.""" + pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True) + if pixel_values.dim() == 4: + pixel_values = pixel_values.squeeze(0) + + if frame_idx is None: + frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0 + + if self.processed_frames is None: + self.processed_frames = {frame_idx: pixel_values} + else: + self.processed_frames[frame_idx] = pixel_values + + return frame_idx + + def get_frame(self, frame_idx: int) -> torch.Tensor: + """Get frame from video.""" + return self.processed_frames[frame_idx].to(self.inference_device, non_blocking=True) + + def reset_tracking_data(self): + """Reset tracking data but keep cache.""" + self._obj_id_to_idx.clear() + self._obj_idx_to_id.clear() + self.obj_ids.clear() + self.point_inputs_per_obj.clear() + self.mask_inputs_per_obj.clear() + self.output_dict_per_obj.clear() + self.frames_tracked_per_obj.clear() + self.obj_with_new_inputs = [] + # Note: cache and video data are preserved + + def reset_inference_session(self): + """Reset tracking data and cache.""" + self._obj_id_to_idx.clear() + self._obj_idx_to_id.clear() + self.obj_ids.clear() + self.point_inputs_per_obj.clear() + self.mask_inputs_per_obj.clear() + self.output_dict_per_obj.clear() + self.frames_tracked_per_obj.clear() + self.obj_with_new_inputs = [] + self.cache.clear_all() + + +class EdgeTamVideoMemoryAttentionMLP(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.intermediate_size = config.memory_attention_mlp_hidden_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.dropout = nn.Dropout(config.memory_attention_dropout) + self.act_fn = ACT2FN[config.memory_attention_mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.dropout(self.act_fn(self.up_proj(x)))) + + +class EdgeTamVideoMemoryAttentionLayer(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + hidden_size = config.memory_attention_hidden_size + self.self_attn = EdgeTamVideoRoPESelfAttention(config) + self.cross_attn_image = EdgeTamVideoRoPECrossAttention(config, kv_in_dim=64) + + # MLP module + self.mlp = EdgeTamVideoMemoryAttentionMLP(config) + + self.layer_norm1 = nn.LayerNorm(hidden_size) + self.layer_norm2 = nn.LayerNorm(hidden_size) + self.layer_norm3 = nn.LayerNorm(hidden_size) + self.dropout1 = nn.Dropout(config.memory_attention_dropout) + self.dropout2 = nn.Dropout(config.memory_attention_dropout) + self.dropout3 = nn.Dropout(config.memory_attention_dropout) + + def forward( + self, + queries: Tensor, + keys: Tensor, + key_point_embedding: Tensor, + rope_position_embeddings: tuple[Tensor, Tensor], + rope_position_embeddings_k: Optional[tuple[Tensor, Tensor]] = None, + num_k_exclude_rope: int = 0, + rope_k_repeat: int = 0, + ) -> torch.Tensor: + # Self-Attention + query = self.layer_norm1(queries) + query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings) + queries = queries + self.dropout1(query) + + # Cross-Attention + query = self.layer_norm2(queries) + query, _ = self.cross_attn_image( + query=query, + key=keys + key_point_embedding, + value=keys, + position_embeddings=rope_position_embeddings, + position_embeddings_k=rope_position_embeddings_k, + num_k_exclude_rope=num_k_exclude_rope, + rope_k_repeat=rope_k_repeat, + ) + queries = queries + self.dropout2(query) + # MLP + query = self.layer_norm3(queries) + query = self.mlp(query) + queries = queries + self.dropout3(query) + return queries + + +class EdgeTamVideoMemoryAttention(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.layers = nn.ModuleList( + [EdgeTamVideoMemoryAttentionLayer(config) for _ in range(config.memory_attention_num_layers)] + ) + self.layer_norm = nn.LayerNorm(config.memory_attention_hidden_size) + self.rotary_emb = EdgeTamVideoVisionRotaryEmbedding(config=config) + self.rotary_emb_k = EdgeTamVideoVisionRotaryEmbedding( + config, end_x=config.memory_attention_rope_k_sizes[0], end_y=config.memory_attention_rope_k_sizes[1] + ) + + def forward( + self, + current_vision_features: torch.Tensor, + memory: torch.Tensor, + current_vision_position_embeddings: Optional[Tensor] = None, + memory_posision_embeddings: Optional[Tensor] = None, + num_object_pointer_tokens: int = 0, + num_spatial_memory_tokens: int = -1, + ): + """ + Args: + current_vision_features (`torch.FloatTensor`): + The current vision features used for self-attention. + memory (`torch.FloatTensor`): + The memory features used for cross-attention. + current_vision_position_embeddings (`torch.FloatTensor`, *optional*): + The position embeddings for the current vision features. + memory_posision_embeddings (`torch.FloatTensor`, *optional*): + The position embeddings for the memory features. + num_object_pointer_tokens (`int`, *optional*, defaults to 0): + The number of object pointer tokens. + """ + output = current_vision_features + if current_vision_position_embeddings is not None: + output = output + 0.1 * current_vision_position_embeddings + + # Convert to batch first + output = output.transpose(0, 1) + memory = memory.transpose(0, 1).unsqueeze(1) + memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1) + rope_position_embeddings = self.rotary_emb() + rope_position_embeddings_k = self.rotary_emb_k() + for layer in self.layers: + output = layer( + queries=output.unsqueeze(1) if output.ndim == 3 else output, + keys=memory, + key_point_embedding=memory_posision_embeddings, + rope_position_embeddings=rope_position_embeddings, + rope_position_embeddings_k=rope_position_embeddings_k, + num_k_exclude_rope=num_object_pointer_tokens, + rope_k_repeat=num_spatial_memory_tokens, + ) + + normed_output = self.layer_norm(output) + + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + + return normed_output + + +class EdgeTamVideoPerceiverMLP(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.hidden_size = config.perceiver_resampler_hidden_size + self.intermediate_size = config.perceiver_resampler_mlp_intermediate_size + + self.layer_norm = nn.LayerNorm(self.hidden_size) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states))) + return hidden_states + + +class EdgeTamVideoPerceiverAttention(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.perceiver_resampler_hidden_size + self.num_attention_heads = config.perceiver_resampler_num_attention_heads + self.head_dim = config.perceiver_resampler_attention_head_dim + self.attention_dropout = config.perceiver_resampler_attention_dropout + + self.inner_dim = self.head_dim * self.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # Project queries, keys, and values + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + # Reshape for multi-head attention + batch_size, seq_len_q = query.shape[:2] + query = query.view(batch_size, seq_len_q, self.num_attention_heads, self.head_dim).transpose(1, 2) + seq_len_kv = key.shape[1] + key = key.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2) + + # Add positional encoding if provided + if positional_encoding is not None: + pos_encoding = positional_encoding.view( + batch_size, seq_len_kv, self.num_attention_heads, self.head_dim + ).transpose(1, 2) + key = key + pos_encoding + value = value + pos_encoding + + # Apply attention + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + # Reshape output + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.inner_dim) + return self.o_proj(attn_output) + + +class EdgeTamVideoPerceiverEncoderLayer(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + + self.cross_attention = EdgeTamVideoPerceiverAttention(config) + self.mlp = EdgeTamVideoPerceiverMLP(config) + self.dropout = nn.Dropout(config.perceiver_resampler_hidden_dropout) + + self.self_attention = EdgeTamVideoPerceiverAttention(config) + self.self_mlp = EdgeTamVideoPerceiverMLP(config) + + # Layer norms moved from attention classes to here + self.layer_norm_input = nn.LayerNorm(config.perceiver_resampler_hidden_size) + self.layer_norm_latents = nn.LayerNorm(config.perceiver_resampler_hidden_size) + self.layer_norm_self = nn.LayerNorm(config.perceiver_resampler_hidden_size) + + def forward( + self, + latents: torch.Tensor, + input_features: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Cross attention with layer norms + normalized_latents = self.layer_norm_latents(latents) + normalized_input = self.layer_norm_input(input_features) + cross_attention_output = self.cross_attention( + query=normalized_latents, + key=normalized_input, + value=normalized_input, + positional_encoding=positional_encoding, + ) + latents = latents + self.dropout(cross_attention_output) + + mlp_output = self.mlp(latents) + latents = latents + mlp_output + + # Self attention with layer norm + normalized_latents_self = self.layer_norm_self(latents) + self_attention_output = self.self_attention( + query=normalized_latents_self, key=normalized_latents_self, value=normalized_latents_self + ) + latents = latents + self_attention_output + + self_mlp_output = self.self_mlp(latents) + latents = latents + self_mlp_output + + return latents + + +def window_partition(hidden_state, window_size): + """ + Partition into non-overlapping windows with padding if needed. + + Args: + hidden_state (`torch.Tensor`): + Input tokens with [batch_size, height, width, num_channels]. + window_size (`int`): + Window size. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements: + - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. + - (padded_height, padded_width): padded height and width before partition + """ + batch_size, height, width, num_channels = hidden_state.shape + + pad_height = (window_size - height % window_size) % window_size + pad_width = (window_size - width % window_size) % window_size + + # Noop in case pad_width == 0 and pad_height == 0. + hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) + + padded_height, padded_width = height + pad_height, width + pad_width + + hidden_state = hidden_state.view( + batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels + ) + windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows, (padded_height, padded_width) + + +class EdgeTamVideoPerceiverResampler(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.perceiver_resampler_hidden_size + self.num_latents_1d = config.perceiver_resampler_num_latents + self.num_latents_2d = config.perceiver_resampler_num_latents_2d + self.num_layers = config.perceiver_resampler_num_layers + + if self.num_latents_1d > 0: + self.latents_1d = nn.Parameter(torch.randn(self.num_latents_1d, self.hidden_size)) + if self.num_latents_2d > 0: + self.latents_2d = nn.Parameter(torch.randn(self.num_latents_2d, self.hidden_size)) + + self.positional_encoding = EdgeTamVideoPositionEmbeddingSine( + num_pos_feats=self.hidden_size // 2, normalize=True + ) + + self.layers = nn.ModuleList([EdgeTamVideoPerceiverEncoderLayer(config) for _ in range(self.num_layers)]) + + self.layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + output_latents = [] + output_positional_encodings = [] + + if self.num_latents_1d > 0: + latents_1d, pos_1d = self._forward_1d(hidden_states, positional_encoding) + output_latents.append(latents_1d) + output_positional_encodings.append(pos_1d) + + if self.num_latents_2d > 0: + latents_2d, pos_2d = self._forward_2d(hidden_states) + output_latents.append(latents_2d) + output_positional_encodings.append(pos_2d) + + combined_latents = torch.cat(output_latents, dim=1) + + combined_positional_encoding = None + if positional_encoding is not None and output_positional_encodings: + combined_positional_encoding = torch.cat(output_positional_encodings, dim=1) + + return combined_latents, combined_positional_encoding + + def _forward_1d( + self, + hidden_states: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size = hidden_states.shape[0] + + latents = self.latents_1d.unsqueeze(0).expand(batch_size, -1, -1) + flattened_features = hidden_states.permute(0, 2, 3, 1).flatten(1, 2) + + positional_features = None + if positional_encoding is not None: + positional_features = positional_encoding.permute(0, 2, 3, 1).flatten(1, 2) + + for layer in self.layers: + latents = layer(latents, flattened_features, positional_features) + + latents = self.layer_norm(latents) + + output_positional_encoding = None + if positional_encoding is not None: + output_positional_encoding = torch.zeros_like(latents) + + return latents, output_positional_encoding + + def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, channels, height, width = hidden_states.shape + + latents_2d = self.latents_2d.unsqueeze(0).expand(batch_size, -1, -1).view(-1, 1, channels) + + num_windows_per_dim = int(math.sqrt(self.num_latents_2d)) + window_size = height // num_windows_per_dim + + windowed_input = hidden_states.permute(0, 2, 3, 1) + windowed_features, _ = window_partition(windowed_input, window_size) + windowed_features = windowed_features.flatten(1, 2) + + for layer in self.layers: + latents_2d = layer(latents_2d, windowed_features, positional_encoding=None) + + latents_2d = latents_2d.view(batch_size, num_windows_per_dim, num_windows_per_dim, channels).permute( + 0, 3, 1, 2 + ) + + positional_encoding_2d = self.positional_encoding(latents_2d.shape, latents_2d.device, latents_2d.dtype).to( + dtype=hidden_states.dtype + ) + positional_encoding_2d = positional_encoding_2d.permute(0, 2, 3, 1).flatten(1, 2) + + latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2) + latents_2d = self.layer_norm(latents_2d) + + return latents_2d, positional_encoding_2d + + +@dataclass +@auto_docstring(custom_intro="Base class for the EdgeTamVideo model's output.") +class EdgeTamVideoImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(torch.FloatTensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + high_res_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, image_size, image_size)`, *optional*): + The predicted masks, upscaled to the original image size. Only used for EdgeTamVideoModel. + object_pointer (`torch.FloatTensor` of shape `(batch_size, point_batch_size, hidden_size)`, *optional*): + A tensor representing the object pointer, used for tracking in videos. Only used for EdgeTamVideoModel. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + object_score_logits: Optional[torch.FloatTensor] = None + image_embeddings: tuple[torch.FloatTensor, ...] = None + vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + high_res_masks: Optional[torch.FloatTensor] = None + object_pointer: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring(custom_intro="Base class for the Sam2 model's output.") +class EdgeTamVideoSegmentationOutput(ModelOutput): + r""" + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted masks stored at the model's resolution. + frame_idx (`int`): + The frame index of the video. + """ + + pred_masks: Optional[torch.FloatTensor] = None + frame_idx: Optional[int] = None + + +class EdgeTamVideoPositionalEmbedding(nn.Module): + def __init__(self, config: EdgeTamVideoPromptEncoderConfig): + super().__init__() + self.scale = config.scale + positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2)) + self.register_buffer("positional_embedding", positional_embedding) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(torch.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class EdgeTamVideoMaskEmbedding(nn.Module): + def __init__(self, config: EdgeTamVideoPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = EdgeTamVideoLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = EdgeTamVideoLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class EdgeTamVideoPromptEncoder(nn.Module): + def __init__(self, config: EdgeTamVideoPromptEncoderConfig): + super().__init__() + self.shared_embedding = EdgeTamVideoPositionalEmbedding(config) + self.mask_embed = EdgeTamVideoMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def forward( + self, + input_points: Optional[tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class EdgeTamVideoTwoWayTransformer(nn.Module): + def __init__(self, config: EdgeTamVideoMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(EdgeTamVideoTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = EdgeTamVideoAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class EdgeTamVideoMaskDecoder(nn.Module): + def __init__(self, config: EdgeTamVideoMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = EdgeTamVideoTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = EdgeTamVideoLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [EdgeTamVideoFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + self.iou_prediction_head = EdgeTamVideoFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = EdgeTamVideoFeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + high_resolution_features: list[torch.Tensor], + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[torch.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`torch.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`torch.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1) + ) + best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +@auto_docstring +class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [] + + def __init__(self, config: EdgeTamVideoConfig): + super().__init__(config) + self.shared_image_embedding = EdgeTamVideoPositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = EdgeTamVideoPromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = EdgeTamVideoMaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.config = config + # For video sequence inference + self.image_size = config.image_size + self.memory_attention = EdgeTamVideoMemoryAttention(config) + self.memory_encoder = EdgeTamVideoMemoryEncoder(config) + self.no_memory_positional_encoding = torch.nn.Parameter( + torch.zeros(1, 1, config.vision_config.fpn_hidden_size) + ) + self.mem_dim = config.memory_encoder_output_channels + self.num_maskmem = config.num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.memory_temporal_positional_encoding = torch.nn.Parameter( + torch.zeros(self.num_maskmem, 1, 1, self.mem_dim) + ) + + self.no_object_pointer = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + # a feedforward layer on SAM output tokens to turn them into object pointers + self.object_pointer_proj = EdgeTamVideoFeedForward(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.temporal_positional_encoding_projection_layer = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.temporal_positional_encoding_projection_layer = torch.nn.Identity() + + self.occlusion_spatial_embedding_parameter = None # compatibility with Sam2 + if config.enable_occlusion_spatial_embedding: + self.occlusion_spatial_embedding_parameter = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + self.spatial_perceiver = EdgeTamVideoPerceiverResampler(config) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self) -> torch.Tensor: + size = self.prompt_encoder.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones(size, device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> list[torch.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @torch.inference_mode() + @auto_docstring(custom_intro="Propagate the objects through a streamed video frame.") + def forward( + self, + inference_session: EdgeTamVideoInferenceSession, + frame_idx: Optional[int] = None, + frame: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> EdgeTamVideoSegmentationOutput: + r""" + inference_session (`EdgeTamVideoInferenceSession`): + The video inference session object. + frame_idx (`int`, *optional*): + The index of the frame on which to run inference. No need to provide when inferring + on a new streamed frame. + frame (`torch.Tensor`, *optional*): + The frame to process. Provide when streaming. + reverse (`bool`, *optional*, defaults to `False`): + Whether to propagate in reverse. + """ + if frame is not None: + frame_idx = inference_session.add_new_frame(frame, frame_idx) + + if frame is not None and inference_session.get_obj_num() == 0: + raise ValueError("No objects are provided for tracking; please add inputs first.") + + num_objects = inference_session.get_obj_num() + pred_masks_per_obj = [None] * num_objects + # Note: We avoid batched inference here because per-object inputs (clicks/masks) + # can differ across objects. + for obj_idx in range(num_objects): + obj_id = inference_session.obj_idx_to_id(obj_idx) + has_new_inputs = obj_id in inference_session.obj_with_new_inputs + has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + # If this object has no new inputs and this frame already has a + # conditioning output, reuse the cached masks instead of recomputing. + if (not has_new_inputs) and has_cond_output: + pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True) + is_init_cond_frame = True + else: + # Defaults when there are no new inputs + is_init_cond_frame = False + point_inputs = None + mask_inputs = None + + if has_new_inputs: + is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx] + if is_init_cond_frame: + reverse = False + point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None) + mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None) + if point_inputs is not None or mask_inputs is not None: + inference_session.obj_with_new_inputs.remove(obj_id) + + current_out = self._run_single_frame_inference( + inference_session=inference_session, + obj_idx=obj_idx, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + reverse=reverse, + run_mem_encoder=True, + streaming=frame is not None, + ) + inference_session.store_output( + obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame + ) + pred_masks = current_out["pred_masks"] + + pred_masks_per_obj[obj_idx] = pred_masks + if not is_init_cond_frame: + # only for tracked frames, not for initial conditioning frames + inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: + all_pred_masks = torch.cat(pred_masks_per_obj, dim=0) + else: + all_pred_masks = pred_masks_per_obj[0] + + return EdgeTamVideoSegmentationOutput(pred_masks=all_pred_masks, frame_idx=frame_idx) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ + list[torch.Tensor], + list[torch.Tensor], + Optional[tuple[torch.FloatTensor, ...]], + Optional[tuple[torch.FloatTensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`torch.FloatTensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: EdgeTamVideoVisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + def _prepare_vision_features( + self, + inference_session: EdgeTamVideoInferenceSession, + frame_idx: int, + batch_size: int, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Prepare vision features for a frame.""" + + # Check if features are cached + if cached_features := inference_session.cache.get_vision_features(frame_idx): + vision_feats = cached_features["vision_feats"] + vision_pos_embeds = cached_features["vision_pos_embeds"] + else: + # Compute features using image encoder + image_batch = inference_session.get_frame(frame_idx).unsqueeze(0) # Add batch dimension + vision_feats, vision_pos_embeds, _, _ = self.get_image_features(image_batch) + # Cache features + inference_session.cache.cache_vision_features( + frame_idx, {"vision_feats": vision_feats, "vision_pos_embeds": vision_pos_embeds} + ) + + # Expand to batch size if needed + if batch_size > 1: + vision_feats = vision_feats.expand(batch_size, -1, -1, -1) + vision_pos_embeds = [pe.expand(batch_size, -1, -1, -1) for pe in vision_pos_embeds] + + return vision_feats, vision_pos_embeds + + def _single_frame_forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> EdgeTamVideoImageSegmentationOutput: + """ + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + elif input_points is not None: + num_objects = input_points.shape[1] + elif input_boxes is not None: + num_objects = input_boxes.shape[1] + elif input_masks is not None: + num_objects = input_masks.shape[1] + else: + num_objects = 1 + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = torch.zeros( + batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device + ) + input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, sam_output_tokens, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + is_obj_appearing = object_score_logits > 0 + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + high_res_multimasks = ( + F.interpolate( + low_res_multimasks.squeeze(1).float(), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + .unsqueeze(1) + .to(low_res_multimasks.dtype) + ) + sam_output_token = sam_output_tokens[:, :, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(iou_scores, dim=-1) + batch_inds = torch.arange(batch_size, device=high_res_multimasks.device) + object_batch_inds = torch.arange(num_objects, device=high_res_multimasks.device) + low_res_masks = low_res_multimasks[batch_inds, object_batch_inds, best_iou_inds] + high_res_masks = high_res_multimasks[batch_inds, object_batch_inds, best_iou_inds] + if sam_output_tokens.size(2) > 1: + sam_output_token = sam_output_tokens[batch_inds, object_batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks[:, :, 0], high_res_multimasks[:, :, 0] + + # Extract object pointer from the SAM output token (with occlusion handling) + object_pointer = self.object_pointer_proj(sam_output_token) + lambda_is_obj_appearing = is_obj_appearing.to(object_pointer.dtype) + + object_pointer = lambda_is_obj_appearing * object_pointer + object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer + + return EdgeTamVideoImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_masks, + high_res_masks=high_res_masks, + object_pointer=object_pointer, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def _use_mask_as_output( + self, + backbone_features: torch.Tensor, + high_res_features: list[torch.Tensor], + mask_inputs: torch.Tensor, + ) -> EdgeTamVideoImageSegmentationOutput: + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in forward above). + """ + # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.to(backbone_features[0].dtype) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks.float(), + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(backbone_features[0].dtype) + # a dummy IoU prediction of all 1's under mask input + iou_scores = mask_inputs.new_ones(mask_inputs.size(0), 1).to(backbone_features[0].dtype) + # produce an object pointer using the SAM decoder from the mask input + object_pointer = self._single_frame_forward( + input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)), + image_embeddings=high_res_features + [backbone_features], + ).object_pointer + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype) + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + object_pointer = lambda_is_obj_appearing * object_pointer + object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer + return EdgeTamVideoImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_masks, + high_res_masks=high_res_masks, + object_pointer=object_pointer, + object_score_logits=object_score_logits, + image_embeddings=high_res_features + [backbone_features], + ) + + def _gather_memory_frame_outputs( + self, + inference_session: EdgeTamVideoInferenceSession, + obj_idx: int, + frame_idx: int, + track_in_reverse_time: bool = False, + ) -> list[tuple[int, dict]]: + """ + Get memory frames from conditioning and non-conditioning outputs. + + Returns: + List of (relative_temporal_offset, output_data) tuples. + """ + temporal_positions_and_previous_outputs = [] + + # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints) + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + if not conditioning_outputs: + raise ValueError( + "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" + ) + + # Store (temporal_position, output_data) tuples + temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] + + # Add non-conditioning memory frames (up to self.num_maskmem - 1) + # These are typically frames tracked by the model without direct user input. + # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. + for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): + # relative_temporal_offset: how many frames before (or after if reversing) the current frame + if not track_in_reverse_time: + previous_frame_idx = frame_idx - relative_temporal_offset + else: + previous_frame_idx = frame_idx + relative_temporal_offset + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + previous_frame_idx, None + ) + + temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) + + return temporal_positions_and_previous_outputs + + def _build_memory_attention_inputs( + self, + temporal_positions_and_previous_outputs: list[tuple[int, dict]], + device: torch.device, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Concatenate memory features and positional embeddings from previous frames. + + Returns: + Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate). + """ + memories_to_concatenate = [] + memory_positional_embeddings_to_concatenate = [] + + for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: + if prev_output_data is None: + continue # Skip if no output data for this temporal position (e.g., padding frames) + + # Load memory features (potentially from CPU to GPU) + # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) + memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) + memories_to_concatenate.append(memory_features.permute(1, 0, 2)) + + # Spatial positional encoding (potentially from CPU to GPU) + spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) + spatial_memory_pos_embed = spatial_memory_pos_embed.squeeze(1).permute(1, 0, 2) + + # Add temporal positional encoding + # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) + combined_memory_pos_embed = ( + spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] + ) + memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) + + return memories_to_concatenate, memory_positional_embeddings_to_concatenate + + def _get_object_pointers( + self, + inference_session: EdgeTamVideoInferenceSession, + obj_idx: int, + frame_idx: int, + num_total_frames: int, + device: torch.device, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> tuple[list[int], list[torch.Tensor], int]: + """ + Get object pointers and their positional embeddings from past frames. + + Returns: + Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use). + """ + temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 + + # Determine max object pointers to use + if streaming: + max_object_pointers_to_use = self.config.max_object_pointers_in_encoder + else: + max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) + + temporal_offsets: list[int] = [] + pointer_tokens: list[torch.Tensor] = [] + + # Add object pointers from selected conditioning frames + # Optionally, only include pointers from past frames during evaluation + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + eligible_conditioning_outputs = conditioning_outputs + if not self.training: + eligible_conditioning_outputs = { + temporal_idx: out + for temporal_idx, out in conditioning_outputs.items() + if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) + } + + for temporal_idx, out_data in eligible_conditioning_outputs.items(): + temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier + temporal_offsets.append(temporal_difference) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) + for t_diff_offset in range(1, max_object_pointers_to_use): + ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset + if ref_frame_idx < 0 or ( + not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames + ): + break # Stop if frame index is out of bounds + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + ref_frame_idx, None + ) + if out_data is not None: + temporal_offsets.append(t_diff_offset) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + return temporal_offsets, pointer_tokens, max_object_pointers_to_use + + def _process_object_pointers( + self, + temporal_offsets: list[int], + pointer_tokens: list[torch.Tensor], + max_object_pointers_to_use: int, + batch_size: int, + num_channels: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Process object pointers and compute their positional embeddings. + + Returns: + Tuple of (object_pointers, object_pointers_pos_embed). + """ + if not pointer_tokens: + return None, None + + # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) + object_pointers = torch.stack(pointer_tokens, dim=0) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + max_temporal_diff = float(max_object_pointers_to_use - 1) + # Determine dimensionality for temporal positional encoding of pointers + pointer_tpos_dim = num_channels + + # Normalize temporal differences before sine PE calculation + normalized_temporal_diffs = ( + torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff + ) + sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) + projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) + object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim) + else: + object_pointers_pos_embed = object_pointers.new_zeros( + len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype + ) + + if self.mem_dim < num_channels: + # If memory dimension is smaller, reshape/split pointers and repeat positional encoding + num_splits = num_channels // self.mem_dim + object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) + object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( + 0, 1 + ) # (SeqLen_ptr*num_splits, Batch, MemDim) + object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) + + return object_pointers, object_pointers_pos_embed + + def _prepare_memory_conditioned_features( + self, + inference_session: EdgeTamVideoInferenceSession, + frame_idx: int, + obj_idx: int, + is_initial_conditioning_frame: bool, + current_vision_features: list[torch.Tensor], + current_vision_positional_embeddings: list[torch.Tensor], + num_total_frames: int, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> torch.Tensor: + """ + Fuse current frame's visual features with memory from previous frames for enhanced object tracking. + + This method conditions the current frame's visual features on temporal memory from previous frames, + enabling consistent object tracking across video sequences. For initial conditioning frames, it uses + no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both + conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention. + + Args: + inference_session (`EdgeTamVideoInferenceSession`): + The video inference session object. + frame_idx (`int`): + Index of the current frame being processed. + obj_idx (`int`): + Index of the object being processed. + is_initial_conditioning_frame (`bool`): + Whether this is an initial conditioning frame with user inputs (True) or a subsequent + tracking frame (False). + current_vision_features (`torch.Tensor`): + Highest-level vision features of shape `(seq_len, batch_size, channels)`. + current_vision_positional_embeddings (`torch.Tensor`): + Positional embedding tensors corresponding to the highest-level vision features. + num_total_frames (`int`): + Total number of frames in the video sequence. + track_in_reverse_time (`bool`, *optional*, defaults to `False`): + Whether tracking is performed in reverse temporal order. + streaming (`bool`, *optional*, defaults to `False`): + Whether this is streaming inference mode. + + Returns: + `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)` + suitable for input to the SAM decoder. + """ + # Get dimensions from the highest-level (lowest-resolution) feature map + batch_size = current_vision_features.size(1) + num_channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] + device = current_vision_features.device + + # If memory is disabled (e.g., for single image SAM), return current features directly. + if self.num_maskmem == 0: + # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width) + # Assuming SeqLen = Height * Width for the last feature map + current_feature_map = current_vision_features.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return current_feature_map + + # Step 1: Handle initial conditioning frames + if is_initial_conditioning_frame: + # For initial conditioning frames, no prior memory is used directly in this block. + # If configured, directly add a learnable "no memory" embedding. + # current_vision_features has shape (SeqLen, Batch, Channels) + conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding + # Reshape to (Batch, Channels, Height, Width) + conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return conditioned_feature_map + + # Step 2: Get memory frames and concatenate their features + temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs( + inference_session, obj_idx, frame_idx, track_in_reverse_time + ) + + memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs( + temporal_positions_and_previous_outputs, device + ) + num_spatial_memory_tokens = len(memories_to_concatenate) + + # Step 3: Get and process object pointers + temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers( + inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming + ) + + num_object_pointer_tokens = 0 + if pointer_tokens: + object_pointers, object_pointers_pos_embed = self._process_object_pointers( + temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device + ) + + if object_pointers is not None: + memories_to_concatenate.append(object_pointers) + memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) + num_object_pointer_tokens = object_pointers.shape[0] + + # Step 4: Concatenate all retrieved memories and their positional embeddings + combined_memory = torch.cat(memories_to_concatenate, dim=0) + combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0) + + # Step 5: Forward through the memory attention mechanism + conditioned_feature_map_flat = self.memory_attention( + current_vision_features=current_vision_features, + current_vision_position_embeddings=current_vision_positional_embeddings, + memory=combined_memory, + memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API + num_object_pointer_tokens=num_object_pointer_tokens, + num_spatial_memory_tokens=num_spatial_memory_tokens, + ) + + # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width) + conditioned_feature_map = ( + conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width) + ) + return conditioned_feature_map + + def _use_multimask(self, is_init_cond_frame: bool, point_inputs: Optional[dict]) -> bool: + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(2) + multimask_output = ( + self.config.multimask_output_in_sam + and (is_init_cond_frame or self.config.multimask_output_for_tracking) + and (self.config.multimask_min_pt_num <= num_pts <= self.config.multimask_max_pt_num) + ) + return multimask_output + + def _run_single_frame_inference( + self, + inference_session: EdgeTamVideoInferenceSession, + frame_idx: int, + obj_idx: int, + batch_size: int, + is_init_cond_frame: bool, + point_inputs: Optional[torch.Tensor], + mask_inputs: Optional[torch.Tensor], + reverse: bool, + run_mem_encoder: bool, + prev_sam_mask_logits: Optional[torch.Tensor] = None, + streaming: bool = False, + ) -> dict[str, Any]: + """ + Perform a single tracking step for video object segmentation. + + Args: + inference_session (`EdgeTamVideoInferenceSession`): + The video inference session object. + frame_idx (`int`): + Index of the current frame. + obj_idx (`int`): + Index of the current object. + batch_size (`int`): + Batch size of the current frame. + is_init_cond_frame (`bool`): + Whether this is an initial conditioning frame with user inputs. + point_inputs (`dict`, *optional*): + Point prompt inputs for the current frame. + mask_inputs (`torch.Tensor`, *optional*): + Mask prompt inputs for the current frame. + reverse (`bool`, *optional*, defaults to `False`): + Whether to track in reverse time order. + run_mem_encoder (`bool`, *optional*, defaults to `True`): + Whether to run the memory encoder on predicted masks. + prev_sam_mask_logits (`torch.Tensor`, *optional*): + Previously predicted SAM mask logits that can be fed with new clicks. + streaming (`bool`, *optional*, defaults to `False`): + Whether this is streaming inference. + + Returns: + `dict`: Dictionary containing the tracking results for the current frame, including: + - pred_masks: Predicted low-resolution masks. + - object_pointer: Object pointer for memory. + - object_score_logits: Object score logits (inference only). + - maskmem_features: Memory features for future frames. + - maskmem_pos_enc: Memory positional encodings. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features( + inference_session, frame_idx, batch_size + ) + # point and mask should not appear as input simultaneously on the same frame + if point_inputs is not None and mask_inputs is not None: + raise ValueError( + "point_inputs and mask_inputs should not appear as input simultaneously on the same frame" + ) + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None: + # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + inference_session=inference_session, + frame_idx=frame_idx, + obj_idx=obj_idx, + is_initial_conditioning_frame=is_init_cond_frame, + current_vision_features=current_vision_feats[-1], + current_vision_positional_embeddings=current_vision_pos_embeds[-1], + num_total_frames=inference_session.num_frames, + track_in_reverse_time=reverse, + streaming=streaming, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._single_frame_forward( + pixel_values=None, # Vision features already computed + input_points=point_inputs["point_coords"] if point_inputs is not None else None, + input_labels=point_inputs["point_labels"] if point_inputs is not None else None, + input_masks=mask_inputs, + image_embeddings=high_res_features + [pix_feat], + multimask_output=multimask_output, + ) + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (which will be used to condition vision features in future frames) + maskmem_features = None + maskmem_pos_enc = None + if run_mem_encoder and self.num_maskmem > 0: + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats[-1], + pred_masks_high_res=sam_outputs.high_res_masks, + object_score_logits=sam_outputs.object_score_logits, + is_mask_from_pts=(point_inputs is not None or mask_inputs is not None), + ) + + current_out = { + "pred_masks": sam_outputs.pred_masks, + "object_pointer": sam_outputs.object_pointer, + "maskmem_features": maskmem_features if maskmem_features is not None else None, + "maskmem_pos_enc": maskmem_pos_enc, + } + if not self.training: + current_out["object_score_logits"] = sam_outputs.object_score_logits + + return current_out + + def _encode_new_memory( + self, + current_vision_feats: torch.Tensor, + pred_masks_high_res: torch.Tensor, + object_score_logits: torch.Tensor, + is_mask_from_pts: bool, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Encode the current image and its prediction into a memory feature.""" + batch_size = current_vision_feats.size(1) # batch size on this frame + channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width) + if is_mask_from_pts and not self.training: + # binarize the mask logits + mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype) + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc + mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc + + maskmem_features, maskmem_pos_enc = self.memory_encoder( + pix_feat, + mask_for_mem, + ) + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.occlusion_spatial_embedding_parameter is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[ + ..., None, None + ].expand(*maskmem_features.shape) + + maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype) + maskmem_features, maskmem_pos_enc = self.spatial_perceiver(maskmem_features, maskmem_pos_enc) + maskmem_features = maskmem_features.to(pred_masks_high_res.dtype) + maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype) + + return maskmem_features, maskmem_pos_enc + + @torch.inference_mode() + @auto_docstring( + custom_intro=""" + Propagate the objects through the video frames. Used when initializing an inference session with a whole video. + Yields EdgeTamVideoSegmentationOutput for each frame. + """ + ) + def propagate_in_video_iterator( + self, + inference_session: EdgeTamVideoInferenceSession, + start_frame_idx: Optional[int] = None, + max_frame_num_to_track: Optional[int] = None, + reverse: bool = False, + ) -> Iterator[EdgeTamVideoSegmentationOutput]: + r""" + inference_session (`EdgeTamVideoInferenceSession`): + The video inference session object. + start_frame_idx (`int`, *optional*): + The starting frame index for propagation. + Need to be provided if `forward` hasn't been called on new inputs yet. + If not provided, the starting frame index will be the earliest frame with input points. + max_frame_num_to_track (`int`, *optional*): + The maximum number of frames to track. + reverse (`bool`, *optional*, defaults to `False`): + Whether to propagate in reverse. + """ + num_frames = inference_session.num_frames + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + frames_with_inputs = [ + frame_idx + for obj_output_dict in inference_session.output_dict_per_obj.values() + for frame_idx in obj_output_dict["cond_frame_outputs"] + ] + if not frames_with_inputs: + raise ValueError( + "Cannot determine the starting frame index; please specify it manually, or run inference on a frame with inputs first." + ) + start_frame_idx = min(frames_with_inputs) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + edgetam_video_output = self(inference_session, frame_idx=frame_idx, reverse=reverse) + yield edgetam_video_output + + +__all__ = ["EdgeTamVideoModel", "EdgeTamVideoInferenceSession", "EdgeTamVideoPreTrainedModel"] diff --git a/src/transformers/models/edgetam_video/modular_edgetam_video.py b/src/transformers/models/edgetam_video/modular_edgetam_video.py new file mode 100644 index 000000000000..b520cd5a756b --- /dev/null +++ b/src/transformers/models/edgetam_video/modular_edgetam_video.py @@ -0,0 +1,1243 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch import Tensor + +from transformers.models.sam2.modeling_sam2 import ( + eager_attention_forward, + window_partition, +) +from transformers.utils.generic import OutputRecorder + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import ( + auto_docstring, +) +from ..auto import CONFIG_MAPPING, AutoConfig +from ..sam2_video.configuration_sam2_video import ( + Sam2VideoConfig, + Sam2VideoMaskDecoderConfig, + Sam2VideoPromptEncoderConfig, +) +from ..sam2_video.modeling_sam2_video import ( + Sam2VideoAttention, + Sam2VideoFeedForward, + Sam2VideoInferenceSession, + Sam2VideoLayerNorm, + Sam2VideoMemoryAttention, + Sam2VideoMemoryEncoder, + Sam2VideoMemoryFuserCXBlock, + Sam2VideoModel, + Sam2VideoPositionEmbeddingSine, + Sam2VideoPreTrainedModel, + Sam2VideoTwoWayAttentionBlock, + Sam2VideoVisionEncoderOutput, + Sam2VideoVisionRotaryEmbedding, + rotate_pairwise, +) + + +class EdgeTamVideoPromptEncoderConfig(Sam2VideoPromptEncoderConfig): + pass + + +class EdgeTamVideoMaskDecoderConfig(Sam2VideoMaskDecoderConfig): + pass + + +class EdgeTamVideoConfig(Sam2VideoConfig): + r""" + [`EdgeTamVideoConfig`] is the configuration class to store the configuration of a [`EdgeTamVideoModel`]. It is used to instantiate a + EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder + configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny + [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `EdgeTamVideoVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVideoVisionConfig`]. + prompt_encoder_config (Union[`dict`, `EdgeTamVideoPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamVideoPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `EdgeTamVideoMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`]. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation for parameter initialization. + num_maskmem (`int`, *optional*, defaults to 7): + The number of memory slots for the mask memory. + image_size (`int`, *optional*, defaults to 1024): + The size of the input images. + sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0): + Scale factor for the sigmoid function in the memory encoder. + sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0): + Bias for the sigmoid function in the memory encoder. + enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`): + Whether to enable spatial embedding for occlusions. + multimask_output_in_sam (`bool`, *optional*, defaults to `True`): + Whether to output multiple masks from the SAM head. + multimask_min_pt_num (`int`, *optional*, defaults to 0): + The minimum number of points to trigger multimask output. + multimask_max_pt_num (`int`, *optional*, defaults to 1): + The maximum number of points to trigger multimask output. + multimask_output_for_tracking (`bool`, *optional*, defaults to `True`): + Whether to use multimask output for tracking. + max_object_pointers_in_encoder (`int`, *optional*, defaults to 16): + The maximum number of object pointers in the encoder. + enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`): + Whether to enable temporal positional encoding for object pointers. + memory_attention_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the memory attention hidden states. + memory_attention_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the memory attention module. + memory_attention_num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer in the memory attention. + memory_attention_downsample_rate (`int`, *optional*, defaults to 1): + The downsample rate for the attention layers. + memory_attention_mlp_hidden_size (`int`, *optional*, defaults to 2048): + The dimension of the feedforward network in the memory attention module. + memory_attention_mlp_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the feedforward network in the memory attention module. + memory_attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the memory attention module. + memory_attention_rope_theta (`float`, *optional*, defaults to 10000): + The Rope theta parameter. + memory_attention_rope_feat_sizes (`Tuple[int, int]`, *optional*, defaults to `[64, 64]`): + The feature sizes for the Rope positional encoding. + memory_attention_rope_k_sizes (`List[int]`, *optional*, defaults to `[16, 16]`): + The key feature sizes for the RoPE positional encoding in memory attention. + memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1): + The dropout rate for the Rope positional encoding. + perceiver_resampler_num_latents (`int`, *optional*, defaults to 256): + The number of 1D latent tokens in the perceiver resampler. + perceiver_resampler_num_latents_2d (`int`, *optional*, defaults to 256): + The number of 2D latent tokens in the perceiver resampler. + perceiver_resampler_hidden_size (`int`, *optional*, defaults to 64): + The hidden size of the perceiver resampler. + perceiver_resampler_mlp_intermediate_size (`int`, *optional*, defaults to 256): + The intermediate size of the feedforward network in the perceiver resampler. + perceiver_resampler_num_attention_heads (`int`, *optional*, defaults to 1): + The number of attention heads in the perceiver resampler. + perceiver_resampler_attention_head_dim (`int`, *optional*, defaults to 64): + The dimension of each attention head in the perceiver resampler. + perceiver_resampler_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the perceiver resampler. + perceiver_resampler_hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the hidden layers in the perceiver resampler. + perceiver_resampler_attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the attention layers in the perceiver resampler. + memory_encoder_hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the memory encoder hidden states. + memory_encoder_output_channels (`int`, *optional*, defaults to 64): + The number of output channels for the memory encoder. + mask_downsampler_embed_dim (`int`, *optional*, defaults to 256): + The dimension of the mask downsampler embedding. + memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024): + The intermediate dimension of the memory fuser feedforward network. + mask_downsampler_kernel_size (`int`, *optional*, defaults to 3): + The kernel size for the mask downsampler. + mask_downsampler_stride (`int`, *optional*, defaults to 2): + The stride for the mask downsampler. + mask_downsampler_padding (`int`, *optional*, defaults to 1): + The padding for the mask downsampler. + mask_downsampler_total_stride (`int`, *optional*, defaults to 16): + The total stride for the mask downsampler. + mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the mask downsampler. + memory_fuser_num_layers (`int`, *optional*, defaults to 2): + The number of layers in the memory fuser. + memory_fuser_embed_dim (`int`, *optional*, defaults to 256): + The dimension of the memory fuser embedding. + memory_fuser_kernel_size (`int`, *optional*, defaults to 7): + The kernel size for the memory fuser. + memory_fuser_padding (`int`, *optional*, defaults to 3): + The padding for the memory fuser. + memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06): + The initial value for the layer scale in the memory fuser. + memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the memory fuser. + + Example: + + ```python + >>> from transformers import ( + ... EdgeTamVisionConfig, + ... EdgeTamVideoPromptEncoderConfig, + ... EdgeTamVideoMaskDecoderConfig, + ... EdgeTamVideoModel, + ... EdgeTamVideoConfig, + ... ) + + >>> # Initializing a EdgeTamVideoConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> configuration = EdgeTamVideoConfig() + + >>> # Initializing a EdgeTamVideoModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration + >>> model = EdgeTamVideoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig + + >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations + >>> vision_config = EdgeTamVisionConfig() + >>> prompt_encoder_config = EdgeTamVideoPromptEncoderConfig() + >>> mask_decoder_config = EdgeTamVideoMaskDecoderConfig() + + >>> config = EdgeTamVideoConfig(vision_config, prompt_encoder_config, mask_decoder_config) + ```""" + + model_type = "edgetam_video" + sub_configs = { + "vision_config": AutoConfig, + "prompt_encoder_config": EdgeTamVideoPromptEncoderConfig, + "mask_decoder_config": EdgeTamVideoMaskDecoderConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + enable_occlusion_spatial_embedding=True, + multimask_output_in_sam=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + multimask_output_for_tracking=True, + max_object_pointers_in_encoder=16, + enable_temporal_pos_encoding_for_object_pointers=True, + # memory attention + memory_attention_hidden_size=256, + memory_attention_num_layers=2, + memory_attention_num_attention_heads=1, + memory_attention_downsample_rate=1, + memory_attention_mlp_hidden_size=2048, + memory_attention_mlp_hidden_act="relu", + memory_attention_dropout=0.1, + memory_attention_rope_theta=10000, + memory_attention_rope_feat_sizes=None, + memory_attention_rope_k_sizes=None, + memory_attention_rope_dropout=0.1, + # spatial perceiver resampler + perceiver_resampler_num_latents=256, + perceiver_resampler_num_latents_2d=256, + perceiver_resampler_hidden_size=64, + perceiver_resampler_mlp_intermediate_size=256, + perceiver_resampler_num_attention_heads=1, + perceiver_resampler_attention_head_dim=64, + perceiver_resampler_num_layers=2, + perceiver_resampler_hidden_dropout=0.0, + perceiver_resampler_attention_dropout=0.0, + # memory encoder + memory_encoder_hidden_size=256, + memory_encoder_output_channels=64, + mask_downsampler_embed_dim=256, + memory_fuser_intermediate_dim=1024, + mask_downsampler_kernel_size=3, + mask_downsampler_stride=2, + mask_downsampler_padding=1, + mask_downsampler_total_stride=16, + mask_downsampler_hidden_act="gelu", + memory_fuser_num_layers=2, + memory_fuser_embed_dim=256, + memory_fuser_kernel_size=7, + memory_fuser_padding=3, + memory_fuser_layer_scale_init_value=1e-6, + memory_fuser_hidden_act="gelu", + **kwargs, + ): + PretrainedConfig.__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + memory_attention_rope_feat_sizes = ( + [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes + ) + memory_attention_rope_k_sizes = ( + [16, 16] if memory_attention_rope_k_sizes is None else memory_attention_rope_k_sizes + ) + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + if isinstance(prompt_encoder_config, EdgeTamVideoPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, EdgeTamVideoMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = vision_config + self.prompt_encoder_config = EdgeTamVideoPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = EdgeTamVideoMaskDecoderConfig(**mask_decoder_config) + + self.initializer_range = initializer_range + self.num_maskmem = num_maskmem # default 1 input frame + 6 previous frames + self.image_size = image_size + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc # scale factor for mask sigmoid prob + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc # bias factor for mask sigmoid prob + self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.max_object_pointers_in_encoder = max_object_pointers_in_encoder + self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers + + # memory attention + self.memory_attention_hidden_size = memory_attention_hidden_size + self.memory_attention_num_layers = memory_attention_num_layers + self.memory_attention_num_attention_heads = memory_attention_num_attention_heads + self.memory_attention_downsample_rate = memory_attention_downsample_rate + self.memory_attention_mlp_hidden_size = memory_attention_mlp_hidden_size + self.memory_attention_mlp_hidden_act = memory_attention_mlp_hidden_act + self.memory_attention_dropout = memory_attention_dropout + self.memory_attention_rope_theta = memory_attention_rope_theta + self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes + self.memory_attention_rope_k_sizes = memory_attention_rope_k_sizes + self.memory_attention_rope_dropout = memory_attention_rope_dropout + + # spatial perceiver resampler + self.perceiver_resampler_num_latents = perceiver_resampler_num_latents + self.perceiver_resampler_num_latents_2d = perceiver_resampler_num_latents_2d + self.perceiver_resampler_hidden_size = perceiver_resampler_hidden_size + self.perceiver_resampler_mlp_intermediate_size = perceiver_resampler_mlp_intermediate_size + self.perceiver_resampler_attention_head_dim = perceiver_resampler_attention_head_dim + self.perceiver_resampler_num_attention_heads = perceiver_resampler_num_attention_heads + self.perceiver_resampler_num_layers = perceiver_resampler_num_layers + self.perceiver_resampler_hidden_dropout = perceiver_resampler_hidden_dropout + self.perceiver_resampler_attention_dropout = perceiver_resampler_attention_dropout + + # memory encoder + self.memory_encoder_hidden_size = memory_encoder_hidden_size + self.memory_encoder_output_channels = memory_encoder_output_channels + self.mask_downsampler_embed_dim = mask_downsampler_embed_dim + self.mask_downsampler_kernel_size = mask_downsampler_kernel_size + self.mask_downsampler_stride = mask_downsampler_stride + self.mask_downsampler_padding = mask_downsampler_padding + self.mask_downsampler_total_stride = mask_downsampler_total_stride + self.mask_downsampler_hidden_act = mask_downsampler_hidden_act + self.memory_fuser_num_layers = memory_fuser_num_layers + self.memory_fuser_embed_dim = memory_fuser_embed_dim + self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim + self.memory_fuser_kernel_size = memory_fuser_kernel_size + self.memory_fuser_padding = memory_fuser_padding + self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value + self.memory_fuser_hidden_act = memory_fuser_hidden_act + + +class EdgeTamVideoLayerNorm(Sam2VideoLayerNorm): + pass + + +class EdgeTamVideoMemoryFuserCXBlock(Sam2VideoMemoryFuserCXBlock): + pass + + +class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput): + pass + + +class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding): + def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None): + nn.Module.__init__() + dim = config.memory_attention_hidden_size // ( + config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads + ) + # Ensure even dimension for proper axial splitting + if dim % 4 != 0: + raise ValueError("Dimension must be divisible by 4 for axial RoPE") + end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y) + freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + # Generate 2D position indices for axial rotary embedding + flattened_indices = torch.arange(end_x * end_y, dtype=torch.long) + x_positions = flattened_indices % end_x + y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") + freqs_x = torch.outer(x_positions, freqs).float() + freqs_y = torch.outer(y_positions, freqs).float() + inv_freq = torch.cat([freqs_x, freqs_y], dim=-1) + inv_freq = inv_freq.repeat_interleave(2, dim=-1) + # directly register the cos and sin embeddings as we have a fixed feature shape + self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) + self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) + + +class EdgeTamVideoAttention(Sam2VideoAttention): + pass + + +def apply_rotary_pos_emb_2d_self_attn( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding to query and key tensors for self-attention. + + Args: + q: Query tensor of shape (..., seq_len, head_dim) + k: Key tensor of shape (..., seq_len, head_dim) + cos: Cosine position embedding of shape (seq_len, head_dim) + sin: Sine position embedding of shape (seq_len, head_dim) + + Returns: + Rotated (q, k) tensors + """ + # Apply RoPE to queries + q_embed = q.float() # force upscale to float32 as in the original implementation + q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) + + # Apply RoPE to keys (same embeddings as queries for self-attention) + k_embed = k.float() # force upscale to float32 as in the original implementation + k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin) + + return q_embed.type_as(q), k_embed.type_as(k) + + +def apply_rotary_pos_emb_2d_cross_attn( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cos_k: torch.Tensor, + sin_k: torch.Tensor, + num_k_exclude_rope: int = 0, + repeat_freqs_k: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding to query and key tensors for cross-attention. + + Args: + q: Query tensor of shape (..., seq_len, head_dim) + k: Key tensor of shape (..., seq_len, head_dim) + cos: Cosine position embedding of shape (seq_len, head_dim) + sin: Sine position embedding of shape (seq_len, head_dim) + cos_k: Cosine position embedding for keys of shape (seq_len, head_dim) + sin_k: Sine position embedding for keys of shape (seq_len, head_dim) + num_k_exclude_rope: Number of tokens at end of k to exclude from RoPE (e.g., object pointer tokens) + repeat_freqs_k: Frequency repetition for keys in cross-attention (e.g., for spatial memory tokens) + + Returns: + Rotated (q, k) tensors + """ + # Apply RoPE to queries (always straightforward) + q_embed = q.float() + q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) + + # Split keys: RoPE tokens and excluded tokens (e.g., object pointers) + num_total_k_tokens = k.shape[-2] + k_for_rope = k[..., : num_total_k_tokens - num_k_exclude_rope, :] + k_excluded = k[..., num_total_k_tokens - num_k_exclude_rope :, :] + + # Early return if no keys need RoPE + if k_for_rope.shape[-2] == 0: + return q_embed.type_as(q), k_excluded + + batch_size, num_heads, k_seq_len, channels_per_head = k_for_rope.shape + + # Handle temporal/spatial token structure for memory + # Keys have temporal + spatial structure, only spatial tokens get RoPE + tokens_per_group = k_seq_len // repeat_freqs_k + spatial_tokens = cos_k.shape[-2] + temporal_tokens = tokens_per_group - spatial_tokens + + # Reshape and separate temporal/spatial tokens + k_grouped = k_for_rope.view(batch_size, num_heads, repeat_freqs_k, tokens_per_group, channels_per_head) + k_temporal = k_grouped[..., :temporal_tokens, :].reshape(batch_size, num_heads, -1, channels_per_head) + k_spatial = k_grouped[..., temporal_tokens:, :].reshape(batch_size, num_heads, -1, channels_per_head) + + # Only apply RoPE to spatial tokens + k_rope_input = k_spatial + + # Prepare position embeddings for repeated groups + if repeat_freqs_k > 1: + cos_k = cos_k.repeat(1, 1, repeat_freqs_k, 1) + sin_k = sin_k.repeat(1, 1, repeat_freqs_k, 1) + + # Apply RoPE to spatial tokens + k_spatial_embed = k_rope_input.float() + k_spatial_embed = (k_spatial_embed * cos_k) + (rotate_pairwise(k_spatial_embed) * sin_k) + + # Reconstruct: temporal + spatial tokens back to original structure + k_spatial_reshaped = k_spatial_embed.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head) + k_temporal_reshaped = k_temporal.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head) + k_final = torch.cat([k_temporal_reshaped, k_spatial_reshaped], dim=3) + k_final = k_final.view(batch_size, num_heads, k_seq_len, channels_per_head) + + # Combine RoPE-processed keys with excluded tokens + k_embed = torch.cat([k_final.type_as(k), k_excluded], dim=-2) + return q_embed.type_as(q), k_embed + + +class EdgeTamVideoRoPESelfAttention(nn.Module): + """Self-attention with rotary position encoding.""" + + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate + self.num_attention_heads = config.memory_attention_num_attention_heads + self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.dropout_p = config.memory_attention_rope_dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tensor: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + cos, sin = position_embeddings + # Apply rotary position encoding for self-attention + query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EdgeTamVideoRoPECrossAttention(nn.Module): + """Cross-attention with rotary position encoding.""" + + def __init__(self, config: EdgeTamVideoConfig, kv_in_dim: int): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate + self.num_attention_heads = config.memory_attention_num_attention_heads + self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.kv_in_dim = kv_in_dim + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.o_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.dropout_p = config.memory_attention_rope_dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_embeddings_k: tuple[torch.Tensor, torch.Tensor], + num_k_exclude_rope: int = 0, + rope_k_repeat: int = 0, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tensor: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos_k, sin_k = position_embeddings_k + # Apply rotary position encoding for cross-attention + query, key = apply_rotary_pos_emb_2d_cross_attn( + query, + key, + cos=cos, + sin=sin, + cos_k=cos_k, + sin_k=sin_k, + repeat_freqs_k=rope_k_repeat, + num_k_exclude_rope=num_k_exclude_rope, + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EdgeTamVideoTwoWayAttentionBlock(Sam2VideoTwoWayAttentionBlock): + pass + + +class EdgeTamVideoPositionEmbeddingSine(Sam2VideoPositionEmbeddingSine): + # maxsize=2 because we need to cache the forward method for both memory encoder and perceiver resampler + @compile_compatible_method_lru_cache(maxsize=2) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +class EdgeTamVideoMemoryEncoder(Sam2VideoMemoryEncoder): + pass + + +class EdgeTamVideoFeedForward(Sam2VideoFeedForward): + pass + + +class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel): + pass + + +class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession): + pass + + +class EdgeTamVideoMemoryAttentionMLP(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.intermediate_size = config.memory_attention_mlp_hidden_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size) + self.dropout = nn.Dropout(config.memory_attention_dropout) + self.act_fn = ACT2FN[config.memory_attention_mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.dropout(self.act_fn(self.up_proj(x)))) + + +class EdgeTamVideoMemoryAttentionLayer(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + hidden_size = config.memory_attention_hidden_size + self.self_attn = EdgeTamVideoRoPESelfAttention(config) + self.cross_attn_image = EdgeTamVideoRoPECrossAttention(config, kv_in_dim=64) + + # MLP module + self.mlp = EdgeTamVideoMemoryAttentionMLP(config) + + self.layer_norm1 = nn.LayerNorm(hidden_size) + self.layer_norm2 = nn.LayerNorm(hidden_size) + self.layer_norm3 = nn.LayerNorm(hidden_size) + self.dropout1 = nn.Dropout(config.memory_attention_dropout) + self.dropout2 = nn.Dropout(config.memory_attention_dropout) + self.dropout3 = nn.Dropout(config.memory_attention_dropout) + + def forward( + self, + queries: Tensor, + keys: Tensor, + key_point_embedding: Tensor, + rope_position_embeddings: tuple[Tensor, Tensor], + rope_position_embeddings_k: Optional[tuple[Tensor, Tensor]] = None, + num_k_exclude_rope: int = 0, + rope_k_repeat: int = 0, + ) -> torch.Tensor: + # Self-Attention + query = self.layer_norm1(queries) + query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings) + queries = queries + self.dropout1(query) + + # Cross-Attention + query = self.layer_norm2(queries) + query, _ = self.cross_attn_image( + query=query, + key=keys + key_point_embedding, + value=keys, + position_embeddings=rope_position_embeddings, + position_embeddings_k=rope_position_embeddings_k, + num_k_exclude_rope=num_k_exclude_rope, + rope_k_repeat=rope_k_repeat, + ) + queries = queries + self.dropout2(query) + # MLP + query = self.layer_norm3(queries) + query = self.mlp(query) + queries = queries + self.dropout3(query) + return queries + + +class EdgeTamVideoMemoryAttention(Sam2VideoMemoryAttention): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.rotary_emb_k = EdgeTamVideoVisionRotaryEmbedding( + config, end_x=config.memory_attention_rope_k_sizes[0], end_y=config.memory_attention_rope_k_sizes[1] + ) + + def forward( + self, + current_vision_features: torch.Tensor, + memory: torch.Tensor, + current_vision_position_embeddings: Optional[Tensor] = None, + memory_posision_embeddings: Optional[Tensor] = None, + num_object_pointer_tokens: int = 0, + num_spatial_memory_tokens: int = -1, + ): + """ + Args: + current_vision_features (`torch.FloatTensor`): + The current vision features used for self-attention. + memory (`torch.FloatTensor`): + The memory features used for cross-attention. + current_vision_position_embeddings (`torch.FloatTensor`, *optional*): + The position embeddings for the current vision features. + memory_posision_embeddings (`torch.FloatTensor`, *optional*): + The position embeddings for the memory features. + num_object_pointer_tokens (`int`, *optional*, defaults to 0): + The number of object pointer tokens. + """ + output = current_vision_features + if current_vision_position_embeddings is not None: + output = output + 0.1 * current_vision_position_embeddings + + # Convert to batch first + output = output.transpose(0, 1) + memory = memory.transpose(0, 1).unsqueeze(1) + memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1) + rope_position_embeddings = self.rotary_emb() + rope_position_embeddings_k = self.rotary_emb_k() + for layer in self.layers: + output = layer( + queries=output.unsqueeze(1) if output.ndim == 3 else output, + keys=memory, + key_point_embedding=memory_posision_embeddings, + rope_position_embeddings=rope_position_embeddings, + rope_position_embeddings_k=rope_position_embeddings_k, + num_k_exclude_rope=num_object_pointer_tokens, + rope_k_repeat=num_spatial_memory_tokens, + ) + + normed_output = self.layer_norm(output) + + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + + return normed_output + + +class EdgeTamVideoPerceiverMLP(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.hidden_size = config.perceiver_resampler_hidden_size + self.intermediate_size = config.perceiver_resampler_mlp_intermediate_size + + self.layer_norm = nn.LayerNorm(self.hidden_size) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states))) + return hidden_states + + +class EdgeTamVideoPerceiverAttention(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.perceiver_resampler_hidden_size + self.num_attention_heads = config.perceiver_resampler_num_attention_heads + self.head_dim = config.perceiver_resampler_attention_head_dim + self.attention_dropout = config.perceiver_resampler_attention_dropout + + self.inner_dim = self.head_dim * self.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # Project queries, keys, and values + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + # Reshape for multi-head attention + batch_size, seq_len_q = query.shape[:2] + query = query.view(batch_size, seq_len_q, self.num_attention_heads, self.head_dim).transpose(1, 2) + seq_len_kv = key.shape[1] + key = key.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2) + + # Add positional encoding if provided + if positional_encoding is not None: + pos_encoding = positional_encoding.view( + batch_size, seq_len_kv, self.num_attention_heads, self.head_dim + ).transpose(1, 2) + key = key + pos_encoding + value = value + pos_encoding + + # Apply attention + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + # Reshape output + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.inner_dim) + return self.o_proj(attn_output) + + +class EdgeTamVideoPerceiverEncoderLayer(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + + self.cross_attention = EdgeTamVideoPerceiverAttention(config) + self.mlp = EdgeTamVideoPerceiverMLP(config) + self.dropout = nn.Dropout(config.perceiver_resampler_hidden_dropout) + + self.self_attention = EdgeTamVideoPerceiverAttention(config) + self.self_mlp = EdgeTamVideoPerceiverMLP(config) + + # Layer norms moved from attention classes to here + self.layer_norm_input = nn.LayerNorm(config.perceiver_resampler_hidden_size) + self.layer_norm_latents = nn.LayerNorm(config.perceiver_resampler_hidden_size) + self.layer_norm_self = nn.LayerNorm(config.perceiver_resampler_hidden_size) + + def forward( + self, + latents: torch.Tensor, + input_features: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Cross attention with layer norms + normalized_latents = self.layer_norm_latents(latents) + normalized_input = self.layer_norm_input(input_features) + cross_attention_output = self.cross_attention( + query=normalized_latents, + key=normalized_input, + value=normalized_input, + positional_encoding=positional_encoding, + ) + latents = latents + self.dropout(cross_attention_output) + + mlp_output = self.mlp(latents) + latents = latents + mlp_output + + # Self attention with layer norm + normalized_latents_self = self.layer_norm_self(latents) + self_attention_output = self.self_attention( + query=normalized_latents_self, key=normalized_latents_self, value=normalized_latents_self + ) + latents = latents + self_attention_output + + self_mlp_output = self.self_mlp(latents) + latents = latents + self_mlp_output + + return latents + + +class EdgeTamVideoPerceiverResampler(nn.Module): + def __init__(self, config: EdgeTamVideoConfig): + super().__init__() + self.config = config + self.hidden_size = config.perceiver_resampler_hidden_size + self.num_latents_1d = config.perceiver_resampler_num_latents + self.num_latents_2d = config.perceiver_resampler_num_latents_2d + self.num_layers = config.perceiver_resampler_num_layers + + if self.num_latents_1d > 0: + self.latents_1d = nn.Parameter(torch.randn(self.num_latents_1d, self.hidden_size)) + if self.num_latents_2d > 0: + self.latents_2d = nn.Parameter(torch.randn(self.num_latents_2d, self.hidden_size)) + + self.positional_encoding = EdgeTamVideoPositionEmbeddingSine( + num_pos_feats=self.hidden_size // 2, normalize=True + ) + + self.layers = nn.ModuleList([EdgeTamVideoPerceiverEncoderLayer(config) for _ in range(self.num_layers)]) + + self.layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + output_latents = [] + output_positional_encodings = [] + + if self.num_latents_1d > 0: + latents_1d, pos_1d = self._forward_1d(hidden_states, positional_encoding) + output_latents.append(latents_1d) + output_positional_encodings.append(pos_1d) + + if self.num_latents_2d > 0: + latents_2d, pos_2d = self._forward_2d(hidden_states) + output_latents.append(latents_2d) + output_positional_encodings.append(pos_2d) + + combined_latents = torch.cat(output_latents, dim=1) + + combined_positional_encoding = None + if positional_encoding is not None and output_positional_encodings: + combined_positional_encoding = torch.cat(output_positional_encodings, dim=1) + + return combined_latents, combined_positional_encoding + + def _forward_1d( + self, + hidden_states: torch.Tensor, + positional_encoding: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size = hidden_states.shape[0] + + latents = self.latents_1d.unsqueeze(0).expand(batch_size, -1, -1) + flattened_features = hidden_states.permute(0, 2, 3, 1).flatten(1, 2) + + positional_features = None + if positional_encoding is not None: + positional_features = positional_encoding.permute(0, 2, 3, 1).flatten(1, 2) + + for layer in self.layers: + latents = layer(latents, flattened_features, positional_features) + + latents = self.layer_norm(latents) + + output_positional_encoding = None + if positional_encoding is not None: + output_positional_encoding = torch.zeros_like(latents) + + return latents, output_positional_encoding + + def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, channels, height, width = hidden_states.shape + + latents_2d = self.latents_2d.unsqueeze(0).expand(batch_size, -1, -1).view(-1, 1, channels) + + num_windows_per_dim = int(math.sqrt(self.num_latents_2d)) + window_size = height // num_windows_per_dim + + windowed_input = hidden_states.permute(0, 2, 3, 1) + windowed_features, _ = window_partition(windowed_input, window_size) + windowed_features = windowed_features.flatten(1, 2) + + for layer in self.layers: + latents_2d = layer(latents_2d, windowed_features, positional_encoding=None) + + latents_2d = latents_2d.view(batch_size, num_windows_per_dim, num_windows_per_dim, channels).permute( + 0, 3, 1, 2 + ) + + positional_encoding_2d = self.positional_encoding(latents_2d.shape, latents_2d.device, latents_2d.dtype).to( + dtype=hidden_states.dtype + ) + positional_encoding_2d = positional_encoding_2d.permute(0, 2, 3, 1).flatten(1, 2) + + latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2) + latents_2d = self.layer_norm(latents_2d) + + return latents_2d, positional_encoding_2d + + +@auto_docstring +class EdgeTamVideoModel(Sam2VideoModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _keys_to_ignore_on_load_unexpected = [] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)} + + def __init__(self, config: EdgeTamVideoConfig): + super().__init__(config) + self.spatial_perceiver = EdgeTamVideoPerceiverResampler(config) + + self.post_init() + + def _build_memory_attention_inputs( + self, + temporal_positions_and_previous_outputs: list[tuple[int, dict]], + device: torch.device, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Concatenate memory features and positional embeddings from previous frames. + + Returns: + Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate). + """ + memories_to_concatenate = [] + memory_positional_embeddings_to_concatenate = [] + + for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: + if prev_output_data is None: + continue # Skip if no output data for this temporal position (e.g., padding frames) + + # Load memory features (potentially from CPU to GPU) + # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) + memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) + memories_to_concatenate.append(memory_features.permute(1, 0, 2)) + + # Spatial positional encoding (potentially from CPU to GPU) + spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) + spatial_memory_pos_embed = spatial_memory_pos_embed.squeeze(1).permute(1, 0, 2) + + # Add temporal positional encoding + # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) + combined_memory_pos_embed = ( + spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] + ) + memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) + + return memories_to_concatenate, memory_positional_embeddings_to_concatenate + + def _prepare_memory_conditioned_features( + self, + inference_session: EdgeTamVideoInferenceSession, + frame_idx: int, + obj_idx: int, + is_initial_conditioning_frame: bool, + current_vision_features: list[torch.Tensor], + current_vision_positional_embeddings: list[torch.Tensor], + num_total_frames: int, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> torch.Tensor: + """ + Fuse current frame's visual features with memory from previous frames for enhanced object tracking. + + This method conditions the current frame's visual features on temporal memory from previous frames, + enabling consistent object tracking across video sequences. For initial conditioning frames, it uses + no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both + conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention. + + Args: + inference_session (`EdgeTamVideoInferenceSession`): + The video inference session object. + frame_idx (`int`): + Index of the current frame being processed. + obj_idx (`int`): + Index of the object being processed. + is_initial_conditioning_frame (`bool`): + Whether this is an initial conditioning frame with user inputs (True) or a subsequent + tracking frame (False). + current_vision_features (`torch.Tensor`): + Highest-level vision features of shape `(seq_len, batch_size, channels)`. + current_vision_positional_embeddings (`torch.Tensor`): + Positional embedding tensors corresponding to the highest-level vision features. + num_total_frames (`int`): + Total number of frames in the video sequence. + track_in_reverse_time (`bool`, *optional*, defaults to `False`): + Whether tracking is performed in reverse temporal order. + streaming (`bool`, *optional*, defaults to `False`): + Whether this is streaming inference mode. + + Returns: + `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)` + suitable for input to the SAM decoder. + """ + # Get dimensions from the highest-level (lowest-resolution) feature map + batch_size = current_vision_features.size(1) + num_channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] + device = current_vision_features.device + + # If memory is disabled (e.g., for single image SAM), return current features directly. + if self.num_maskmem == 0: + # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width) + # Assuming SeqLen = Height * Width for the last feature map + current_feature_map = current_vision_features.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return current_feature_map + + # Step 1: Handle initial conditioning frames + if is_initial_conditioning_frame: + # For initial conditioning frames, no prior memory is used directly in this block. + # If configured, directly add a learnable "no memory" embedding. + # current_vision_features has shape (SeqLen, Batch, Channels) + conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding + # Reshape to (Batch, Channels, Height, Width) + conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return conditioned_feature_map + + # Step 2: Get memory frames and concatenate their features + temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs( + inference_session, obj_idx, frame_idx, track_in_reverse_time + ) + + memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs( + temporal_positions_and_previous_outputs, device + ) + num_spatial_memory_tokens = len(memories_to_concatenate) + + # Step 3: Get and process object pointers + temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers( + inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming + ) + + num_object_pointer_tokens = 0 + if pointer_tokens: + object_pointers, object_pointers_pos_embed = self._process_object_pointers( + temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device + ) + + if object_pointers is not None: + memories_to_concatenate.append(object_pointers) + memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) + num_object_pointer_tokens = object_pointers.shape[0] + + # Step 4: Concatenate all retrieved memories and their positional embeddings + combined_memory = torch.cat(memories_to_concatenate, dim=0) + combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0) + + # Step 5: Forward through the memory attention mechanism + conditioned_feature_map_flat = self.memory_attention( + current_vision_features=current_vision_features, + current_vision_position_embeddings=current_vision_positional_embeddings, + memory=combined_memory, + memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API + num_object_pointer_tokens=num_object_pointer_tokens, + num_spatial_memory_tokens=num_spatial_memory_tokens, + ) + + # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width) + conditioned_feature_map = ( + conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width) + ) + return conditioned_feature_map + + def _encode_new_memory( + self, + current_vision_feats: torch.Tensor, + pred_masks_high_res: torch.Tensor, + object_score_logits: torch.Tensor, + is_mask_from_pts: bool, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Encode the current image and its prediction into a memory feature.""" + batch_size = current_vision_feats.size(1) # batch size on this frame + channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width) + if is_mask_from_pts and not self.training: + # binarize the mask logits + mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype) + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc + mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc + + maskmem_features, maskmem_pos_enc = self.memory_encoder( + pix_feat, + mask_for_mem, + ) + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.occlusion_spatial_embedding_parameter is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[ + ..., None, None + ].expand(*maskmem_features.shape) + + maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype) + maskmem_features, maskmem_pos_enc = self.spatial_perceiver(maskmem_features, maskmem_pos_enc) + maskmem_features = maskmem_features.to(pred_masks_high_res.dtype) + maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype) + + return maskmem_features, maskmem_pos_enc + + +__all__ = [ + "EdgeTamVideoMaskDecoderConfig", + "EdgeTamVideoPromptEncoderConfig", + "EdgeTamVideoConfig", + "EdgeTamVideoModel", + "EdgeTamVideoInferenceSession", + "EdgeTamVideoPreTrainedModel", +] diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index 8e420bf27904..47b5b47d3630 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -551,8 +551,8 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): def to_dict(self) -> dict[str, Any]: output = super().to_dict() - output["num_classes"] = self.num_labels - output["label_names"] = list(self.id2label.values()) + output.setdefault("num_classes", self.num_labels) + output.setdefault("label_names", list(self.id2label.values())) output.pop("id2label", None) output.pop("label2id", None) return output diff --git a/src/transformers/models/sam2/configuration_sam2.py b/src/transformers/models/sam2/configuration_sam2.py index 8a93f28d5a20..e14583181d38 100644 --- a/src/transformers/models/sam2/configuration_sam2.py +++ b/src/transformers/models/sam2/configuration_sam2.py @@ -379,8 +379,6 @@ class Sam2Config(PretrainedConfig): Dictionary of configuration options used to initialize [`Sam2MaskDecoderConfig`]. initializer_range (`float`, *optional*, defaults to 0.02): Standard deviation for parameter initialization. - kwargs (*optional*): - Dictionary of keyword arguments. Example: diff --git a/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py b/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py index 322aa5507978..cc2ee0c7c612 100644 --- a/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py +++ b/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py @@ -190,7 +190,7 @@ def replace_keys(state_dict, config): if re.match(output_vision_encoder_neck_pattern, key): key = key.replace(".conv.", ".") - # memory_encoder.out_proj.weight -> memory_encoder.projection.weight + # memory_encoder.o_proj.weight -> memory_encoder.projection.weight if re.match(output_memory_encoder_projection_pattern, key): key = key.replace(".o_proj.", ".projection.") diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py index f4c1261d6779..caa07d1f63b5 100644 --- a/src/transformers/models/sam2_video/modeling_sam2_video.py +++ b/src/transformers/models/sam2_video/modeling_sam2_video.py @@ -134,8 +134,10 @@ def __init__( dtype: Union[torch.dtype, str] = "float32", max_vision_features_cache_size: int = 1, ): - # store as a list to avoid double memory allocation with torch.cat when adding new frames - self.processed_frames = list(video.to(video_storage_device, dtype=dtype)) if video is not None else None + # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames + self.processed_frames = ( + dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None + ) self.video_height = video_height self.video_width = video_width @@ -293,18 +295,21 @@ def get_output( return value # Video frame management - def add_new_frame(self, pixel_values: torch.Tensor) -> int: + def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int: """Add new frame with automatic device placement.""" pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True) if pixel_values.dim() == 4: pixel_values = pixel_values.squeeze(0) + if frame_idx is None: + frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0 + if self.processed_frames is None: - self.processed_frames = [pixel_values] + self.processed_frames = {frame_idx: pixel_values} else: - self.processed_frames.append(pixel_values) + self.processed_frames[frame_idx] = pixel_values - return self.num_frames - 1 + return frame_idx def get_frame(self, frame_idx: int) -> torch.Tensor: """Get frame from video.""" @@ -1714,7 +1719,7 @@ def forward( Whether to propagate in reverse. """ if frame is not None: - frame_idx = inference_session.add_new_frame(frame) + frame_idx = inference_session.add_new_frame(frame, frame_idx) if frame is not None and inference_session.get_obj_num() == 0: raise ValueError("No objects are provided for tracking; please add inputs first.") @@ -2097,6 +2102,195 @@ def _use_mask_as_output( image_embeddings=high_res_features + [backbone_features], ) + def _gather_memory_frame_outputs( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + track_in_reverse_time: bool = False, + ) -> list[tuple[int, dict]]: + """ + Get memory frames from conditioning and non-conditioning outputs. + + Returns: + List of (relative_temporal_offset, output_data) tuples. + """ + temporal_positions_and_previous_outputs = [] + + # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints) + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + if not conditioning_outputs: + raise ValueError( + "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" + ) + + # Store (temporal_position, output_data) tuples + temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] + + # Add non-conditioning memory frames (up to self.num_maskmem - 1) + # These are typically frames tracked by the model without direct user input. + # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. + for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): + # relative_temporal_offset: how many frames before (or after if reversing) the current frame + if not track_in_reverse_time: + previous_frame_idx = frame_idx - relative_temporal_offset + else: + previous_frame_idx = frame_idx + relative_temporal_offset + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + previous_frame_idx, None + ) + + temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) + + return temporal_positions_and_previous_outputs + + def _build_memory_attention_inputs( + self, + temporal_positions_and_previous_outputs: list[tuple[int, dict]], + device: torch.device, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Concatenate memory features and positional embeddings from previous frames. + + Returns: + Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate). + """ + memories_to_concatenate = [] + memory_positional_embeddings_to_concatenate = [] + + for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: + if prev_output_data is None: + continue # Skip if no output data for this temporal position (e.g., padding frames) + + # Load memory features (potentially from CPU to GPU) + # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) + memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) + memories_to_concatenate.append(memory_features) + + # Spatial positional encoding (potentially from CPU to GPU) + spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) + + # Add temporal positional encoding + # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) + combined_memory_pos_embed = ( + spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] + ) + memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) + + return memories_to_concatenate, memory_positional_embeddings_to_concatenate + + def _get_object_pointers( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + num_total_frames: int, + device: torch.device, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> tuple[list[int], list[torch.Tensor], int]: + """ + Get object pointers and their positional embeddings from past frames. + + Returns: + Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use). + """ + temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 + + # Determine max object pointers to use + if streaming: + max_object_pointers_to_use = self.config.max_object_pointers_in_encoder + else: + max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) + + temporal_offsets: list[int] = [] + pointer_tokens: list[torch.Tensor] = [] + + # Add object pointers from selected conditioning frames + # Optionally, only include pointers from past frames during evaluation + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + eligible_conditioning_outputs = conditioning_outputs + if not self.training: + eligible_conditioning_outputs = { + temporal_idx: out + for temporal_idx, out in conditioning_outputs.items() + if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) + } + + for temporal_idx, out_data in eligible_conditioning_outputs.items(): + temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier + temporal_offsets.append(temporal_difference) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) + for t_diff_offset in range(1, max_object_pointers_to_use): + ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset + if ref_frame_idx < 0 or ( + not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames + ): + break # Stop if frame index is out of bounds + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + ref_frame_idx, None + ) + if out_data is not None: + temporal_offsets.append(t_diff_offset) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + return temporal_offsets, pointer_tokens, max_object_pointers_to_use + + def _process_object_pointers( + self, + temporal_offsets: list[int], + pointer_tokens: list[torch.Tensor], + max_object_pointers_to_use: int, + batch_size: int, + num_channels: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Process object pointers and compute their positional embeddings. + + Returns: + Tuple of (object_pointers, object_pointers_pos_embed). + """ + if not pointer_tokens: + return None, None + + # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) + object_pointers = torch.stack(pointer_tokens, dim=0) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + max_temporal_diff = float(max_object_pointers_to_use - 1) + # Determine dimensionality for temporal positional encoding of pointers + pointer_tpos_dim = num_channels + + # Normalize temporal differences before sine PE calculation + normalized_temporal_diffs = ( + torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff + ) + sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) + projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) + object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim) + else: + object_pointers_pos_embed = object_pointers.new_zeros( + len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype + ) + + if self.mem_dim < num_channels: + # If memory dimension is smaller, reshape/split pointers and repeat positional encoding + num_splits = num_channels // self.mem_dim + object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) + object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( + 0, 1 + ) # (SeqLen_ptr*num_splits, Batch, MemDim) + object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) + + return object_pointers, object_pointers_pos_embed + def _prepare_memory_conditioned_features( self, inference_session: Sam2VideoInferenceSession, @@ -2157,135 +2351,9 @@ def _prepare_memory_conditioned_features( ) return current_feature_map - num_object_pointer_tokens = 0 - temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 - - # Step 1: Condition the visual features of the current frame on previous memories - if not is_initial_conditioning_frame: - # Retrieve memories encoded from previous frames - memories_to_concatenate = [] - memory_positional_embeddings_to_concatenate = [] - - # Ensure there are conditioning frame outputs to process - conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] - if not conditioning_outputs: - raise ValueError( - "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" - ) - - # Select a maximum number of temporally closest conditioning frames for cross-attention (no limit here, as is the case in the original checkpoints) - # Store (temporal_position, output_data) tuples - temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] - - # Add non-conditioning memory frames (up to self.num_maskmem - 1) - # These are typically frames tracked by the model without direct user input. - # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. - for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): - # relative_temporal_offset: how many frames before (or after if reversing) the current frame - if not track_in_reverse_time: - previous_frame_idx = frame_idx - relative_temporal_offset - else: - previous_frame_idx = frame_idx + relative_temporal_offset - - # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU - output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( - previous_frame_idx, None - ) - - temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) - - for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: - if prev_output_data is None: - continue # Skip if no output data for this temporal position (e.g., padding frames) - - # Load memory features (potentially from CPU to GPU) - # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) - memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) - memories_to_concatenate.append(memory_features) - - # Spatial positional encoding (potentially from CPU to GPU) - spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) - - # Add temporal positional encoding - # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) - combined_memory_pos_embed = ( - spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] - ) - memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) - - # Construct the list of past object pointers to be used in attention - if streaming: - max_object_pointers_to_use = self.config.max_object_pointers_in_encoder - else: - max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) - temporal_diff_and_pointers = [] - - # Add object pointers from selected conditioning frames - # Optionally, only include pointers from past frames during evaluation - eligible_conditioning_outputs = conditioning_outputs - if not self.training: - eligible_conditioning_outputs = { - temporal_idx: out - for temporal_idx, out in conditioning_outputs.items() - if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) - } - - for temporal_idx, out_data in eligible_conditioning_outputs.items(): - temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier - temporal_diff_and_pointers.append((temporal_difference, out_data["object_pointer"])) - - # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) - for t_diff_offset in range(1, max_object_pointers_to_use): - ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset - if ref_frame_idx < 0 or ( - not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames - ): - break # Stop if frame index is out of bounds - - # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU - out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( - ref_frame_idx, None - ) - if out_data is not None: - temporal_diff_and_pointers.append((t_diff_offset, out_data["object_pointer"])) - - if temporal_diff_and_pointers: - temporal_differences, object_pointers_list = zip(*temporal_diff_and_pointers) - # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) - object_pointers = torch.stack(object_pointers_list, dim=0) - - if self.config.enable_temporal_pos_encoding_for_object_pointers: - max_temporal_diff = float(max_object_pointers_to_use - 1) - # Determine dimensionality for temporal positional encoding of pointers - pointer_tpos_dim = num_channels - - # Normalize temporal differences before sine PE calculation - normalized_temporal_diffs = ( - torch.tensor(temporal_differences, device=device, dtype=torch.float32) / max_temporal_diff - ) - sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) - projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) - object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim) - else: - object_pointers_pos_embed = object_pointers.new_zeros( - len(temporal_differences), batch_size, self.mem_dim, dtype=object_pointers.dtype - ) - - if self.mem_dim < num_channels: - # If memory dimension is smaller, reshape/split pointers and repeat positional encoding - num_splits = num_channels // self.mem_dim - object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) - object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( - 0, 1 - ) # (SeqLen_ptr*num_splits, Batch, MemDim) - object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) - - memories_to_concatenate.append(object_pointers) - memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) - num_object_pointer_tokens = object_pointers.shape[0] - else: + # Step 1: Handle initial conditioning frames + if is_initial_conditioning_frame: # For initial conditioning frames, no prior memory is used directly in this block. - # The model might handle this with a special token or mechanism. # If configured, directly add a learnable "no memory" embedding. # current_vision_features has shape (SeqLen, Batch, Channels) conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding @@ -2295,11 +2363,36 @@ def _prepare_memory_conditioned_features( ) return conditioned_feature_map - # Step 2: Concatenate all retrieved memories and their positional embeddings. + # Step 2: Get memory frames and concatenate their features + temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs( + inference_session, obj_idx, frame_idx, track_in_reverse_time + ) + + memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs( + temporal_positions_and_previous_outputs, device + ) + + # Step 3: Get and process object pointers + temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers( + inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming + ) + + num_object_pointer_tokens = 0 + if pointer_tokens: + object_pointers, object_pointers_pos_embed = self._process_object_pointers( + temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device + ) + + if object_pointers is not None: + memories_to_concatenate.append(object_pointers) + memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) + num_object_pointer_tokens = object_pointers.shape[0] + + # Step 4: Concatenate all retrieved memories and their positional embeddings combined_memory = torch.cat(memories_to_concatenate, dim=0) combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0) - # Step 3: Forward through the memory attention mechanism. + # Step 5: Forward through the memory attention mechanism conditioned_feature_map_flat = self.memory_attention( current_vision_features=current_vision_features, current_vision_position_embeddings=current_vision_positional_embeddings, diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py index 53e10998b2a7..fa0d6c21d5e6 100644 --- a/src/transformers/models/sam2_video/modular_sam2_video.py +++ b/src/transformers/models/sam2_video/modular_sam2_video.py @@ -403,8 +403,10 @@ def __init__( dtype: Union[torch.dtype, str] = "float32", max_vision_features_cache_size: int = 1, ): - # store as a list to avoid double memory allocation with torch.cat when adding new frames - self.processed_frames = list(video.to(video_storage_device, dtype=dtype)) if video is not None else None + # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames + self.processed_frames = ( + dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None + ) self.video_height = video_height self.video_width = video_width @@ -562,18 +564,21 @@ def get_output( return value # Video frame management - def add_new_frame(self, pixel_values: torch.Tensor) -> int: + def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int: """Add new frame with automatic device placement.""" pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True) if pixel_values.dim() == 4: pixel_values = pixel_values.squeeze(0) + if frame_idx is None: + frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0 + if self.processed_frames is None: - self.processed_frames = [pixel_values] + self.processed_frames = {frame_idx: pixel_values} else: - self.processed_frames.append(pixel_values) + self.processed_frames[frame_idx] = pixel_values - return self.num_frames - 1 + return frame_idx def get_frame(self, frame_idx: int) -> torch.Tensor: """Get frame from video.""" @@ -1799,6 +1804,195 @@ def _use_mask_as_output( image_embeddings=high_res_features + [backbone_features], ) + def _gather_memory_frame_outputs( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + track_in_reverse_time: bool = False, + ) -> list[tuple[int, dict]]: + """ + Get memory frames from conditioning and non-conditioning outputs. + + Returns: + List of (relative_temporal_offset, output_data) tuples. + """ + temporal_positions_and_previous_outputs = [] + + # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints) + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + if not conditioning_outputs: + raise ValueError( + "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" + ) + + # Store (temporal_position, output_data) tuples + temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] + + # Add non-conditioning memory frames (up to self.num_maskmem - 1) + # These are typically frames tracked by the model without direct user input. + # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. + for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): + # relative_temporal_offset: how many frames before (or after if reversing) the current frame + if not track_in_reverse_time: + previous_frame_idx = frame_idx - relative_temporal_offset + else: + previous_frame_idx = frame_idx + relative_temporal_offset + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + previous_frame_idx, None + ) + + temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) + + return temporal_positions_and_previous_outputs + + def _build_memory_attention_inputs( + self, + temporal_positions_and_previous_outputs: list[tuple[int, dict]], + device: torch.device, + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Concatenate memory features and positional embeddings from previous frames. + + Returns: + Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate). + """ + memories_to_concatenate = [] + memory_positional_embeddings_to_concatenate = [] + + for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: + if prev_output_data is None: + continue # Skip if no output data for this temporal position (e.g., padding frames) + + # Load memory features (potentially from CPU to GPU) + # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) + memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) + memories_to_concatenate.append(memory_features) + + # Spatial positional encoding (potentially from CPU to GPU) + spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) + + # Add temporal positional encoding + # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) + combined_memory_pos_embed = ( + spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] + ) + memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) + + return memories_to_concatenate, memory_positional_embeddings_to_concatenate + + def _get_object_pointers( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + num_total_frames: int, + device: torch.device, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> tuple[list[int], list[torch.Tensor], int]: + """ + Get object pointers and their positional embeddings from past frames. + + Returns: + Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use). + """ + temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 + + # Determine max object pointers to use + if streaming: + max_object_pointers_to_use = self.config.max_object_pointers_in_encoder + else: + max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) + + temporal_offsets: list[int] = [] + pointer_tokens: list[torch.Tensor] = [] + + # Add object pointers from selected conditioning frames + # Optionally, only include pointers from past frames during evaluation + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + eligible_conditioning_outputs = conditioning_outputs + if not self.training: + eligible_conditioning_outputs = { + temporal_idx: out + for temporal_idx, out in conditioning_outputs.items() + if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) + } + + for temporal_idx, out_data in eligible_conditioning_outputs.items(): + temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier + temporal_offsets.append(temporal_difference) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) + for t_diff_offset in range(1, max_object_pointers_to_use): + ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset + if ref_frame_idx < 0 or ( + not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames + ): + break # Stop if frame index is out of bounds + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + ref_frame_idx, None + ) + if out_data is not None: + temporal_offsets.append(t_diff_offset) + pointer_tokens.append(out_data["object_pointer"].to(device)) + + return temporal_offsets, pointer_tokens, max_object_pointers_to_use + + def _process_object_pointers( + self, + temporal_offsets: list[int], + pointer_tokens: list[torch.Tensor], + max_object_pointers_to_use: int, + batch_size: int, + num_channels: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Process object pointers and compute their positional embeddings. + + Returns: + Tuple of (object_pointers, object_pointers_pos_embed). + """ + if not pointer_tokens: + return None, None + + # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) + object_pointers = torch.stack(pointer_tokens, dim=0) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + max_temporal_diff = float(max_object_pointers_to_use - 1) + # Determine dimensionality for temporal positional encoding of pointers + pointer_tpos_dim = num_channels + + # Normalize temporal differences before sine PE calculation + normalized_temporal_diffs = ( + torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff + ) + sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) + projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) + object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim) + else: + object_pointers_pos_embed = object_pointers.new_zeros( + len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype + ) + + if self.mem_dim < num_channels: + # If memory dimension is smaller, reshape/split pointers and repeat positional encoding + num_splits = num_channels // self.mem_dim + object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) + object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( + 0, 1 + ) # (SeqLen_ptr*num_splits, Batch, MemDim) + object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) + + return object_pointers, object_pointers_pos_embed + def _prepare_memory_conditioned_features( self, inference_session: Sam2VideoInferenceSession, @@ -1859,135 +2053,9 @@ def _prepare_memory_conditioned_features( ) return current_feature_map - num_object_pointer_tokens = 0 - temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 - - # Step 1: Condition the visual features of the current frame on previous memories - if not is_initial_conditioning_frame: - # Retrieve memories encoded from previous frames - memories_to_concatenate = [] - memory_positional_embeddings_to_concatenate = [] - - # Ensure there are conditioning frame outputs to process - conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] - if not conditioning_outputs: - raise ValueError( - "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" - ) - - # Select a maximum number of temporally closest conditioning frames for cross-attention (no limit here, as is the case in the original checkpoints) - # Store (temporal_position, output_data) tuples - temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] - - # Add non-conditioning memory frames (up to self.num_maskmem - 1) - # These are typically frames tracked by the model without direct user input. - # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. - for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): - # relative_temporal_offset: how many frames before (or after if reversing) the current frame - if not track_in_reverse_time: - previous_frame_idx = frame_idx - relative_temporal_offset - else: - previous_frame_idx = frame_idx + relative_temporal_offset - - # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU - output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( - previous_frame_idx, None - ) - - temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) - - for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: - if prev_output_data is None: - continue # Skip if no output data for this temporal position (e.g., padding frames) - - # Load memory features (potentially from CPU to GPU) - # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) - memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True) - memories_to_concatenate.append(memory_features) - - # Spatial positional encoding (potentially from CPU to GPU) - spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True) - - # Add temporal positional encoding - # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) - combined_memory_pos_embed = ( - spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] - ) - memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) - - # Construct the list of past object pointers to be used in attention - if streaming: - max_object_pointers_to_use = self.config.max_object_pointers_in_encoder - else: - max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) - temporal_diff_and_pointers = [] - - # Add object pointers from selected conditioning frames - # Optionally, only include pointers from past frames during evaluation - eligible_conditioning_outputs = conditioning_outputs - if not self.training: - eligible_conditioning_outputs = { - temporal_idx: out - for temporal_idx, out in conditioning_outputs.items() - if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) - } - - for temporal_idx, out_data in eligible_conditioning_outputs.items(): - temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier - temporal_diff_and_pointers.append((temporal_difference, out_data["object_pointer"])) - - # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) - for t_diff_offset in range(1, max_object_pointers_to_use): - ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset - if ref_frame_idx < 0 or ( - not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames - ): - break # Stop if frame index is out of bounds - - # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU - out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( - ref_frame_idx, None - ) - if out_data is not None: - temporal_diff_and_pointers.append((t_diff_offset, out_data["object_pointer"])) - - if temporal_diff_and_pointers: - temporal_differences, object_pointers_list = zip(*temporal_diff_and_pointers) - # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) - object_pointers = torch.stack(object_pointers_list, dim=0) - - if self.config.enable_temporal_pos_encoding_for_object_pointers: - max_temporal_diff = float(max_object_pointers_to_use - 1) - # Determine dimensionality for temporal positional encoding of pointers - pointer_tpos_dim = num_channels - - # Normalize temporal differences before sine PE calculation - normalized_temporal_diffs = ( - torch.tensor(temporal_differences, device=device, dtype=torch.float32) / max_temporal_diff - ) - sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) - projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) - object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim) - else: - object_pointers_pos_embed = object_pointers.new_zeros( - len(temporal_differences), batch_size, self.mem_dim, dtype=object_pointers.dtype - ) - - if self.mem_dim < num_channels: - # If memory dimension is smaller, reshape/split pointers and repeat positional encoding - num_splits = num_channels // self.mem_dim - object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) - object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( - 0, 1 - ) # (SeqLen_ptr*num_splits, Batch, MemDim) - object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) - - memories_to_concatenate.append(object_pointers) - memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) - num_object_pointer_tokens = object_pointers.shape[0] - else: + # Step 1: Handle initial conditioning frames + if is_initial_conditioning_frame: # For initial conditioning frames, no prior memory is used directly in this block. - # The model might handle this with a special token or mechanism. # If configured, directly add a learnable "no memory" embedding. # current_vision_features has shape (SeqLen, Batch, Channels) conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding @@ -1997,11 +2065,36 @@ def _prepare_memory_conditioned_features( ) return conditioned_feature_map - # Step 2: Concatenate all retrieved memories and their positional embeddings. + # Step 2: Get memory frames and concatenate their features + temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs( + inference_session, obj_idx, frame_idx, track_in_reverse_time + ) + + memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs( + temporal_positions_and_previous_outputs, device + ) + + # Step 3: Get and process object pointers + temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers( + inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming + ) + + num_object_pointer_tokens = 0 + if pointer_tokens: + object_pointers, object_pointers_pos_embed = self._process_object_pointers( + temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device + ) + + if object_pointers is not None: + memories_to_concatenate.append(object_pointers) + memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) + num_object_pointer_tokens = object_pointers.shape[0] + + # Step 4: Concatenate all retrieved memories and their positional embeddings combined_memory = torch.cat(memories_to_concatenate, dim=0) combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0) - # Step 3: Forward through the memory attention mechanism. + # Step 5: Forward through the memory attention mechanism conditioned_feature_map_flat = self.memory_attention( current_vision_features=current_vision_features, current_vision_position_embeddings=current_vision_positional_embeddings, @@ -2211,7 +2304,7 @@ def forward( Whether to propagate in reverse. """ if frame is not None: - frame_idx = inference_session.add_new_frame(frame) + frame_idx = inference_session.add_new_frame(frame, frame_idx) if frame is not None and inference_session.get_obj_num() == 0: raise ValueError("No objects are provided for tracking; please add inputs first.") diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 24142232241f..34e640ade8bf 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -121,8 +121,8 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs): def to_dict(self) -> dict[str, Any]: output = super().to_dict() - output["num_classes"] = self.num_labels - output["label_names"] = list(self.id2label.values()) + output.setdefault("num_classes", self.num_labels) + output.setdefault("label_names", list(self.id2label.values())) output.pop("id2label", None) output.pop("label2id", None) return output diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index cfc3c1c104d3..d388ff05297f 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -160,6 +160,7 @@ def __init__(self, config: TimmWrapperConfig): super().__init__(config) # using num_classes=0 to avoid creating classification head extra_init_kwargs = config.model_args or {} + self.features_only = extra_init_kwargs.get("features_only", False) self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() @@ -233,20 +234,25 @@ def forward( pixel_values = pixel_values.to(self.device, self.dtype) - if output_hidden_states: - # to enable hidden states selection - if isinstance(output_hidden_states, (list, tuple)): - kwargs["indices"] = output_hidden_states - last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) - else: - last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs) - hidden_states = None - - if do_pooling: - # classification head is not created, applying pooling only - pooler_output = self.timm_model.forward_head(last_hidden_state) - else: + if self.features_only: + last_hidden_state = self.timm_model.forward(pixel_values, **kwargs) + hidden_states = last_hidden_state if output_hidden_states else None pooler_output = None + else: + if output_hidden_states: + # to enable hidden states selection + if isinstance(output_hidden_states, (list, tuple)): + kwargs["indices"] = output_hidden_states + last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs) + else: + last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs) + hidden_states = None + + if do_pooling: + # classification head is not created, applying pooling only + pooler_output = self.timm_model.forward_head(last_hidden_state) + else: + pooler_output = None if not return_dict: outputs = (last_hidden_state, pooler_output, hidden_states) diff --git a/tests/models/edgetam/__init__.py b/tests/models/edgetam/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/edgetam/test_modeling_edgetam.py b/tests/models/edgetam/test_modeling_edgetam.py new file mode 100644 index 000000000000..701642a43d41 --- /dev/null +++ b/tests/models/edgetam/test_modeling_edgetam.py @@ -0,0 +1,734 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch EDGETAM model.""" + +import gc +import tempfile +import unittest + +import requests + +from transformers import ( + EdgeTamConfig, + EdgeTamMaskDecoderConfig, + EdgeTamPromptEncoderConfig, + EdgeTamVisionConfig, + Sam2Processor, + pipeline, +) +from transformers.testing_utils import ( + backend_empty_cache, + require_torch, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_vision_available +from transformers.video_utils import load_video + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import AutoConfig, EdgeTamModel, Sam2Processor + + +if is_vision_available(): + from PIL import Image + + +class EdgeTamPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=128, + patch_size=16, + mask_input_channels=8, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return EdgeTamPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class EdgeTamMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + + def get_config(self): + return EdgeTamMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsample_rate=self.attention_downsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + + return config, dummy_inputs + + +class EdgeTamModelTester: + def __init__( + self, + parent, + num_channels=3, + image_size=128, + hidden_size=12, + patch_kernel_size=7, + patch_stride=4, + patch_padding=3, + dim_mul=2.0, + backbone_channel_list=[96, 48, 24, 12], + backbone_feature_sizes=[[32, 32], [16, 16], [8, 8]], + fpn_hidden_size=32, + memory_encoder_hidden_size=32, + batch_size=2, + is_training=False, + ): + self.parent = parent + self.image_size = image_size + self.hidden_size = hidden_size + self.patch_kernel_size = patch_kernel_size + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.dim_mul = dim_mul + self.backbone_channel_list = backbone_channel_list + self.backbone_feature_sizes = backbone_feature_sizes + self.fpn_hidden_size = fpn_hidden_size + self.batch_size = batch_size + self.num_channels = num_channels + self.is_training = is_training + self.memory_encoder_hidden_size = memory_encoder_hidden_size + + self.prompt_encoder_tester = EdgeTamPromptEncoderTester() + self.mask_decoder_tester = EdgeTamMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = EdgeTamVisionConfig( + backbone_config=AutoConfig.from_pretrained( + "timm/repvit_m1.dist_in1k", + model_args={ + "in_chans": 3, + "features_only": True, + "out_indices": (0, 1, 2, 3), + "embed_dim": self.backbone_channel_list[::-1], + }, + ), + backbone_channel_list=self.backbone_channel_list, + backbone_feature_sizes=self.backbone_feature_sizes, + fpn_hidden_size=self.fpn_hidden_size, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return EdgeTamConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + memory_attention_hidden_size=self.hidden_size, + memory_encoder_hidden_size=self.memory_encoder_hidden_size, + image_size=self.image_size, + mask_downsampler_embed_dim=32, + memory_fuser_embed_dim=32, + memory_attention_num_layers=1, + memory_attention_feed_forward_hidden_size=32, + ) + + def create_and_check_model(self, config, pixel_values): + model = EdgeTamModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class EdgeTamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (EdgeTamModel,) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": EdgeTamModel, "mask-generation": EdgeTamModel} if is_torch_available() else {} + ) + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + _is_composite = True + + def setUp(self): + self.model_tester = EdgeTamModelTester(self) + common_properties = ["initializer_range"] + self.config_tester = ConfigTester( + self, config_class=EdgeTamConfig, has_text_modality=False, common_properties=common_properties + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="Timm model does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Can't get or set embeddings for Timm model") + def test_model_get_set_embeddings(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # Override as EdgeTamModel doesn't have hidden states + def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str): + r""" + Tests the equivalence between the eager and flash attention implementations. + This test is only for inference and runs with `torch_dtype=torch.bfloat16`. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( + attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 + ): + self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + if padding_side == "left": + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + else: + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.vision_hidden_states[-1] + logits_fa = outputs_fa.vision_hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = outputs.vision_hidden_states[-1] + logits_fa = outputs_fa.vision_hidden_states[-1] + + if padding_side == "left": + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + else: + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + + # Override as diffence slightly higher than the threshold + # def test_batching_equivalence(self, atol=5e-4, rtol=5e-4): + # super().test_batching_equivalence(atol=atol, rtol=rtol) + + @unittest.skip(reason="TimmWrapperModel does not support an attention implementation") + def test_can_set_attention_dynamically_composite_model(self): + pass + + @unittest.skip(reason="vision_hidden_states from TimmWrapperModel") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Timm weights cannot be fully constructed in _init_weights") + def test_can_init_all_missing_weights(self): + pass + + @unittest.skip(reason="Timm weights cannot be fully constructed in _init_weights") + def test_initialization(self): + pass + + @unittest.skip( + reason="TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation." + ) + def test_flash_attn_2_can_dispatch_composite_models(self): + pass + + @unittest.skip("TimmWrapperModel cannot be tested with meta device") + def test_can_be_initialized_on_meta(self): + pass + + @unittest.skip("TimmWrapperModel cannot be tested with meta device") + def test_can_load_with_meta_device_context_manager(self): + pass + + ## Skip flash attention releated tests below + ## correct configuration: + ## from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "eager"} + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_from_config(self): + pass + + @unittest.skip("SDPA test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_sdpa_generate_with_dynamic_cache(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip("SDPA test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_sdpa_generate(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_inference_equivalence(self): + pass + + @unittest.skip("EdgeTAM does not have language_model, vision_tower, multi_modal_projector.") + def test_sdpa_can_dispatch_composite_models(self): + pass + + @unittest.skip("Cannot set `output_attentions` for timm models.") + def test_attention_outputs(self): + pass + + @unittest.skip("Cannot set `output_attentions` for timm models.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip("Cannot set `output_attentions` for timm models.") + def test_generate_compilation_all_outputs(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "yonigozlan/EdgeTAM-hf" + model = EdgeTamModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_sdpa_can_compile_dynamic(self): + self.skipTest(reason="EDGETAM model can't be compiled dynamic yet") + + +def prepare_image(): + img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_groceries_image(): + img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_video(): + video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4" + raw_video, _ = load_video(video_url) + return raw_video + + +@slow +class EdgeTamModelIntegrationTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(torch.float32) + self.processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf") + self.model.to(torch_device) + self.model.eval() + + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + backend_empty_cache(torch_device) + + def test_inference_mask_generation_one_point_multimask(self): + raw_image = prepare_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + inputs = self.processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + self.assertEqual(outputs.iou_scores.shape, (1, 1, 3)) + self.assertEqual(outputs.pred_masks.shape, (1, 1, 3, 256, 256)) + sorted_indices = torch.argsort(outputs.iou_scores.squeeze(), descending=True) + scores = outputs.iou_scores.squeeze()[sorted_indices] + masks_logits = outputs.pred_masks.squeeze()[sorted_indices][0, :3, :3] + torch.testing.assert_close( + scores, torch.tensor([0.7621, 0.4859, 0.0461]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + masks_logits, + torch.tensor( + [[-19.5483, -22.3549, -26.0962], [-18.1821, -23.4761, -24.2262], [-20.3549, -24.5518, -22.7232]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_one_point_no_multimask(self): + raw_image = prepare_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + + inputs = self.processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs, multimask_output=False) + self.assertEqual(outputs.iou_scores.shape, (1, 1, 1)) + self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256)) + scores = outputs.iou_scores.squeeze((0, 1)) + masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3] + torch.testing.assert_close(scores, torch.tensor([0.7621]).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + masks_logits, + torch.tensor( + [[-19.5483, -22.3549, -26.0962], [-18.1821, -23.4761, -24.2262], [-20.3549, -24.5518, -22.7232]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_batched_images_multi_points(self): + raw_image1 = prepare_image() + raw_image2 = prepare_dog_img() + input_points = [[[[500, 375]]], [[[770, 200], [730, 120]]]] + input_labels = [[[1]], [[1, 0]]] + + inputs = self.processor( + images=[raw_image1, raw_image2], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = self.model(**inputs) + self.assertEqual(outputs.iou_scores.shape, (2, 1, 3)) + self.assertEqual(outputs.pred_masks.shape, (2, 1, 3, 256, 256)) + + sorted_indices = torch.argsort(outputs.iou_scores[0].squeeze(), descending=True) + scores1 = outputs.iou_scores[0].squeeze()[sorted_indices] + masks_logits1 = outputs.pred_masks[0].squeeze()[sorted_indices][0, :3, :3] + sorted_indices = torch.argsort(outputs.iou_scores[1].squeeze(), descending=True) + scores2 = outputs.iou_scores[1].squeeze()[sorted_indices] + masks_logits2 = outputs.pred_masks[1].squeeze()[sorted_indices][0, :3, :3] + torch.testing.assert_close( + scores1, torch.tensor([0.7490, 0.4685, 0.0463]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + masks_logits1, + torch.tensor( + [[-19.1423, -21.6488, -25.6816], [-17.8018, -22.6512, -23.5699], [-19.9140, -23.6919, -22.3147]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + torch.testing.assert_close( + scores2, torch.tensor([0.7225, 0.6515, 0.6350]).to(torch_device), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + masks_logits2, + torch.tensor([[-8.8259, -7.7961, -9.3665], [-8.2648, -8.7771, -9.1390], [-9.5951, -8.3995, -9.0599]]).to( + torch_device + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_batched_images_batched_points_multi_points(self): + raw_image1 = prepare_image() + raw_image2 = prepare_groceries_image() + input_points = [[[[500, 375]], [[650, 750]]], [[[400, 300]], [[630, 300], [550, 300]]]] + input_labels = [[[1], [1]], [[1], [1, 1]]] + inputs = self.processor( + images=[raw_image1, raw_image2], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + with torch.no_grad(): + outputs = self.model(**inputs, multimask_output=False) + self.assertEqual(outputs.iou_scores.shape, (2, 2, 1)) + self.assertEqual(outputs.pred_masks.shape, (2, 2, 1, 256, 256)) + torch.testing.assert_close( + outputs.iou_scores, + torch.tensor([[[0.7490], [0.9397]], [[0.7952], [0.8723]]]).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[:, :, :, :2, :2], + torch.tensor( + [ + [[[[-19.1423, -21.6488], [-17.8018, -22.6512]]], [[[-7.1591, -9.8201], [-7.4133, -9.2781]]]], + [[[[-16.7645, -15.2790], [-16.1805, -16.2937]]], [[[-8.5934, -8.4215], [-8.1873, -8.3722]]]], + ] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_batched_images_batched_boxes(self): + raw_image1 = prepare_image() + raw_image2 = prepare_groceries_image() + input_boxes = [ + [[75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750]], + [[450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350]], + ] + inputs = self.processor(images=[raw_image1, raw_image2], input_boxes=input_boxes, return_tensors="pt").to( + torch_device + ) + with torch.no_grad(): + outputs = self.model(**inputs, multimask_output=False) + self.assertEqual(outputs.iou_scores.shape, (2, 4, 1)) + self.assertEqual(outputs.pred_masks.shape, (2, 4, 1, 256, 256)) + torch.testing.assert_close( + outputs.iou_scores, + torch.tensor([[[0.9773], [0.9415], [0.9683], [0.8792]], [[0.9721], [0.9852], [0.9812], [0.9760]]]).to( + torch_device + ), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close( + outputs.pred_masks[:, :, :, :2, :2], + torch.tensor( + [ + [ + [[[-12.6412, -12.0553], [-11.8415, -13.1696]]], + [[[-16.0378, -19.9641], [-15.4939, -19.0260]]], + [[[-18.8254, -23.6185], [-17.7889, -23.2116]]], + [[[-25.7024, -29.8722], [-22.9264, -30.0557]]], + ], + [ + [[[-19.0264, -17.0396], [-16.9458, -16.3287]]], + [[[-20.9671, -19.2132], [-18.5827, -18.0511]]], + [[[-22.4642, -19.7389], [-19.4541, -19.4717]]], + [[[-21.9226, -18.6297], [-18.9272, -18.8151]]], + ], + ] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_from_existing_points_and_mask(self): + raw_image = prepare_image() + input_points = [[[[500, 375]]]] + input_labels = [[[1]]] + original_inputs = self.processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + with torch.no_grad(): + outputs = self.model(**original_inputs) + + # best mask to use as input for new points + mask_input = outputs.pred_masks[:, :, torch.argmax(outputs.iou_scores)] + + new_input_points = [[[[500, 375], [1125, 625]]]] + new_input_labels = [[[1, 1]]] + inputs = self.processor( + input_points=new_input_points, + input_labels=new_input_labels, + original_sizes=original_inputs["original_sizes"], + return_tensors="pt", + ).to(torch_device) + with torch.no_grad(): + outputs = self.model( + **inputs, + input_masks=mask_input, + image_embeddings=outputs.image_embeddings, + multimask_output=False, + ) + + self.assertEqual(outputs.iou_scores.shape, (1, 1, 1)) + self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256)) + scores = outputs.iou_scores.squeeze((0, 1)) + masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3] + torch.testing.assert_close(scores, torch.tensor([0.9431]).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + masks_logits, + torch.tensor([[-4.1968, -4.9034, -6.0680], [-4.4053, -5.1200, -5.8580], [-4.3920, -5.5096, -5.8166]]).to( + torch_device + ), + atol=1e-4, + rtol=1e-4, + ) + + # with negative point + new_input_points = [[[[500, 375], [1125, 625]]]] + new_input_labels = [[[1, 0]]] + inputs = self.processor( + input_points=new_input_points, + input_labels=new_input_labels, + original_sizes=original_inputs["original_sizes"], + return_tensors="pt", + ).to(torch_device) + with torch.no_grad(): + outputs = self.model( + **inputs, + input_masks=mask_input, + image_embeddings=outputs.image_embeddings, + multimask_output=False, + ) + self.assertEqual(outputs.iou_scores.shape, (1, 1, 1)) + self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256)) + scores = outputs.iou_scores.squeeze((0, 1)) + masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3] + torch.testing.assert_close(scores, torch.tensor([0.9695]).to(torch_device), atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + masks_logits, + torch.tensor( + [[-14.3212, -15.4295, -17.4482], [-13.2246, -15.9468, -17.1341], [-15.1678, -16.4498, -14.7385]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="yonigozlan/EdgeTAM-hf", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64) diff --git a/tests/models/edgetam_video/__init__.py b/tests/models/edgetam_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/edgetam_video/test_modeling_edgetam_video.py b/tests/models/edgetam_video/test_modeling_edgetam_video.py new file mode 100644 index 000000000000..a2ad383351d2 --- /dev/null +++ b/tests/models/edgetam_video/test_modeling_edgetam_video.py @@ -0,0 +1,507 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SAM2 model.""" + +import gc +import unittest + +import requests + +from transformers.testing_utils import ( + backend_empty_cache, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_vision_available +from transformers.video_utils import load_video + + +if is_torch_available(): + import torch + + from transformers import EdgeTamVideoModel, Sam2VideoProcessor + + +if is_vision_available(): + from PIL import Image + + +def prepare_image(): + img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_groceries_image(): + img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_video(): + video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4" + raw_video, _ = load_video(video_url) + return raw_video + + +@slow +class EdgeTamVideoModelIntegrationTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.video_model = EdgeTamVideoModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(torch.float32) + self.processor = Sam2VideoProcessor.from_pretrained("yonigozlan/EdgeTAM-hf") + self.video_model.to(torch_device) + self.video_model.eval() + + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + backend_empty_cache(torch_device) + + def test_inference_mask_generation_video_one_point(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_points=[[[[210, 350]]]], + input_labels=[[[1]]], + ) + outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = outputs.pred_masks + self.assertEqual(low_res_masks.shape, (1, 1, 256, 256)) + video_res_masks = self.processor.post_process_masks([low_res_masks], [raw_video.shape[-3:-1]], binarize=False)[ + 0 + ] + self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[0, 0, :3, :3], + torch.tensor( + [[-28.3880, -28.3880, -27.9277], [-27.5260, -27.5260, -27.2455], [-25.5902, -25.5902, -25.7136]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-28.3880, -28.3880], [-27.5260, -27.5260]]]], + [[[[-15.3350, -15.3350], [-15.0002, -15.0002]]]], + [[[[-14.8729, -14.8729], [-14.6724, -14.6724]]]], + ], + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_video_one_point_propagate_in_video_directly(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_points=[[[[210, 350]]]], + input_labels=[[[1]]], + ) + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + print(f"VIDEO_TEST2 - ACTUAL frames[:3, :, :, :2, :2]: {frames[:3, :, :, :2, :2]}") + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-28.3880, -28.3880], [-27.5260, -27.5260]]]], + [[[[-15.3350, -15.3350], [-15.0002, -15.0002]]]], + [[[[-14.8729, -14.8729], [-14.6724, -14.6724]]]], + ] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_mask_generation_video_multi_points(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_points=[[[[210, 350], [250, 220]]]], + input_labels=[[[1, 1]]], + ) + outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = outputs.pred_masks + video_res_masks = self.processor.post_process_masks( + [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + self.assertEqual(low_res_masks.shape, (1, 1, 256, 256)) + self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[0, 0, :3, :3], + torch.tensor( + [[-17.3081, -17.3081, -16.9805], [-16.8430, -16.8430, -16.6766], [-15.7986, -15.7986, -15.9941]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + # higher tolerance due to errors propagating from frame to frame + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-17.3081, -17.3081], [-16.8430, -16.8430]]]], + [[[[-14.9302, -14.9302], [-14.8802, -14.8802]]]], + [[[[-14.4372, -14.4372], [-14.3697, -14.3697]]]], + ] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + def test_inference_mask_generation_video_one_bb(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_boxes=[[[300, 0, 500, 400]]], + ) + outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = outputs.pred_masks + video_res_masks = self.processor.post_process_masks( + [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + self.assertEqual(low_res_masks.shape, (1, 1, 256, 256)) + self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[0, 0, :3, :3], + torch.tensor( + [[-17.3245, -17.3245, -16.9231], [-16.8773, -16.8773, -16.6082], [-15.8731, -15.8731, -15.9011]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + # higher tolerance due to errors propagating from frame to frame + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-17.3245, -17.3245], [-16.8773, -16.8773]]]], + [[[[-16.2826, -16.2826], [-15.9087, -15.9087]]]], + [[[[-15.8716, -15.8716], [-15.3992, -15.3992]]]], + ] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + def test_inference_mask_generation_video_one_point_one_bb(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_boxes=[[[300, 0, 500, 400]]], + input_points=[[[[460, 60]]]], + input_labels=[[[1]]], + ) + outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = outputs.pred_masks + video_res_masks = self.processor.post_process_masks( + [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + self.assertEqual(low_res_masks.shape, (1, 1, 256, 256)) + self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[0, 0, :3, :3], + torch.tensor( + [[-13.9780, -13.9780, -13.7824], [-13.7642, -13.7642, -13.6000], [-13.2842, -13.2842, -13.1904]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + # higher tolerance due to errors propagating from frame to frame + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-13.9780, -13.9780], [-13.7642, -13.7642]]]], + [[[[-16.0142, -16.0142], [-15.5600, -15.5600]]]], + [[[[-16.7568, -16.7568], [-16.2460, -16.2460]]]], + ] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) + + def test_inference_mask_generation_video_multi_objects_multi_points(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_ids = [2, 3] # give a unique id to each object we interact with (it can be any integers) + + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_ids, + input_points=[[[[200, 300], [230, 250], [275, 175]], [[400, 150]]]], + input_labels=[[[1, 1, 0], [1]]], + ) + outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = outputs.pred_masks + video_res_masks = self.processor.post_process_masks( + [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + self.assertEqual(low_res_masks.shape, (2, 1, 256, 256)) + self.assertEqual(video_res_masks.shape, (2, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[:, 0, :2, :2], # first object + torch.tensor( + [[[-12.6233, -12.6233], [-12.1809, -12.1809]], [[-13.4556, -13.4556], [-12.9549, -12.9549]]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 2, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-12.6233, -12.6233], [-12.1809, -12.1809]]], [[[-13.4556, -13.4556], [-12.9549, -12.9549]]]], + [[[[-12.5589, -12.5589], [-12.4450, -12.4450]]], [[[-12.2181, -12.2181], [-12.0188, -12.0188]]]], + [[[[-15.3170, -15.3170], [-15.0254, -15.0254]]], [[[-11.4912, -11.4912], [-11.3171, -11.3171]]]], + ] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_propagate_video_from_mask_input(self): + raw_video = prepare_video() + inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device) + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) + + # get input_mask + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_points=[[[[210, 350], [250, 220]]]], + input_labels=[[[1, 1]]], + ) + sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + + # set mask as input + self.processor.add_inputs_to_inference_session( + inference_session=inference_session, + frame_idx=ann_frame_idx, + obj_ids=ann_obj_id, + input_masks=self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0], + ) + sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx) + low_res_masks = sam2_video_output.pred_masks + self.assertEqual(low_res_masks.shape, (1, 1, 256, 256)) + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + video_res_masks[0, 0, :3, :3], + torch.tensor( + [[-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000]] + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + # test propagate in video frames + frames = [] + for sam2_video_output in self.video_model.propagate_in_video_iterator( + inference_session=inference_session, + start_frame_idx=ann_frame_idx, + max_frame_num_to_track=2, + ): + video_res_masks = self.processor.post_process_masks( + [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False + )[0] + frames.append(video_res_masks) + frames = torch.stack(frames, dim=0) + self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2])) + torch.testing.assert_close( + frames[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-10.0000, -10.0000], [-10.0000, -10.0000]]]], + [[[[-17.4083, -17.4083], [-17.2256, -17.2256]]]], + [[[[-13.8533, -13.8533], [-13.7759, -13.7759]]]], + ], + ).to(torch_device), + atol=1e-4, + rtol=1e-4, + ) + + def test_inference_propagate_on_streamed_video(self): + raw_video = prepare_video() + + inference_session = self.processor.init_video_session(inference_device=torch_device) + video_res_masks = [] + max_frame_num_to_track = 3 + for frame_idx, frame in enumerate(raw_video): + if frame_idx >= max_frame_num_to_track: + break + inputs = self.processor(images=frame, device=torch_device, return_tensors="pt") + if frame_idx == 0: + self.processor.add_inputs_to_inference_session( + inference_session, + frame_idx=0, + obj_ids=1, + input_points=[[[[210, 350], [250, 220]]]], + input_labels=[[[1, 1]]], + original_size=inputs.original_sizes[0], + ) + sam2_video_output = self.video_model(inference_session=inference_session, frame=inputs.pixel_values[0]) + video_res_masks.append( + self.processor.post_process_masks( + [sam2_video_output.pred_masks], inputs.original_sizes, binarize=False + )[0] + ) + + video_res_masks = torch.stack(video_res_masks, dim=0) + self.assertEqual( + video_res_masks.shape, (max_frame_num_to_track, 1, 1, raw_video.shape[-3], raw_video.shape[-2]) + ) + # higher tolerance due to errors propagating from frame to frame + print(f"VIDEO_TEST8 - ACTUAL video_res_masks[:3, :, :, :2, :2]: {video_res_masks[:3, :, :, :2, :2]}") + torch.testing.assert_close( + video_res_masks[:3, :, :, :2, :2], + torch.tensor( + [ + [[[[-17.3081, -17.3081], [-16.8430, -16.8430]]]], + [[[[-14.9302, -14.9302], [-14.8802, -14.8802]]]], + [[[[-14.4372, -14.4372], [-14.3697, -14.3697]]]], + ] + ).to(torch_device), + atol=1e-2, + rtol=1e-2, + ) diff --git a/tests/models/sam2/test_modeling_sam2.py b/tests/models/sam2/test_modeling_sam2.py index a19c6a13d220..dcacd3920a7a 100644 --- a/tests/models/sam2/test_modeling_sam2.py +++ b/tests/models/sam2/test_modeling_sam2.py @@ -558,7 +558,6 @@ def test_attention_outputs(self): ) # Override as Sam2Model has different sub-modules - def test_sdpa_can_dispatch_composite_models(self): """ Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. diff --git a/utils/check_repo.py b/utils/check_repo.py index 0890a1abc4da..2df8d17d6fad 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -140,7 +140,9 @@ "BarkCausalModel", # Building part of bigger (tested) model. "BarkModel", # Does not have a forward signature - generation tested with integration tests. "Sam2HieraDetModel", # Building part of bigger (tested) model. - "Sam2VideoModel", # inherit from Sam2Model (tested). + "Sam2VideoModel", # Partly tested in Sam2Model, not regular model. + "EdgeTamVisionModel", # Building part of bigger (tested) model. + "EdgeTamVideoModel", # Partly tested in EdgeTamModel, not regular model. "SeamlessM4TTextToUnitModel", # Building part of bigger (tested) model. "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. @@ -208,6 +210,7 @@ "models/shieldgemma2/test_modeling_shieldgemma2.py", "models/llama4/test_modeling_llama4.py", "models/sam2_video/test_modeling_sam2_video.py", + "models/edgetam_video/test_modeling_edgetam_video.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and @@ -256,6 +259,8 @@ "SamModel", "Sam2Model", "Sam2VideoModel", + "EdgeTamModel", + "EdgeTamVideoModel", "SamHQModel", "DPTForDepthEstimation", "DecisionTransformerGPT2Model",